Skip to content

Commit 30994aa

Browse files
authored
[BFCL] Migrate Gemini Inference to Google AI Studio (#1099)
This PR updates the inference mechanism for Google Gemini models, replacing the use of Google Vertex AI with Google AI Studio. In addition, this PR downgrades `tenacity` from 9.0.0 → 8.5.0 because `google-genai` currently pins `tenacity<9.0`. ---- **Compatibility note on tenacity** Our code does exercise the retry path affected by [jd/tenacity #425](jd/tenacity#425), but the issue has no functional impact on our evaluation accuracy. Therefore, the temporary downgrade is considered safe. We will revert to tenacity ≥9.0 once python-genai removes the <9.0 pin (tracked in [googleapis/python-genai #1005](googleapis/python-genai#1005)).
1 parent c753d3c commit 30994aa

File tree

9 files changed

+92
-84
lines changed

9 files changed

+92
-84
lines changed

berkeley-function-call-leaderboard/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
All notable changes to the Berkeley Function Calling Leaderboard will be documented in this file.
44

5+
- [Jul 6, 2025] [#1099](https://github.com/ShishirPatil/gorilla/pull/1099): Migrate Gemini inference to Google AI Studio.
56
- [Jul 2, 2025] [#1090](https://github.com/ShishirPatil/gorilla/pull/1090): Updated OpenAI models to use `developer` role instead of `system` role, following OpenAI's documentation recommendations. This change affects only the OpenAI Responses handler.
67
- [Jul 2, 2025] [#1062](https://github.com/ShishirPatil/gorilla/pull/1062): Introduce OpenAI Responses handler, and add support for `o3-2025-04-16` and `o4-mini-2025-04-16`.
78
- [Jun 30, 2025] [#956](https://github.com/ShishirPatil/gorilla/pull/956): Fix typo in ground truth for multi_turn_base.

berkeley-function-call-leaderboard/SUPPORTED_MODELS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ For model names containing `{...}`, multiple versions are available. For example
135135
## Additional Requirements for Certain Models
136136

137137
- **Gemini Models:**
138-
For `Gemini` models, we use the Google Vertex AI endpoint for inference. Ensure you have set the `VERTEX_AI_PROJECT_ID` and `VERTEX_AI_LOCATION` in your `.env` file.
138+
For `Gemini` models, we use the Google AI Studio API for inference. Ensure you have set the `GOOGLE_API_KEY` in your `.env` file.
139139

140140
- **Databricks Models:**
141141
For `databrick-dbrx-instruct`, you must create an Azure Databricks workspace and set up a dedicated inference endpoint. Provide the endpoint URL via `DATABRICKS_AZURE_ENDPOINT_URL` in `.env`.

berkeley-function-call-leaderboard/bfcl_eval/.env.example

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ MINING_API_KEY=sk-XXXXXX
1717
DMCITO_BASE_URL=XXXXXX
1818
DMCITO_API_KEY=sk-XXXXXX
1919

20-
# We use Vertex AI to inference Google Gemini models
21-
VERTEX_AI_PROJECT_ID=
22-
VERTEX_AI_LOCATION=
20+
# We use Google AI Studio to inference Google Gemini models
21+
GOOGLE_API_KEY=
2322

2423
AWS_ACCESS_KEY_ID=
2524
AWS_SECRET_ACCESS_KEY=

berkeley-function-call-leaderboard/bfcl_eval/_llm_response_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def multi_threaded_inference(handler, test_case, include_input_log, exclude_stat
205205
"❗️❗️ Error occurred during inference. Maximum reties reached for rate limit or other error. Continuing to next test case."
206206
)
207207
print(f"❗️❗️ Test case ID: {test_case['id']}, Error: {str(e)}")
208-
traceback.print_exc()
208+
traceback.print_exc(limit=10)
209209
print("-" * 100)
210210

211211
return {

berkeley-function-call-leaderboard/bfcl_eval/model_handler/api_inference/gemini.py

Lines changed: 71 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import time
33

4-
import vertexai
54
from bfcl_eval.constants.type_mappings import GORILLA_TO_OPENAPI
65
from bfcl_eval.model_handler.base_handler import BaseHandler
76
from bfcl_eval.model_handler.model_style import ModelStyle
@@ -15,38 +14,37 @@
1514
retry_with_backoff,
1615
system_prompt_pre_processing_chat_model,
1716
)
18-
from google.api_core.exceptions import ResourceExhausted, TooManyRequests
19-
from vertexai.generative_models import (
17+
from google import genai
18+
from google.genai import errors as genai_errors
19+
from google.genai.types import (
20+
AutomaticFunctionCallingConfig,
2021
Content,
21-
FunctionDeclaration,
22-
GenerationConfig,
23-
GenerativeModel,
22+
GenerateContentConfig,
2423
Part,
24+
ThinkingConfig,
2525
Tool,
2626
)
2727

2828

2929
class GeminiHandler(BaseHandler):
3030
def __init__(self, model_name, temperature) -> None:
3131
super().__init__(model_name, temperature)
32-
self.model_style = ModelStyle.Google
33-
# Initialize Vertex AI
34-
vertexai.init(
35-
project=os.getenv("VERTEX_AI_PROJECT_ID"),
36-
location=os.getenv("VERTEX_AI_LOCATION"),
37-
)
38-
self.client = GenerativeModel(self.model_name.replace("-FC", ""))
32+
self.model_style = ModelStyle.GOOGLE
33+
api_key = os.getenv("GOOGLE_API_KEY")
34+
if not api_key:
35+
raise ValueError(
36+
"GOOGLE_API_KEY environment variable must be set for Gemini models"
37+
)
38+
self.client = genai.Client(api_key=api_key)
3939

4040
@staticmethod
4141
def _substitute_prompt_role(prompts: list[dict]) -> list[dict]:
42-
# Allowed roles: user, model, function
42+
# Allowed roles: user, model
4343
for prompt in prompts:
4444
if prompt["role"] == "user":
4545
prompt["role"] = "user"
4646
elif prompt["role"] == "assistant":
4747
prompt["role"] = "model"
48-
elif prompt["role"] == "tool":
49-
prompt["role"] = "function"
5048

5149
return prompts
5250

@@ -72,58 +70,41 @@ def decode_execute(self, result):
7270
)
7371
return func_call_list
7472

75-
@retry_with_backoff(error_type=[ResourceExhausted, TooManyRequests])
76-
def generate_with_backoff(self, client, **kwargs):
73+
# We can't retry on ClientError because it's too broad.
74+
# Both rate limit and invalid function description will trigger google.genai.errors.ClientError
75+
@retry_with_backoff(error_message_pattern=r".*RESOURCE_EXHAUSTED.*")
76+
def generate_with_backoff(self, **kwargs):
7777
start_time = time.time()
78-
api_response = client.generate_content(**kwargs)
78+
api_response = self.client.models.generate_content(**kwargs)
7979
end_time = time.time()
8080

8181
return api_response, end_time - start_time
8282

8383
#### FC methods ####
8484

8585
def _query_FC(self, inference_data: dict):
86-
# Gemini models needs to first conver the function doc to FunctionDeclaration and Tools objects.
87-
# We do it here to avoid json serialization issues.
88-
func_declarations = []
89-
for function in inference_data["tools"]:
90-
func_declarations.append(
91-
FunctionDeclaration(
92-
name=function["name"],
93-
description=function["description"],
94-
parameters=function["parameters"],
95-
)
96-
)
97-
98-
if func_declarations:
99-
tools = [Tool(function_declarations=func_declarations)]
100-
else:
101-
tools = None
102-
10386
inference_data["inference_input_log"] = {
10487
"message": repr(inference_data["message"]),
10588
"tools": inference_data["tools"],
10689
"system_prompt": inference_data.get("system_prompt", None),
10790
}
10891

109-
# messages are already converted to Content object
92+
config = GenerateContentConfig(
93+
temperature=self.temperature,
94+
automatic_function_calling=AutomaticFunctionCallingConfig(disable=True),
95+
thinking_config=ThinkingConfig(include_thoughts=True),
96+
)
97+
11098
if "system_prompt" in inference_data:
111-
# We re-instantiate the GenerativeModel object with the system prompt
112-
# We cannot reassign the self.client object as it will affect other entries
113-
client = GenerativeModel(
114-
self.model_name.replace("-FC", ""),
115-
system_instruction=inference_data["system_prompt"],
116-
)
117-
else:
118-
client = self.client
99+
config.system_instruction = inference_data["system_prompt"]
100+
101+
if len(inference_data["tools"]) > 0:
102+
config.tools = [Tool(function_declarations=inference_data["tools"])]
119103

120104
return self.generate_with_backoff(
121-
client=client,
105+
model=self.model_name.replace("-FC", ""),
122106
contents=inference_data["message"],
123-
generation_config=GenerationConfig(
124-
temperature=self.temperature,
125-
),
126-
tools=tools,
107+
config=config,
127108
)
128109

129110
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
@@ -155,9 +136,12 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
155136
tool_call_func_names = []
156137
fc_parts = []
157138
text_parts = []
139+
reasoning_content = []
158140

159141
if (
160142
len(api_response.candidates) > 0
143+
and api_response.candidates[0].content
144+
and api_response.candidates[0].content.parts
161145
and len(api_response.candidates[0].content.parts) > 0
162146
):
163147
response_function_call_content = api_response.candidates[0].content
@@ -172,13 +156,17 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
172156

173157
fc_parts.append({part_func_name: part_func_args_dict})
174158
tool_call_func_names.append(part_func_name)
159+
# Aggregate reasoning content
160+
elif part.thought:
161+
reasoning_content.append(part.text)
175162
else:
176163
text_parts.append(part.text)
164+
177165
else:
178166
response_function_call_content = Content(
179167
role="model",
180168
parts=[
181-
Part.from_text("The model did not return any response."),
169+
Part(text="The model did not return any response."),
182170
],
183171
)
184172

@@ -188,6 +176,7 @@ def _parse_query_response_FC(self, api_response: any) -> dict:
188176
"model_responses": model_responses,
189177
"model_responses_message_for_chat_history": response_function_call_content,
190178
"tool_call_func_names": tool_call_func_names,
179+
"reasoning_content": "\n".join(reasoning_content),
191180
"input_token": api_response.usage_metadata.prompt_token_count,
192181
"output_token": api_response.usage_metadata.candidates_token_count,
193182
}
@@ -200,7 +189,7 @@ def add_first_turn_message_FC(
200189
Content(
201190
role=message["role"],
202191
parts=[
203-
Part.from_text(message["content"]),
192+
Part(text=message["content"]),
204193
],
205194
)
206195
)
@@ -235,12 +224,12 @@ def _add_execution_results_FC(
235224
Part.from_function_response(
236225
name=tool_call_func_name,
237226
response={
238-
"content": execution_result,
227+
"result": execution_result,
239228
},
240229
)
241230
)
242231

243-
tool_response_content = Content(parts=tool_response_parts)
232+
tool_response_content = Content(role="user", parts=tool_response_parts)
244233
inference_data["message"].append(tool_response_content)
245234

246235
return inference_data
@@ -253,20 +242,18 @@ def _query_prompting(self, inference_data: dict):
253242
"system_prompt": inference_data.get("system_prompt", None),
254243
}
255244

256-
# messages are already converted to Content object
245+
config = GenerateContentConfig(
246+
temperature=self.temperature,
247+
thinking_config=ThinkingConfig(include_thoughts=True),
248+
)
249+
257250
if "system_prompt" in inference_data:
258-
client = GenerativeModel(
259-
self.model_name.replace("-FC", ""),
260-
system_instruction=inference_data["system_prompt"],
261-
)
262-
else:
263-
client = self.client
251+
config.system_instruction = inference_data["system_prompt"]
252+
264253
api_response = self.generate_with_backoff(
265-
client=client,
254+
model=self.model_name.replace("-FC", ""),
266255
contents=inference_data["message"],
267-
generation_config=GenerationConfig(
268-
temperature=self.temperature,
269-
),
256+
config=config,
270257
)
271258
return api_response
272259

@@ -295,13 +282,28 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
295282
def _parse_query_response_prompting(self, api_response: any) -> dict:
296283
if (
297284
len(api_response.candidates) > 0
285+
and api_response.candidates[0].content
286+
and api_response.candidates[0].content.parts
298287
and len(api_response.candidates[0].content.parts) > 0
299288
):
300-
model_responses = api_response.text
289+
assert (
290+
len(api_response.candidates[0].content.parts) == 2
291+
), api_response.candidates[0].content.parts
292+
293+
model_responses = ""
294+
reasoning_content = ""
295+
for part in api_response.candidates[0].content.parts:
296+
if part.thought:
297+
reasoning_content = part.text
298+
else:
299+
model_responses = part.text
300+
301301
else:
302302
model_responses = "The model did not return any response."
303+
303304
return {
304305
"model_responses": model_responses,
306+
"reasoning_content": reasoning_content,
305307
"input_token": api_response.usage_metadata.prompt_token_count,
306308
"output_token": api_response.usage_metadata.candidates_token_count,
307309
}
@@ -314,7 +316,7 @@ def add_first_turn_message_prompting(
314316
Content(
315317
role=message["role"],
316318
parts=[
317-
Part.from_text(message["content"]),
319+
Part(text=message["content"]),
318320
],
319321
)
320322
)
@@ -332,7 +334,7 @@ def _add_assistant_message_prompting(
332334
Content(
333335
role="model",
334336
parts=[
335-
Part.from_text(model_response_data["model_responses"]),
337+
Part(text=model_response_data["model_responses"]),
336338
],
337339
)
338340
)
@@ -347,7 +349,7 @@ def _add_execution_results_prompting(
347349
tool_message = Content(
348350
role="user",
349351
parts=[
350-
Part.from_text(formatted_results_message),
352+
Part(text=formatted_results_message),
351353
],
352354
)
353355
inference_data["message"].append(tool_message)

berkeley-function-call-leaderboard/bfcl_eval/model_handler/local_inference/base_oss_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def _multi_threaded_inference(
285285
"❗️❗️ Error occurred during inference. Maximum reties reached for rate limit or other error. Continuing to next test case."
286286
)
287287
print(f"❗️❗️ Test case ID: {test_case['id']}, Error: {str(e)}")
288-
traceback.print_exc()
288+
traceback.print_exc(limit=10)
289289
print("-" * 100)
290290

291291
model_responses = f"Error during inference: {str(e)}"

berkeley-function-call-leaderboard/bfcl_eval/model_handler/model_style.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class ModelStyle(Enum):
88
OpenAI_Responses = "openai-responses"
99
Anthropic = "claude"
1010
Mistral = "mistral"
11-
Google = "google"
11+
GOOGLE = "google"
1212
AMAZON = "amazon"
1313
FIREWORK_AI = "firework_ai"
1414
NEXUS = "nexus"

0 commit comments

Comments
 (0)