|
| 1 | +import re |
| 2 | + |
| 3 | +from bfcl_eval.model_handler.api_inference.nvidia import NvidiaHandler |
| 4 | +from bfcl_eval.model_handler.utils import ( |
| 5 | + combine_consecutive_user_prompts, |
| 6 | + convert_system_prompt_into_user_prompt, |
| 7 | + default_decode_ast_prompting, |
| 8 | + default_decode_execute_prompting, |
| 9 | + func_doc_language_specific_pre_processing, |
| 10 | +) |
| 11 | +from overrides import override |
| 12 | + |
| 13 | + |
| 14 | +class NemotronHandler(NvidiaHandler): |
| 15 | + """Handler for the LLaMA 3.1 Nemotron Ultra 253B v1 model. |
| 16 | +
|
| 17 | + This handler extends NvidiaHandler to support the Nemotron model's XML-based |
| 18 | + function calling format. The model expects: |
| 19 | + - <TOOLCALL>[function_calls]</TOOLCALL> for function calls |
| 20 | + - <AVAILABLE_TOOLS>{functions}</AVAILABLE_TOOLS> for function documentation |
| 21 | + """ |
| 22 | + |
| 23 | + def _format_system_prompt(self, prompts, function_docs, test_category): |
| 24 | + """Format the system prompt in the Nemotron-specific XML format.""" |
| 25 | + |
| 26 | + system_prompt_template = """You are an expert in composing functions. You are given a question and a set of possible functions. |
| 27 | +Based on the question, you will need to make one or more function/tool calls to achieve the purpose. |
| 28 | +If none of the function can be used, point it out. If the given question lacks the parameters required by the function, |
| 29 | +also point it out. You should only return the function call in tools call sections. |
| 30 | +
|
| 31 | +If you decide to invoke any of the function(s), you MUST put it in the format of <TOOLCALL>[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]</TOOLCALL> |
| 32 | +
|
| 33 | +You SHOULD NOT include any other text in the response. |
| 34 | +Here is a list of functions in JSON format that you can invoke. |
| 35 | +
|
| 36 | +<AVAILABLE_TOOLS>{functions}</AVAILABLE_TOOLS> |
| 37 | +
|
| 38 | +{user_prompt}""" |
| 39 | + |
| 40 | + # Extract the first user message content (if any) and remove it from the list. |
| 41 | + user_prompt = "" |
| 42 | + for idx, msg in enumerate(prompts): |
| 43 | + if msg["role"] == "user": |
| 44 | + user_prompt = msg["content"] |
| 45 | + # Delete the user message – it will be folded into the system prompt. |
| 46 | + prompts.pop(idx) |
| 47 | + break |
| 48 | + |
| 49 | + system_prompt = system_prompt_template.format( |
| 50 | + functions=function_docs, user_prompt=user_prompt |
| 51 | + ) |
| 52 | + |
| 53 | + # Insert the system prompt at the beginning of the list. |
| 54 | + prompts.insert(0, {"role": "system", "content": system_prompt}) |
| 55 | + |
| 56 | + return prompts |
| 57 | + |
| 58 | + @override |
| 59 | + def _pre_query_processing_prompting(self, test_entry: dict) -> dict: |
| 60 | + """Process the input query and format it for the Nemotron model.""" |
| 61 | + functions: list = test_entry["function"] |
| 62 | + test_category: str = test_entry["id"].rsplit("_", 1)[0] |
| 63 | + |
| 64 | + # Pre-process functions based on language |
| 65 | + functions = func_doc_language_specific_pre_processing(functions, test_category) |
| 66 | + |
| 67 | + for round_idx in range(len(test_entry["question"])): |
| 68 | + test_entry["question"][round_idx] = convert_system_prompt_into_user_prompt( |
| 69 | + test_entry["question"][round_idx] |
| 70 | + ) |
| 71 | + test_entry["question"][round_idx] = combine_consecutive_user_prompts( |
| 72 | + test_entry["question"][round_idx] |
| 73 | + ) |
| 74 | + |
| 75 | + test_entry["question"][0] = self._format_system_prompt( |
| 76 | + test_entry["question"][0], functions, test_category |
| 77 | + ) |
| 78 | + |
| 79 | + # Return empty message list - messages will be added incrementally |
| 80 | + return {"message": []} |
| 81 | + |
| 82 | + @override |
| 83 | + def decode_ast(self, result, language="Python"): |
| 84 | + """Extract function calls from the Nemotron XML format.""" |
| 85 | + # Extract content between TOOLCALL tags |
| 86 | + toolcall_match = re.search(r"<TOOLCALL>(.*?)</TOOLCALL>", result, re.DOTALL) |
| 87 | + if not toolcall_match: |
| 88 | + return [] |
| 89 | + |
| 90 | + # Get the function call string |
| 91 | + func_call_str = toolcall_match.group(1) |
| 92 | + |
| 93 | + return default_decode_ast_prompting(func_call_str, language) |
| 94 | + |
| 95 | + @override |
| 96 | + def decode_execute(self, result, language="Python"): |
| 97 | + """Convert Nemotron response to executable function calls.""" |
| 98 | + # Extract content between TOOLCALL tags |
| 99 | + toolcall_match = re.search(r"<TOOLCALL>(.*?)</TOOLCALL>", result, re.DOTALL) |
| 100 | + if not toolcall_match: |
| 101 | + return [] |
| 102 | + |
| 103 | + # Get the function call string |
| 104 | + func_call_str = toolcall_match.group(1) |
| 105 | + |
| 106 | + return default_decode_execute_prompting(func_call_str, language) |
0 commit comments