Skip to content

Commit 40b07a5

Browse files
authored
[BFCL] Patch Gemini Handler (#421)
Build on the work of @vandyxiaowei #406, this PR is a bug fix for the generation pipeline for Gemini models ([Gemini-1.5-Pro (FC)](https://deepmind.google/technologies/gemini/#introduction) and [Gemini-1.0-Pro (FC)](https://deepmind.google/technologies/gemini/#introduction)) to make it more robust. The Gemini model output `result["candidates"][0]` does not always have the key `"content"`, which would cause the current generation pipeline to error out. This PR **DOES NOT** change the leaderboard score.
1 parent 2fc82a9 commit 40b07a5

File tree

3 files changed

+37
-20
lines changed

3 files changed

+37
-20
lines changed

berkeley-function-call-leaderboard/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,10 @@ For inferencing `Databrick-DBRX-instruct`, you need to create a Databrick Azure
231231

232232

233233
## Changelog
234+
235+
* [May 8, 2024] [#406](https://github.com/ShishirPatil/gorilla/pull/406) and [#421](https://github.com/ShishirPatil/gorilla/pull/421): Update the `gemini_handler.py` to better handle parallel function calls for Gemini models.
234236
* [May 6, 2024] [#412](https://github.com/ShishirPatil/gorilla/pull/412): Bug fix in evaluation dataset for AST categories. This includes updates to both prompts and function docs.
237+
* [May 2, 2024] [#405](https://github.com/ShishirPatil/gorilla/pull/405): Bug fix in the possible answers for the AST Simple evaluation dataset. Prompt and function docs are not affected.
235238
* [April 28, 2024] [#397](https://github.com/ShishirPatil/gorilla/pull/397): Add new model `snowflake/arctic` to the leaderboard. Note that there are multiple ways to inference the model, and we choose to do it via Nvidia API catalog.
236239
* [April 27, 2024] [#390](https://github.com/ShishirPatil/gorilla/pull/390): Bug fix in cost and latency calculation for open-source models, which are now all calculated when serving the model with [vLLM](https://github.com/vllm-project/vllm) using 8 V100 GPUs for consistency. $$\text{Cost} = \text{Latency per 1000 function call} * (\text{8xV100 azure-pay-as-you-go-price per hour / 3600})$$
237240
* [April 25, 2024] [#386](https://github.com/ShishirPatil/gorilla/pull/386): Add 5 new models to the leaderboard: `meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `gemini-1.5-pro-preview-0409`, `command-r-plus`, `command-r-plus-FC`.

berkeley-function-call-leaderboard/eval_checker/eval_runner_helper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,9 +681,11 @@ def record_cost_latency(leaderboard_table, model_name, model_output_data):
681681
)
682682
print("*" * 100)
683683
if "input_token_count" in data:
684-
input_token.append(data["input_token_count"])
684+
if data["input_token_count"] != 0:
685+
input_token.append(data["input_token_count"])
685686
if "output_token_count" in data:
686-
output_token.append(data["output_token_count"])
687+
if data["output_token_count"] != 0:
688+
output_token.append(data["output_token_count"])
687689

688690
leaderboard_table[model_name]["cost"]["input_data"].extend(input_token)
689691
leaderboard_table[model_name]["cost"]["output_data"].extend(output_token)

berkeley-function-call-leaderboard/model_handler/gemini_handler.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,26 +62,38 @@ def _query_gemini(self, user_query, functions):
6262
"output_tokens": 0,
6363
"latency": latency,
6464
}
65-
parts = []
66-
for part in result["candidates"][0]["content"]["parts"]:
67-
if "functionCall" in part:
68-
if (
69-
"name" in part["functionCall"]
70-
and "args" in part["functionCall"]
71-
):
72-
parts.append({part["functionCall"]["name"]: json.dumps(part["functionCall"]["args"])})
65+
try:
66+
parts = []
67+
for part in result["candidates"][0]["content"]["parts"]:
68+
if "functionCall" in part:
69+
if (
70+
"name" in part["functionCall"]
71+
and "args" in part["functionCall"]
72+
):
73+
parts.append({part["functionCall"]["name"]: json.dumps(part["functionCall"]["args"])})
74+
else:
75+
parts.append("Parsing error: " + json.dumps(part["functionCall"]))
7376
else:
74-
parts.append("Parsing error: " + json.dumps(part["functionCall"]))
75-
else:
76-
parts.append(part["text"])
77-
result = parts
77+
parts.append(part["text"])
78+
result = parts
79+
# This try-except is necessary because sometimes `result["candidates"][0]` does not have the key "content"
80+
except Exception as e:
81+
result = f"Parsing error: {e}"
82+
7883
metatdata = {}
79-
metatdata["input_tokens"] = json.loads(response.content)["usageMetadata"][
80-
"promptTokenCount"
81-
]
82-
metatdata["output_tokens"] = json.loads(response.content)["usageMetadata"][
83-
"candidatesTokenCount"
84-
]
84+
try:
85+
metatdata["input_tokens"] = json.loads(response.content)["usageMetadata"][
86+
"promptTokenCount"
87+
]
88+
except:
89+
metatdata["input_tokens"] = 0 # We special handle the 0 value when aggregating the results. 0 token will be ignored and not be counted in the average.
90+
try:
91+
metatdata["output_tokens"] = json.loads(response.content)["usageMetadata"][
92+
"candidatesTokenCount"
93+
]
94+
except:
95+
metatdata["output_tokens"] = 0 # We special handle the 0 value when aggregating the results. 0 token will be ignored and not be counted in the average.
96+
8597
metatdata["latency"] = latency
8698
return result, metatdata
8799

0 commit comments

Comments
 (0)