|
| 1 | +import json |
| 2 | + |
| 3 | +from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler |
| 4 | +from bfcl.model_handler.utils import func_doc_language_specific_pre_processing |
| 5 | + |
| 6 | +# TODO: Merge with LlamaHandler |
| 7 | + |
| 8 | + |
| 9 | +class LlamaFCHandler(OSSHandler): |
| 10 | + def __init__(self, model_name, temperature) -> None: |
| 11 | + super().__init__(model_name, temperature) |
| 12 | + self.model_name_huggingface = model_name.replace("-FC", "") |
| 13 | + |
| 14 | + @staticmethod |
| 15 | + def _format_prompt(messages, function): |
| 16 | + """ |
| 17 | + "bos_token": "<|begin_of_text|>", |
| 18 | + "chat_template": |
| 19 | + {{- bos_token }} |
| 20 | + {%- if custom_tools is defined %} |
| 21 | + {%- set tools = custom_tools %} |
| 22 | + {%- endif %} |
| 23 | + {%- if not tools_in_user_message is defined %} |
| 24 | + {%- set tools_in_user_message = true %} |
| 25 | + {%- endif %} |
| 26 | + {%- if not date_string is defined %} |
| 27 | + {%- set date_string = "26 Jul 2024" %} |
| 28 | + {%- endif %} |
| 29 | + {%- if not tools is defined %} |
| 30 | + {%- set tools = none %} |
| 31 | + {%- endif %} |
| 32 | +
|
| 33 | + {#- This block extracts the system message, so we can slot it into the right place. #} |
| 34 | + {%- if messages[0]['role'] == 'system' %} |
| 35 | + {%- set system_message = messages[0]['content']|trim %} |
| 36 | + {%- set messages = messages[1:] %} |
| 37 | + {%- else %} |
| 38 | + {%- set system_message = "" %} |
| 39 | + {%- endif %} |
| 40 | +
|
| 41 | + {#- System message + builtin tools #} |
| 42 | + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} |
| 43 | + {%- if builtin_tools is defined or tools is not none %} |
| 44 | + {{- "Environment: ipython\n" }} |
| 45 | + {%- endif %} |
| 46 | + {%- if builtin_tools is defined %} |
| 47 | + {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n"}} |
| 48 | + {%- endif %} |
| 49 | + {{- "Cutting Knowledge Date: December 2023\n" }} |
| 50 | + {{- "Today Date: " + date_string + "\n\n" }} |
| 51 | + {%- if tools is not none and not tools_in_user_message %} |
| 52 | + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} |
| 53 | + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} |
| 54 | + {{- "Do not use variables.\n\n" }} |
| 55 | + {%- for t in tools %} |
| 56 | + {{- t | tojson(indent=4) }} |
| 57 | + {{- "\n\n" }} |
| 58 | + {%- endfor %} |
| 59 | + {%- endif %} |
| 60 | + {{- system_message }} |
| 61 | + {{- "<|eot_id|>" }} |
| 62 | +
|
| 63 | + {#- Custom tools are passed in a user message with some extra guidance #} |
| 64 | + {%- if tools_in_user_message and not tools is none %} |
| 65 | + {#- Extract the first user message so we can plug it in here #} |
| 66 | + {%- if messages | length != 0 %} |
| 67 | + {%- set first_user_message = messages[0]['content']|trim %} |
| 68 | + {%- set messages = messages[1:] %} |
| 69 | + {%- else %} |
| 70 | + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} |
| 71 | + {%- endif %} |
| 72 | + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} |
| 73 | + {{- "Given the following functions, please respond with a JSON for a function call " }} |
| 74 | + {{- "with its proper arguments that best answers the given prompt.\n\n" }} |
| 75 | + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} |
| 76 | + {{- "Do not use variables.\n\n" }} |
| 77 | + {%- for t in tools %} |
| 78 | + {{- t | tojson(indent=4) }} |
| 79 | + {{- "\n\n" }} |
| 80 | + {%- endfor %} |
| 81 | + {{- first_user_message + "<|eot_id|>"}} |
| 82 | + {%- endif %} |
| 83 | +
|
| 84 | + {%- for message in messages %} |
| 85 | + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} |
| 86 | + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} |
| 87 | + {%- elif 'tool_calls' in message %} |
| 88 | + {%- if not message.tool_calls|length == 1 %} |
| 89 | + {{- raise_exception("This model only supports single tool-calls at once!") }} |
| 90 | + {%- endif %} |
| 91 | + {%- set tool_call = message.tool_calls[0].function %} |
| 92 | + {%- if builtin_tools is defined and tool_call.name in builtin_tools %} |
| 93 | + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} |
| 94 | + {{- "<|python_tag|>" + tool_call.name + ".call(" }} |
| 95 | + {%- for arg_name, arg_val in tool_call.arguments | items %} |
| 96 | + {{- arg_name + '="' + arg_val + '"' }} |
| 97 | + {%- if not loop.last %} |
| 98 | + {{- ", " }} |
| 99 | + {%- endif %} |
| 100 | + {%- endfor %} |
| 101 | + {{- ")" }} |
| 102 | + {%- else %} |
| 103 | + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} |
| 104 | + {{- '{"name": "' + tool_call.name + '", ' }} |
| 105 | + {{- '"parameters": ' }} |
| 106 | + {{- tool_call.arguments | tojson }} |
| 107 | + {{- "}" }} |
| 108 | + {%- endif %} |
| 109 | + {%- if builtin_tools is defined %} |
| 110 | + {#- This means we're in ipython mode #} |
| 111 | + {{- "<|eom_id|>" }} |
| 112 | + {#- This means we're in ipython mode #} |
| 113 | + {{- "<|eom_id|>" }} |
| 114 | + {{- "<|eom_id|>" }} |
| 115 | + {%- else %} |
| 116 | + {{- "<|eot_id|>" }} |
| 117 | + {%- endif %} |
| 118 | + {%- elif message.role == "tool" or message.role == "ipython" %} |
| 119 | + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} |
| 120 | + {%- if message.content is mapping or message.content is iterable %} |
| 121 | + {{- message.content | tojson }} |
| 122 | + {%- else %} |
| 123 | + {{- message.content }} |
| 124 | + {%- endif %} |
| 125 | + {{- "<|eot_id|>" }} |
| 126 | + {%- endif %} |
| 127 | + {%- endfor %} |
| 128 | + {%- if add_generation_prompt %} |
| 129 | + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} |
| 130 | + {%- endif %} |
| 131 | + """ |
| 132 | + formatted_prompt = "<|begin_of_text|>" |
| 133 | + |
| 134 | + system_message = "" |
| 135 | + remaining_messages = messages |
| 136 | + if messages[0]["role"] == "system": |
| 137 | + system_message = messages[0]["content"].strip() |
| 138 | + remaining_messages = messages[1:] |
| 139 | + |
| 140 | + formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n\n" |
| 141 | + formatted_prompt += "Environment: ipython\n" |
| 142 | + formatted_prompt += "Cutting Knowledge Date: December 2023\n" |
| 143 | + formatted_prompt += "Today Date: 26 Jul 2024\n\n" |
| 144 | + formatted_prompt += system_message + "<|eot_id|>" |
| 145 | + |
| 146 | + # Llama pass in custom tools in first user message |
| 147 | + is_first_user_message = True |
| 148 | + for message in remaining_messages: |
| 149 | + if message["role"] == "user" and is_first_user_message: |
| 150 | + is_first_user_message = False |
| 151 | + formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n\n" |
| 152 | + formatted_prompt += "Given the following functions, please respond with a JSON for a function call " |
| 153 | + formatted_prompt += ( |
| 154 | + "with its proper arguments that best answers the given prompt.\n\n" |
| 155 | + ) |
| 156 | + formatted_prompt += 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' |
| 157 | + formatted_prompt += "Do not use variables.\n\n" |
| 158 | + for func in function: |
| 159 | + formatted_prompt += json.dumps(func, indent=4) + "\n\n" |
| 160 | + formatted_prompt += f"{message['content'].strip()}<|eot_id|>" |
| 161 | + |
| 162 | + elif message["role"] == "tool": |
| 163 | + formatted_prompt += "<|start_header_id|>ipython<|end_header_id|>\n\n" |
| 164 | + if isinstance(message["content"], (dict, list)): |
| 165 | + formatted_prompt += json.dumps(message["content"]) |
| 166 | + else: |
| 167 | + formatted_prompt += message["content"] |
| 168 | + formatted_prompt += "<|eot_id|>" |
| 169 | + |
| 170 | + else: |
| 171 | + formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content'].strip()}<|eot_id|>" |
| 172 | + |
| 173 | + formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n\n" |
| 174 | + |
| 175 | + return formatted_prompt |
| 176 | + |
| 177 | + def decode_ast(self, result, language="Python"): |
| 178 | + result = result.replace("<|python_tag|>", "") |
| 179 | + # Llama sometimes separates the function calls with `;` and sometimes with `,` |
| 180 | + if ";" in result: |
| 181 | + """ |
| 182 | + "<|python_tag|>{\"name\": \"calc_binomial_probability\", \"parameters\": {\"n\": \"10\", \"k\": \"3\", \"p\": \"0\"}}; {\"name\": \"calc_binomial_probability\", \"parameters\": {\"n\": \"15\", \"k\": \"5\", \"p\": \"0\"}}; {\"name\": \"calc_binomial_probability\", \"parameters\": {\"n\": \"20\", \"k\": \"7\", \"p\": \"0\"}}" |
| 183 | + """ |
| 184 | + function_calls = result.split(";") |
| 185 | + function_calls = [json.loads(func_call) for func_call in function_calls] |
| 186 | + else: |
| 187 | + """ |
| 188 | + "[\n {\"name\": \"calculate_permutations\", \"parameters\": {\"n\": \"20\", \"k\": \"5\"}},\n {\"name\": \"calculate_permutations\", \"parameters\": {\"n\": \"12\", \"k\": \"5\"}},\n {\"name\": \"calculate_permutations\", \"parameters\": {\"n\": \"10\", \"k\": \"3\"}}\n]" |
| 189 | + """ |
| 190 | + function_calls = eval(result) |
| 191 | + if type(function_calls) == dict: |
| 192 | + function_calls = [function_calls] |
| 193 | + |
| 194 | + decoded_output = [] |
| 195 | + for func_call in function_calls: |
| 196 | + name = func_call["name"] |
| 197 | + params = func_call["parameters"] |
| 198 | + decoded_output.append({name: params}) |
| 199 | + |
| 200 | + return decoded_output |
| 201 | + |
| 202 | + def decode_execute(self, result): |
| 203 | + result = result.replace("<|python_tag|>", "") |
| 204 | + # Llama sometimes separates the function calls with `;` and sometimes with `,` |
| 205 | + if ";" in result: |
| 206 | + function_calls = result.split(";") |
| 207 | + function_calls = [json.loads(func_call) for func_call in function_calls] |
| 208 | + else: |
| 209 | + function_calls = eval(result) |
| 210 | + if type(function_calls) == dict: |
| 211 | + function_calls = [function_calls] |
| 212 | + |
| 213 | + execution_list = [] |
| 214 | + for func_call in function_calls: |
| 215 | + name = func_call["name"] |
| 216 | + params = func_call["parameters"] |
| 217 | + execution_list.append( |
| 218 | + f"{name}({','.join([f'{k}={repr(v)}' for k,v in params.items()])})" |
| 219 | + ) |
| 220 | + |
| 221 | + return execution_list |
| 222 | + |
| 223 | + def _pre_query_processing_prompting(self, test_entry: dict) -> dict: |
| 224 | + functions: list = test_entry["function"] |
| 225 | + test_category: str = test_entry["id"].rsplit("_", 1)[0] |
| 226 | + |
| 227 | + functions = func_doc_language_specific_pre_processing(functions, test_category) |
| 228 | + |
| 229 | + # Llama use its own system prompt |
| 230 | + |
| 231 | + return {"message": [], "function": functions} |
0 commit comments