Skip to content

Commit 291904c

Browse files
Add exponential retry logic for gemini models (#764)
This is to avoid the following error with long context on Gemini models due to insufficient quota: ``` Error: 429 Resource exhausted. Please try again later. Please refer to https://cloud.google.com/vertex-ai/generative-ai/docs/quotas#error-code-429 for more details. ``` This approach uses exponential backoff retries when encountering a ResourceExhausted error. --------- Co-authored-by: Huanzhi (Hans) Mao <[email protected]>
1 parent cb8ff39 commit 291904c

File tree

2 files changed

+40
-33
lines changed

2 files changed

+40
-33
lines changed

berkeley-function-call-leaderboard/bfcl/model_handler/proprietary_model/gemini.py

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import os
23

34
import vertexai
@@ -13,6 +14,13 @@
1314
func_doc_language_specific_pre_processing,
1415
system_prompt_pre_processing_chat_model,
1516
)
17+
from google.api_core.exceptions import ResourceExhausted
18+
from tenacity import (
19+
retry,
20+
retry_if_exception_type,
21+
stop_after_attempt,
22+
wait_random_exponential,
23+
)
1624
from vertexai.generative_models import (
1725
Content,
1826
FunctionDeclaration,
@@ -22,6 +30,7 @@
2230
Tool,
2331
)
2432

33+
logging.basicConfig(level=logging.INFO)
2534

2635
class GeminiHandler(BaseHandler):
2736
def __init__(self, model_name, temperature) -> None:
@@ -69,6 +78,18 @@ def decode_execute(self, result):
6978
)
7079
return func_call_list
7180

81+
@retry(
82+
wait=wait_random_exponential(min=6, max=120),
83+
stop=stop_after_attempt(10),
84+
retry=retry_if_exception_type(ResourceExhausted),
85+
before_sleep=lambda retry_state: print(
86+
f"Attempt {retry_state.attempt_number} failed. Sleeping for {float(round(retry_state.next_action.sleep, 2))} seconds before retrying..."
87+
f"Error: {retry_state.outcome.exception()}"
88+
),
89+
)
90+
def generate_with_backoff(self, client, **kwargs):
91+
return client.generate_content(**kwargs)
92+
7293
#### FC methods ####
7394

7495
def _query_FC(self, inference_data: dict):
@@ -100,21 +121,17 @@ def _query_FC(self, inference_data: dict):
100121
self.model_name.replace("-FC", ""),
101122
system_instruction=inference_data["system_prompt"],
102123
)
103-
api_response = client.generate_content(
104-
contents=inference_data["message"],
105-
generation_config=GenerationConfig(
106-
temperature=self.temperature,
107-
),
108-
tools=tools if len(tools) > 0 else None,
109-
)
110124
else:
111-
api_response = self.client.generate_content(
112-
contents=inference_data["message"],
113-
generation_config=GenerationConfig(
114-
temperature=self.temperature,
115-
),
116-
tools=tools if len(tools) > 0 else None,
117-
)
125+
client = self.client
126+
127+
api_response = self.generate_with_backoff(
128+
client=client,
129+
contents=inference_data["message"],
130+
generation_config=GenerationConfig(
131+
temperature=self.temperature,
132+
),
133+
tools=tools if len(tools) > 0 else None,
134+
)
118135
return api_response
119136

120137
def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
@@ -237,19 +254,15 @@ def _query_prompting(self, inference_data: dict):
237254
self.model_name.replace("-FC", ""),
238255
system_instruction=inference_data["system_prompt"],
239256
)
240-
api_response = client.generate_content(
241-
contents=inference_data["message"],
242-
generation_config=GenerationConfig(
243-
temperature=self.temperature,
244-
),
245-
)
246257
else:
247-
api_response = self.client.generate_content(
248-
contents=inference_data["message"],
249-
generation_config=GenerationConfig(
250-
temperature=self.temperature,
251-
),
252-
)
258+
client = self.client
259+
api_response = self.generate_with_backoff(
260+
client=client,
261+
contents=inference_data["message"],
262+
generation_config=GenerationConfig(
263+
temperature=self.temperature,
264+
),
265+
)
253266
return api_response
254267

255268
def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
@@ -275,13 +288,6 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
275288
return {"message": []}
276289

277290
def _parse_query_response_prompting(self, api_response: any) -> dict:
278-
# Note: Same issue as with mentioned in `_parse_query_response_FC` method
279-
# According to the Vertex AI documentation, `api_response.text` should be enough.
280-
# However, under the hood, it is calling `api_response.candidates[0].content.parts[0].text` which is causing the issue
281-
"""TypeError: argument of type 'Part' is not iterable"""
282-
# So again, we need to directly access the `api_response.candidates[0].content.parts[0]._raw_part.text` attribute to get the text content of the part
283-
# This is a workaround for this bug, until the bug is fixed
284-
285291
if len(api_response.candidates[0].content.parts) > 0:
286292
model_responses = api_response.text
287293
else:

berkeley-function-call-leaderboard/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"tabulate>=0.9.0",
3131
"google-cloud-aiplatform==1.72.0",
3232
"mpmath==1.3.0",
33+
"tenacity==9.0.0"
3334
]
3435

3536
[project.scripts]

0 commit comments

Comments
 (0)