Skip to content

Commit 82f32ca

Browse files
authored
feat: google genai integration with tool block (#20096)
1 parent 1f4c6fa commit 82f32ca

File tree

6 files changed

+131
-64
lines changed

6 files changed

+131
-64
lines changed

llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/base.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import typing
77
from typing import (
88
TYPE_CHECKING,
9+
cast,
910
Any,
1011
AsyncGenerator,
1112
Dict,
@@ -38,6 +39,7 @@
3839
MessageRole,
3940
ThinkingBlock,
4041
TextBlock,
42+
ToolCallBlock,
4143
)
4244
from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr
4345
from llama_index.core.callbacks import CallbackManager
@@ -376,7 +378,6 @@ def _stream_chat(
376378

377379
def gen() -> ChatResponseGen:
378380
content = ""
379-
existing_tool_calls = []
380381
thoughts = ""
381382
for r in response:
382383
if not r.candidates:
@@ -390,14 +391,11 @@ def gen() -> ChatResponseGen:
390391
else:
391392
content += content_delta
392393
llama_resp = chat_from_gemini_response(r)
393-
existing_tool_calls.extend(
394-
llama_resp.message.additional_kwargs.get("tool_calls", [])
395-
)
396-
llama_resp.delta = content_delta
397-
llama_resp.message.blocks = [TextBlock(text=content)]
398-
llama_resp.message.blocks.append(ThinkingBlock(content=thoughts))
399-
llama_resp.message.additional_kwargs["tool_calls"] = existing_tool_calls
400-
yield llama_resp
394+
if content:
395+
llama_resp.message.blocks.append(TextBlock(text=content))
396+
if thoughts:
397+
llama_resp.message.blocks.append(ThinkingBlock(content=thoughts))
398+
yield llama_resp
401399

402400
if self.use_file_api:
403401
asyncio.run(
@@ -429,7 +427,6 @@ async def _astream_chat(
429427

430428
async def gen() -> ChatResponseAsyncGen:
431429
content = ""
432-
existing_tool_calls = []
433430
thoughts = ""
434431
async for r in await chat.send_message_stream(
435432
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
@@ -448,19 +445,15 @@ async def gen() -> ChatResponseAsyncGen:
448445
else:
449446
content += content_delta
450447
llama_resp = chat_from_gemini_response(r)
451-
existing_tool_calls.extend(
452-
llama_resp.message.additional_kwargs.get(
453-
"tool_calls", []
454-
)
455-
)
456448
llama_resp.delta = content_delta
457-
llama_resp.message.blocks = [TextBlock(text=content)]
458-
llama_resp.message.blocks.append(
459-
ThinkingBlock(content=thoughts)
460-
)
461-
llama_resp.message.additional_kwargs["tool_calls"] = (
462-
existing_tool_calls
463-
)
449+
if content:
450+
llama_resp.message.blocks.append(
451+
TextBlock(text=content)
452+
)
453+
if thoughts:
454+
llama_resp.message.blocks.append(
455+
ThinkingBlock(content=thoughts)
456+
)
464457
yield llama_resp
465458

466459
if self.use_file_api:
@@ -551,7 +544,11 @@ def get_tool_calls_from_response(
551544
**kwargs: Any,
552545
) -> List[ToolSelection]:
553546
"""Predict and call the tool."""
554-
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
547+
tool_calls = [
548+
block
549+
for block in response.message.blocks
550+
if isinstance(block, ToolCallBlock)
551+
]
555552

556553
if len(tool_calls) < 1:
557554
if error_on_no_tool_call:
@@ -565,9 +562,9 @@ def get_tool_calls_from_response(
565562
for tool_call in tool_calls:
566563
tool_selections.append(
567564
ToolSelection(
568-
tool_id=tool_call["name"],
569-
tool_name=tool_call["name"],
570-
tool_kwargs=tool_call["args"],
565+
tool_id=tool_call.tool_name,
566+
tool_name=tool_call.tool_name,
567+
tool_kwargs=cast(Dict[str, Any], tool_call.tool_kwargs),
571568
)
572569
)
573570

llama-index-integrations/llms/llama-index-llms-google-genai/llama_index/llms/google_genai/utils.py

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import asyncio
2+
import json
23
import logging
34
from collections.abc import Sequence
45
from io import BytesIO
5-
from typing import (
6-
TYPE_CHECKING,
7-
Any,
8-
Dict,
9-
Union,
10-
Optional,
11-
Type,
12-
Tuple,
13-
)
6+
from typing import TYPE_CHECKING, Any, Dict, Union, Optional, Type, Tuple, cast
147
import typing
158

169
import google.genai.types as types
@@ -29,6 +22,7 @@
2922
DocumentBlock,
3023
VideoBlock,
3124
ThinkingBlock,
25+
ToolCallBlock,
3226
)
3327
from llama_index.core.program.utils import _repair_incomplete_json
3428
from tenacity import (
@@ -188,16 +182,33 @@ def chat_from_gemini_response(
188182
)
189183
additional_kwargs["thought_signatures"].append(part.thought_signature)
190184
if part.function_call:
191-
if "tool_calls" not in additional_kwargs:
192-
additional_kwargs["tool_calls"] = []
193-
additional_kwargs["tool_calls"].append(
194-
{
195-
"id": part.function_call.id if part.function_call.id else "",
196-
"name": part.function_call.name,
197-
"args": part.function_call.args,
198-
"thought_signature": part.thought_signature,
199-
}
185+
if (
186+
part.thought_signature
187+
not in additional_kwargs["thought_signatures"]
188+
):
189+
additional_kwargs["thought_signatures"].append(
190+
part.thought_signature
191+
)
192+
content_blocks.append(
193+
ToolCallBlock(
194+
tool_call_id=part.function_call.id or "",
195+
tool_name=part.function_call.name or "",
196+
tool_kwargs=part.function_call.args or {},
197+
)
200198
)
199+
if part.function_response:
200+
# follow the same pattern as for transforming a chatmessage into a gemini message: if it's a function response, package it alone and return it
201+
additional_kwargs["tool_call_id"] = part.function_response.id
202+
role = ROLES_FROM_GEMINI[top_candidate.content.role]
203+
print("RESPONSE", json.dumps(part.function_response.response))
204+
return ChatResponse(
205+
message=ChatMessage(
206+
role=role, content=json.dumps(part.function_response.response)
207+
),
208+
raw=raw,
209+
additional_kwargs=additional_kwargs,
210+
)
211+
201212
if thought_tokens:
202213
thinking_blocks = [
203214
i
@@ -271,6 +282,7 @@ async def chat_message_to_gemini(
271282
message: ChatMessage, use_file_api: bool = False, client: Optional[Client] = None
272283
) -> Union[types.Content, types.File]:
273284
"""Convert ChatMessages to Gemini-specific history, including ImageDocuments."""
285+
unique_tool_calls = []
274286
parts = []
275287
part = None
276288
for index, block in enumerate(message.blocks):
@@ -326,6 +338,11 @@ async def chat_message_to_gemini(
326338
part.thought_signature = block.additional_information.get(
327339
"thought_signature", None
328340
)
341+
elif isinstance(block, ToolCallBlock):
342+
part = types.Part.from_function_call(
343+
name=block.tool_name, args=cast(Dict[str, Any], block.tool_kwargs)
344+
)
345+
unique_tool_calls.append((block.tool_name, str(block.tool_kwargs)))
329346
else:
330347
msg = f"Unsupported content block type: {type(block).__name__}"
331348
raise ValueError(msg)
@@ -343,15 +360,20 @@ async def chat_message_to_gemini(
343360

344361
for tool_call in message.additional_kwargs.get("tool_calls", []):
345362
if isinstance(tool_call, dict):
346-
part = types.Part.from_function_call(
347-
name=tool_call.get("name"), args=tool_call.get("args")
348-
)
349-
part.thought_signature = tool_call.get("thought_signature")
363+
if (
364+
tool_call.get("name", ""),
365+
str(tool_call.get("args", {})),
366+
) not in unique_tool_calls:
367+
part = types.Part.from_function_call(
368+
name=tool_call.get("name", ""), args=tool_call.get("args", {})
369+
)
370+
part.thought_signature = tool_call.get("thought_signature")
350371
else:
351-
part = types.Part.from_function_call(
352-
name=tool_call.name, args=tool_call.args
353-
)
354-
part.thought_signature = tool_call.thought_signature
372+
if (tool_call.name, str(tool_call.args)) not in unique_tool_calls:
373+
part = types.Part.from_function_call(
374+
name=tool_call.name, args=tool_call.args
375+
)
376+
part.thought_signature = tool_call.thought_signature
355377
parts.append(part)
356378

357379
# the tool call id is the name of the tool

llama-index-integrations/llms/llama-index-llms-google-genai/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dev = [
2727

2828
[project]
2929
name = "llama-index-llms-google-genai"
30-
version = "0.6.2"
30+
version = "0.7.0"
3131
description = "llama-index llms google genai integration"
3232
authors = [{name = "Your Name", email = "[email protected]"}]
3333
requires-python = ">=3.9,<4.0"
@@ -36,7 +36,7 @@ license = "MIT"
3636
dependencies = [
3737
"pillow>=10.2.0",
3838
"google-genai>=1.24.0,<2",
39-
"llama-index-core>=0.14.3,<0.15",
39+
"llama-index-core>=0.14.5,<0.15",
4040
]
4141

4242
[tool.codespell]

llama-index-integrations/llms/llama-index-llms-google-genai/tests/test_llms_google_genai.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TextBlock,
1212
VideoBlock,
1313
ThinkingBlock,
14+
ToolCallBlock,
1415
)
1516
from llama_index.core.llms.llm import ToolSelection
1617
from llama_index.core.program.function_program import get_function_tool
@@ -564,8 +565,16 @@ def test_tool_required_integration(llm: GoogleGenAI) -> None:
564565
tools=[search_tool],
565566
tool_required=True,
566567
)
567-
assert response.message.additional_kwargs.get("tool_calls") is not None
568-
assert len(response.message.additional_kwargs["tool_calls"]) > 0
568+
assert (
569+
len(
570+
[
571+
block
572+
for block in response.message.blocks
573+
if isinstance(block, ToolCallBlock)
574+
]
575+
)
576+
> 0
577+
)
569578

570579
# Test with tool_required=False
571580
response = llm.chat_with_tools(
@@ -729,6 +738,10 @@ async def test_prepare_chat_params_more_than_2_tool_calls():
729738
)
730739
],
731740
),
741+
ChatMessage(
742+
blocks=[ToolCallBlock(tool_name="get_available_tools", tool_kwargs={})],
743+
role=MessageRole.ASSISTANT,
744+
),
732745
ChatMessage(
733746
content="Let me search for puppies.",
734747
role=MessageRole.ASSISTANT,
@@ -777,10 +790,11 @@ async def test_prepare_chat_params_more_than_2_tool_calls():
777790
text="The user is asking me for a puppy, so I should search for puppies using the available tools.",
778791
thought=True,
779792
),
793+
types.Part.from_function_call(name="get_available_tools", args={}),
780794
types.Part(text="Let me search for puppies."),
781-
types.Part.from_function_call(name="tool_1", args=None),
782-
types.Part.from_function_call(name="tool_2", args=None),
783-
types.Part.from_function_call(name="tool_3", args=None),
795+
types.Part.from_function_call(name="tool_1", args={}),
796+
types.Part.from_function_call(name="tool_2", args={}),
797+
types.Part.from_function_call(name="tool_3", args={}),
784798
],
785799
role=MessageRole.MODEL,
786800
),
@@ -872,6 +886,10 @@ def test_cached_content_in_response() -> None:
872886
mock_response.candidates[0].content.parts[0].text = "Test response"
873887
mock_response.candidates[0].content.parts[0].thought = False
874888
mock_response.candidates[0].content.parts[0].inline_data = None
889+
mock_response.candidates[0].content.parts[0].function_call.id = ""
890+
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
891+
mock_response.candidates[0].content.parts[0].function_call.args = {}
892+
mock_response.candidates[0].content.parts[0].function_response = None
875893
mock_response.prompt_feedback = None
876894
mock_response.usage_metadata = None
877895
mock_response.function_calls = None
@@ -899,6 +917,10 @@ def test_cached_content_without_cached_content() -> None:
899917
mock_response.candidates[0].content.parts[0].text = "Test response"
900918
mock_response.candidates[0].content.parts[0].thought = False
901919
mock_response.candidates[0].content.parts[0].inline_data = None
920+
mock_response.candidates[0].content.parts[0].function_call.id = ""
921+
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
922+
mock_response.candidates[0].content.parts[0].function_call.args = {}
923+
mock_response.candidates[0].content.parts[0].function_response = None
902924
mock_response.prompt_feedback = None
903925
mock_response.usage_metadata = None
904926
mock_response.function_calls = None
@@ -923,9 +945,15 @@ def test_thoughts_in_response() -> None:
923945
mock_response.candidates[0].content.parts[0].text = "This is a thought."
924946
mock_response.candidates[0].content.parts[0].inline_data = None
925947
mock_response.candidates[0].content.parts[0].thought = True
948+
mock_response.candidates[0].content.parts[0].function_call.id = ""
949+
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
950+
mock_response.candidates[0].content.parts[0].function_call.args = {}
926951
mock_response.candidates[0].content.parts[1].text = "This is not a thought."
927952
mock_response.candidates[0].content.parts[1].inline_data = None
928953
mock_response.candidates[0].content.parts[1].thought = None
954+
mock_response.candidates[0].content.parts[1].function_call = None
955+
mock_response.candidates[0].content.parts[1].function_response = None
956+
mock_response.candidates[0].content.parts[0].function_response = None
929957
mock_response.candidates[0].content.parts[0].model_dump = MagicMock(return_value={})
930958
mock_response.candidates[0].content.parts[1].model_dump = MagicMock(return_value={})
931959
mock_response.prompt_feedback = None
@@ -967,6 +995,8 @@ def test_thoughts_without_thought_response() -> None:
967995
mock_response.candidates[0].content.parts[0].text = "This is not a thought."
968996
mock_response.candidates[0].content.parts[0].inline_data = None
969997
mock_response.candidates[0].content.parts[0].thought = None
998+
mock_response.candidates[0].content.parts[0].function_call = None
999+
mock_response.candidates[0].content.parts[0].function_response = None
9701000
mock_response.prompt_feedback = None
9711001
mock_response.usage_metadata = None
9721002
mock_response.function_calls = None
@@ -1084,6 +1114,8 @@ def test_built_in_tool_in_response() -> None:
10841114
].text = "Test response with search results"
10851115
mock_response.candidates[0].content.parts[0].inline_data = None
10861116
mock_response.candidates[0].content.parts[0].thought = None
1117+
mock_response.candidates[0].content.parts[0].function_call = None
1118+
mock_response.candidates[0].content.parts[0].function_response = None
10871119
mock_response.prompt_feedback = None
10881120
mock_response.usage_metadata = MagicMock()
10891121
mock_response.usage_metadata.model_dump.return_value = {
@@ -1523,6 +1555,8 @@ def test_code_execution_response_parts() -> None:
15231555
)
15241556
mock_text_part.inline_data = None
15251557
mock_text_part.thought = None
1558+
mock_text_part.function_call = None
1559+
mock_text_part.function_response = None
15261560

15271561
mock_code_part = MagicMock()
15281562
mock_code_part.text = None
@@ -1532,6 +1566,8 @@ def test_code_execution_response_parts() -> None:
15321566
"code": "def is_prime(n):\n if n < 2:\n return False\n for i in range(2, int(n**0.5) + 1):\n if n % i == 0:\n return False\n return True\n\nprimes = []\nn = 2\nwhile len(primes) < 50:\n if is_prime(n):\n primes.append(n)\n n += 1\n\nprint(f'Sum of first 50 primes: {sum(primes)}')",
15331567
"language": types.Language.PYTHON,
15341568
}
1569+
mock_code_part.function_call = None
1570+
mock_code_part.function_response = None
15351571

15361572
mock_result_part = MagicMock()
15371573
mock_result_part.text = None
@@ -1541,11 +1577,15 @@ def test_code_execution_response_parts() -> None:
15411577
"outcome": types.Outcome.OUTCOME_OK,
15421578
"output": "Sum of first 50 primes: 5117",
15431579
}
1580+
mock_result_part.function_call = None
1581+
mock_result_part.function_response = None
15441582

15451583
mock_final_text_part = MagicMock()
15461584
mock_final_text_part.text = "The sum of the first 50 prime numbers is 5117."
15471585
mock_final_text_part.inline_data = None
15481586
mock_final_text_part.thought = None
1587+
mock_final_text_part.function_call = None
1588+
mock_final_text_part.function_response = None
15491589

15501590
mock_candidate.content.parts = [
15511591
mock_text_part,

0 commit comments

Comments
 (0)