Skip to content

Commit a1af085

Browse files
committed
feat: support gemini-2.0-flash-thinking-exp
1 parent 90475d3 commit a1af085

File tree

2 files changed

+44
-31
lines changed

2 files changed

+44
-31
lines changed

models.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@
114114
output_price: 0
115115
supports_vision: true
116116
supports_function_calling: true
117+
- name: gemini-2.0-flash-thinking-exp
118+
max_input_tokens: 32768
119+
max_output_tokens: 8192
120+
input_price: 0
121+
output_price: 0
122+
supports_vision: true
117123
- name: gemini-exp-1206
118124
max_input_tokens: 32768
119125
max_output_tokens: 8192
@@ -452,6 +458,15 @@
452458
output_price: 0.075
453459
supports_vision: true
454460
supports_function_calling: true
461+
- name: gemini-2.0-flash-exp
462+
max_input_tokens: 1048576
463+
max_output_tokens: 8192
464+
supports_vision: true
465+
supports_function_calling: true
466+
- name: gemini-2.0-flash-thinking-exp-1219
467+
max_input_tokens: 32768
468+
max_output_tokens: 8192
469+
supports_vision: true
455470
- name: claude-3-5-sonnet-v2@20241022
456471
max_input_tokens: 200000
457472
max_output_tokens: 8192

src/client/vertexai.rs

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,25 @@ pub async fn gemini_chat_completions_streaming(
197197
let handle = |value: &str| -> Result<()> {
198198
let data: Value = serde_json::from_str(value)?;
199199
debug!("stream-data: {data}");
200-
if let Some(text) = data["candidates"][0]["content"]["parts"][0]["text"].as_str() {
201-
if !text.is_empty() {
202-
handler.text(text)?;
203-
}
204-
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
205-
.as_str()
206-
.or_else(|| data["candidates"][0]["finishReason"].as_str())
207-
{
208-
bail!("Content Blocked")
209-
} else if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
210-
for part in parts {
211-
if let (Some(name), Some(args)) = (
200+
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
201+
for (i, part) in parts.iter().enumerate() {
202+
if let Some(text) = part["text"].as_str() {
203+
if i > 0 {
204+
handler.text("\n\n")?;
205+
}
206+
handler.text(text)?;
207+
} else if let (Some(name), Some(args)) = (
212208
part["functionCall"]["name"].as_str(),
213209
part["functionCall"]["args"].as_object(),
214210
) {
215211
handler.tool_call(ToolCall::new(name.to_string(), json!(args), None))?;
216212
}
217213
}
214+
} else if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
215+
.as_str()
216+
.or_else(|| data["candidates"][0]["finishReason"].as_str())
217+
{
218+
bail!("Blocked due to safety")
218219
}
219220

220221
Ok(())
@@ -257,38 +258,35 @@ struct EmbeddingsResBodyPredictionEmbeddings {
257258
}
258259

259260
fn gemini_extract_chat_completions_text(data: &Value) -> Result<ChatCompletionsOutput> {
260-
let text = data["candidates"][0]["content"]["parts"][0]["text"]
261-
.as_str()
262-
.unwrap_or_default();
263-
261+
let mut text_parts = vec![];
264262
let mut tool_calls = vec![];
265263
if let Some(parts) = data["candidates"][0]["content"]["parts"].as_array() {
266-
tool_calls = parts
267-
.iter()
268-
.filter_map(|part| {
269-
if let (Some(name), Some(args)) = (
270-
part["functionCall"]["name"].as_str(),
271-
part["functionCall"]["args"].as_object(),
272-
) {
273-
Some(ToolCall::new(name.to_string(), json!(args), None))
274-
} else {
275-
None
276-
}
277-
})
278-
.collect()
264+
for part in parts {
265+
if let Some(text) = part["text"].as_str() {
266+
text_parts.push(text);
267+
}
268+
if let (Some(name), Some(args)) = (
269+
part["functionCall"]["name"].as_str(),
270+
part["functionCall"]["args"].as_object(),
271+
) {
272+
tool_calls.push(ToolCall::new(name.to_string(), json!(args), None));
273+
}
274+
}
279275
}
276+
277+
let text = text_parts.join("\n\n");
280278
if text.is_empty() && tool_calls.is_empty() {
281279
if let Some("SAFETY") = data["promptFeedback"]["blockReason"]
282280
.as_str()
283281
.or_else(|| data["candidates"][0]["finishReason"].as_str())
284282
{
285-
bail!("Content Blocked")
283+
bail!("Blocked due to safety")
286284
} else {
287285
bail!("Invalid response data: {data}");
288286
}
289287
}
290288
let output = ChatCompletionsOutput {
291-
text: text.to_string(),
289+
text,
292290
tool_calls,
293291
id: None,
294292
input_tokens: data["usageMetadata"]["promptTokenCount"].as_u64(),

0 commit comments

Comments
 (0)