Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

import vertexai
Expand All @@ -13,6 +14,13 @@
func_doc_language_specific_pre_processing,
system_prompt_pre_processing_chat_model,
)
from google.api_core.exceptions import ResourceExhausted
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from vertexai.generative_models import (
Content,
FunctionDeclaration,
Expand All @@ -22,6 +30,7 @@
Tool,
)

logging.basicConfig(level=logging.INFO)

class GeminiHandler(BaseHandler):
def __init__(self, model_name, temperature) -> None:
Expand Down Expand Up @@ -69,6 +78,18 @@ def decode_execute(self, result):
)
return func_call_list

@retry(
wait=wait_random_exponential(min=6, max=120),
stop=stop_after_attempt(10),
retry=retry_if_exception_type(ResourceExhausted),
before_sleep=lambda retry_state: print(
f"Attempt {retry_state.attempt_number} failed. Sleeping for {float(round(retry_state.next_action.sleep, 2))} seconds before retrying..."
f"Error: {retry_state.outcome.exception()}"
),
)
def generate_with_backoff(self, client, **kwargs):
return client.generate_content(**kwargs)

#### FC methods ####

def _query_FC(self, inference_data: dict):
Expand Down Expand Up @@ -100,21 +121,17 @@ def _query_FC(self, inference_data: dict):
self.model_name.replace("-FC", ""),
system_instruction=inference_data["system_prompt"],
)
api_response = client.generate_content(
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
tools=tools if len(tools) > 0 else None,
)
else:
api_response = self.client.generate_content(
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
tools=tools if len(tools) > 0 else None,
)
client = self.client

api_response = self.generate_with_backoff(
client=client,
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
tools=tools if len(tools) > 0 else None,
)
return api_response

def _pre_query_processing_FC(self, inference_data: dict, test_entry: dict) -> dict:
Expand Down Expand Up @@ -237,19 +254,15 @@ def _query_prompting(self, inference_data: dict):
self.model_name.replace("-FC", ""),
system_instruction=inference_data["system_prompt"],
)
api_response = client.generate_content(
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
)
else:
api_response = self.client.generate_content(
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
)
client = self.client
api_response = self.generate_with_backoff(
client=client,
contents=inference_data["message"],
generation_config=GenerationConfig(
temperature=self.temperature,
),
)
return api_response

def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
Expand All @@ -275,13 +288,6 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
return {"message": []}

def _parse_query_response_prompting(self, api_response: any) -> dict:
# Note: Same issue as with mentioned in `_parse_query_response_FC` method
# According to the Vertex AI documentation, `api_response.text` should be enough.
# However, under the hood, it is calling `api_response.candidates[0].content.parts[0].text` which is causing the issue
"""TypeError: argument of type 'Part' is not iterable"""
# So again, we need to directly access the `api_response.candidates[0].content.parts[0]._raw_part.text` attribute to get the text content of the part
# This is a workaround for this bug, until the bug is fixed

if len(api_response.candidates[0].content.parts) > 0:
model_responses = api_response.text
else:
Expand Down
1 change: 1 addition & 0 deletions berkeley-function-call-leaderboard/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"tabulate>=0.9.0",
"google-cloud-aiplatform==1.72.0",
"mpmath==1.3.0",
"tenacity==9.0.0"
]

[project.scripts]
Expand Down