1
+ import logging
1
2
import os
2
3
3
4
import vertexai
13
14
func_doc_language_specific_pre_processing ,
14
15
system_prompt_pre_processing_chat_model ,
15
16
)
17
+ from google .api_core .exceptions import ResourceExhausted
18
+ from tenacity import (
19
+ retry ,
20
+ retry_if_exception_type ,
21
+ stop_after_attempt ,
22
+ wait_random_exponential ,
23
+ )
16
24
from vertexai .generative_models import (
17
25
Content ,
18
26
FunctionDeclaration ,
22
30
Tool ,
23
31
)
24
32
33
+ logging .basicConfig (level = logging .INFO )
25
34
26
35
class GeminiHandler (BaseHandler ):
27
36
def __init__ (self , model_name , temperature ) -> None :
@@ -69,6 +78,18 @@ def decode_execute(self, result):
69
78
)
70
79
return func_call_list
71
80
81
+ @retry (
82
+ wait = wait_random_exponential (min = 6 , max = 120 ),
83
+ stop = stop_after_attempt (10 ),
84
+ retry = retry_if_exception_type (ResourceExhausted ),
85
+ before_sleep = lambda retry_state : print (
86
+ f"Attempt { retry_state .attempt_number } failed. Sleeping for { float (round (retry_state .next_action .sleep , 2 ))} seconds before retrying..."
87
+ f"Error: { retry_state .outcome .exception ()} "
88
+ ),
89
+ )
90
+ def generate_with_backoff (self , client , ** kwargs ):
91
+ return client .generate_content (** kwargs )
92
+
72
93
#### FC methods ####
73
94
74
95
def _query_FC (self , inference_data : dict ):
@@ -100,21 +121,17 @@ def _query_FC(self, inference_data: dict):
100
121
self .model_name .replace ("-FC" , "" ),
101
122
system_instruction = inference_data ["system_prompt" ],
102
123
)
103
- api_response = client .generate_content (
104
- contents = inference_data ["message" ],
105
- generation_config = GenerationConfig (
106
- temperature = self .temperature ,
107
- ),
108
- tools = tools if len (tools ) > 0 else None ,
109
- )
110
124
else :
111
- api_response = self .client .generate_content (
112
- contents = inference_data ["message" ],
113
- generation_config = GenerationConfig (
114
- temperature = self .temperature ,
115
- ),
116
- tools = tools if len (tools ) > 0 else None ,
117
- )
125
+ client = self .client
126
+
127
+ api_response = self .generate_with_backoff (
128
+ client = client ,
129
+ contents = inference_data ["message" ],
130
+ generation_config = GenerationConfig (
131
+ temperature = self .temperature ,
132
+ ),
133
+ tools = tools if len (tools ) > 0 else None ,
134
+ )
118
135
return api_response
119
136
120
137
def _pre_query_processing_FC (self , inference_data : dict , test_entry : dict ) -> dict :
@@ -237,19 +254,15 @@ def _query_prompting(self, inference_data: dict):
237
254
self .model_name .replace ("-FC" , "" ),
238
255
system_instruction = inference_data ["system_prompt" ],
239
256
)
240
- api_response = client .generate_content (
241
- contents = inference_data ["message" ],
242
- generation_config = GenerationConfig (
243
- temperature = self .temperature ,
244
- ),
245
- )
246
257
else :
247
- api_response = self .client .generate_content (
248
- contents = inference_data ["message" ],
249
- generation_config = GenerationConfig (
250
- temperature = self .temperature ,
251
- ),
252
- )
258
+ client = self .client
259
+ api_response = self .generate_with_backoff (
260
+ client = client ,
261
+ contents = inference_data ["message" ],
262
+ generation_config = GenerationConfig (
263
+ temperature = self .temperature ,
264
+ ),
265
+ )
253
266
return api_response
254
267
255
268
def _pre_query_processing_prompting (self , test_entry : dict ) -> dict :
@@ -275,13 +288,6 @@ def _pre_query_processing_prompting(self, test_entry: dict) -> dict:
275
288
return {"message" : []}
276
289
277
290
def _parse_query_response_prompting (self , api_response : any ) -> dict :
278
- # Note: Same issue as with mentioned in `_parse_query_response_FC` method
279
- # According to the Vertex AI documentation, `api_response.text` should be enough.
280
- # However, under the hood, it is calling `api_response.candidates[0].content.parts[0].text` which is causing the issue
281
- """TypeError: argument of type 'Part' is not iterable"""
282
- # So again, we need to directly access the `api_response.candidates[0].content.parts[0]._raw_part.text` attribute to get the text content of the part
283
- # This is a workaround for this bug, until the bug is fixed
284
-
285
291
if len (api_response .candidates [0 ].content .parts ) > 0 :
286
292
model_responses = api_response .text
287
293
else :
0 commit comments