Skip to content

Commit 42f9d28

Browse files
vandyxiaoweiXiaowei Li
andauthored
handle parallel function calls from gemini (#406)
Handle parallel function calls for Gemini handler for the Berkeley Function Calling Leaderboard. This PR does NOT change values in BFCL. Co-authored-by: Xiaowei Li <[email protected]>
1 parent 0eb02bb commit 42f9d28

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

berkeley-function-call-leaderboard/model_handler/gemini_handler.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ def _query_gemini(self, user_query, functions):
4343
}
4444

4545
# NOTE: To run the gemini model, you need to provide your own GCP project ID, which can be found in the GCP console.
46-
if self.model_name == "gemini-1.5-pro-preview-0409":
47-
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/gemini-1.5-pro-preview-0409:generateContent"
48-
else:
49-
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/gemini-1.0-pro:generateContent"
46+
API_URL = "https://us-central1-aiplatform.googleapis.com/v1beta1/projects/{YOUR_GCP_PROJECT_ID_HERE}/locations/us-central1/publishers/google/models/" + self.model_name + ":generateContent"
5047
headers = {
5148
"Authorization": "Bearer " + token,
5249
"Content-Type": "application/json",
@@ -65,22 +62,19 @@ def _query_gemini(self, user_query, functions):
6562
"output_tokens": 0,
6663
"latency": latency,
6764
}
68-
contents = result["candidates"][0]["content"]["parts"][0]
69-
if "functionCall" in contents:
70-
if (
71-
"name" in contents["functionCall"]
72-
and "args" in contents["functionCall"]
73-
):
74-
result = {
75-
contents["functionCall"]["name"]: json.dumps(
76-
contents["functionCall"]["args"]
77-
)
78-
}
79-
65+
parts = []
66+
for part in result["candidates"][0]["content"]["parts"]:
67+
if "functionCall" in part:
68+
if (
69+
"name" in part["functionCall"]
70+
and "args" in part["functionCall"]
71+
):
72+
parts.append({part["functionCall"]["name"]: json.dumps(part["functionCall"]["args"])})
73+
else:
74+
parts.append("Parsing error: " + json.dumps(part["functionCall"]))
8075
else:
81-
result = "Parsing error: " + json.dumps(contents["functionCall"])
82-
else:
83-
result = contents["text"]
76+
parts.append(part["text"])
77+
result = parts
8478
metatdata = {}
8579
metatdata["input_tokens"] = json.loads(response.content)["usageMetadata"][
8680
"promptTokenCount"

0 commit comments

Comments
 (0)