Skip to content

Commit d4dd2bb

Browse files
AnnhilucMarkDaoust
authored andcommitted
fix: Update parse_config_for_mcp_tools to remove the deep copy of the config and filter tools
PiperOrigin-RevId: 759882878
1 parent e18af97 commit d4dd2bb

File tree

4 files changed

+304
-14
lines changed

4 files changed

+304
-14
lines changed

google/genai/_extra_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -453,17 +453,19 @@ async def parse_config_for_mcp_tools(
453453
mcp_to_genai_tool_adapters: dict[str, McpToGenAiToolAdapter] = {}
454454
if not config:
455455
return None, mcp_to_genai_tool_adapters
456-
config_model = _create_generate_content_config_model(config).model_copy(
457-
deep=True
458-
)
456+
config_model = _create_generate_content_config_model(config)
457+
# Create a copy of the config model with the tools field cleared as they will
458+
# be replaced with the MCP tools converted to GenAI tools.
459+
config_model_copy = config_model.model_copy(update={'tools': None})
459460
if config_model.tools:
461+
config_model_copy.tools = []
460462
for tool in config_model.tools:
461463
if McpClientSession is not None and isinstance(tool, McpClientSession):
462464
mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
463465
tool, await tool.list_tools()
464466
)
465467
# Extend the config with the MCP session tools converted to GenAI tools.
466-
config_model.tools.extend(mcp_to_genai_tool_adapter.tools)
468+
config_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
467469
for genai_tool in mcp_to_genai_tool_adapter.tools:
468470
if genai_tool.function_declarations:
469471
for function_declaration in genai_tool.function_declarations:
@@ -477,9 +479,9 @@ async def parse_config_for_mcp_tools(
477479
mcp_to_genai_tool_adapter
478480
)
479481
if McpClientSession is not None:
480-
config_model.tools = [
482+
config_model_copy.tools.extend(
481483
tool
482484
for tool in config_model.tools
483485
if not isinstance(tool, McpClientSession)
484-
]
485-
return config_model, mcp_to_genai_tool_adapters
486+
)
487+
return config_model_copy, mcp_to_genai_tool_adapters

google/genai/live.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,23 @@
5252

5353
if typing.TYPE_CHECKING:
5454
from mcp import ClientSession as McpClientSession
55+
from mcp.types import Tool as McpTool
5556
from ._adapters import McpToGenAiToolAdapter
57+
from ._mcp_utils import mcp_to_gemini_tool
5658
else:
5759
McpClientSession: typing.Type = Any
60+
McpTool: typing.Type = Any
5861
McpToGenAiToolAdapter: typing.Type = Any
5962
try:
6063
from mcp import ClientSession as McpClientSession
64+
from mcp.types import Tool as McpTool
6165
from ._adapters import McpToGenAiToolAdapter
66+
from ._mcp_utils import mcp_to_gemini_tool
6267
except ImportError:
6368
McpClientSession = None
69+
McpTool = None
6470
McpToGenAiToolAdapter = None
71+
mcp_to_gemini_tool = None
6572

6673
logger = logging.getLogger('google_genai.live')
6774

@@ -1005,22 +1012,31 @@ async def _t_live_connect_config(
10051012
parameter_model = config
10061013
parameter_model.system_instruction = system_instruction
10071014

1015+
# Create a copy of the config model with the tools field cleared as they will
1016+
# be replaced with the MCP tools converted to GenAI tools.
1017+
parameter_model_copy = parameter_model.model_copy(update={'tools': None})
10081018
if parameter_model.tools:
1019+
parameter_model_copy.tools = []
10091020
for tool in parameter_model.tools:
10101021
if McpClientSession is not None and isinstance(tool, McpClientSession):
10111022
mcp_to_genai_tool_adapter = McpToGenAiToolAdapter(
10121023
tool, await tool.list_tools()
10131024
)
10141025
# Extend the config with the MCP session tools converted to GenAI tools.
1015-
parameter_model.tools.extend(mcp_to_genai_tool_adapter.tools)
1026+
parameter_model_copy.tools.extend(mcp_to_genai_tool_adapter.tools)
1027+
if McpTool is not None and isinstance(tool, McpTool):
1028+
parameter_model_copy.tools.append(mcp_to_gemini_tool(tool))
10161029
if McpClientSession is not None:
1017-
parameter_model.tools = [
1030+
parameter_model_copy.tools.extend(
10181031
tool
10191032
for tool in parameter_model.tools
1020-
if not isinstance(tool, McpClientSession)
1021-
]
1033+
if (
1034+
not isinstance(tool, McpClientSession)
1035+
and not isinstance(tool, McpTool)
1036+
)
1037+
)
10221038

1023-
if parameter_model.generation_config is not None:
1039+
if parameter_model_copy.generation_config is not None:
10241040
warnings.warn(
10251041
'Setting `LiveConnectConfig.generation_config` is deprecated, '
10261042
'please set the fields on `LiveConnectConfig` directly. This will '
@@ -1029,4 +1045,4 @@ async def _t_live_connect_config(
10291045
stacklevel=4,
10301046
)
10311047

1032-
return parameter_model
1048+
return parameter_model_copy

google/genai/tests/live/test_live.py

Lines changed: 180 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
"""Tests for live.py."""
1818
import contextlib
1919
import json
20-
from typing import AsyncIterator
20+
import typing
21+
from typing import Any, AsyncIterator
2122
from unittest import mock
2223
from unittest.mock import AsyncMock
2324
from unittest.mock import Mock
@@ -36,6 +37,19 @@
3637
from ... import types
3738
from .. import pytest_helper
3839

40+
if typing.TYPE_CHECKING:
41+
from mcp import types as mcp_types
42+
from mcp import ClientSession as McpClientSession
43+
else:
44+
mcp_types: typing.Type = Any
45+
McpClientSession: typing.Type = Any
46+
try:
47+
from mcp import types as mcp_types
48+
from mcp import ClientSession as McpClientSession
49+
except ImportError:
50+
mcp_types = None
51+
McpClientSession = None
52+
3953

4054
function_declarations = [{
4155
'name': 'get_current_weather',
@@ -979,6 +993,171 @@ async def test_bidi_setup_to_api_with_tools_function_behavior(vertexai):
979993
)
980994

981995

996+
@pytest.mark.parametrize('vertexai', [True, False])
997+
@pytest.mark.asyncio
998+
async def test_bidi_setup_to_api_with_config_mcp_tools(
999+
vertexai,
1000+
):
1001+
if mcp_types is None:
1002+
return
1003+
1004+
expected_result_googleai = {
1005+
'setup': {
1006+
'model': 'models/test_model',
1007+
'tools': [{
1008+
'functionDeclarations': [{
1009+
'parameters': {
1010+
'type': 'OBJECT',
1011+
'properties': {
1012+
'location': {
1013+
'type': 'STRING',
1014+
},
1015+
},
1016+
},
1017+
'name': 'get_weather',
1018+
'description': 'Get the weather in a city.',
1019+
}],
1020+
}],
1021+
}
1022+
}
1023+
expected_result_vertexai = {
1024+
'setup': {
1025+
'generationConfig': {
1026+
'responseModalities': [
1027+
'AUDIO',
1028+
],
1029+
},
1030+
'model': (
1031+
'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1032+
),
1033+
'tools': [{
1034+
'functionDeclarations': [{
1035+
'parameters': {
1036+
'type': 'OBJECT',
1037+
'properties': {
1038+
'location': {
1039+
'type': 'STRING',
1040+
},
1041+
},
1042+
},
1043+
'name': 'get_weather',
1044+
'description': 'Get the weather in a city.',
1045+
}],
1046+
}],
1047+
}
1048+
}
1049+
result = await get_connect_message(
1050+
mock_api_client(vertexai=vertexai),
1051+
model='test_model',
1052+
config={
1053+
'tools': [
1054+
mcp_types.Tool(
1055+
name='get_weather',
1056+
description='Get the weather in a city.',
1057+
inputSchema={
1058+
'type': 'object',
1059+
'properties': {'location': {'type': 'string'}},
1060+
},
1061+
)
1062+
],
1063+
},
1064+
)
1065+
1066+
assert (
1067+
result == expected_result_vertexai
1068+
if vertexai
1069+
else expected_result_googleai
1070+
)
1071+
1072+
1073+
@pytest.mark.parametrize('vertexai', [True, False])
1074+
@pytest.mark.asyncio
1075+
async def test_bidi_setup_to_api_with_config_mcp_session(
1076+
vertexai,
1077+
):
1078+
if mcp_types is None:
1079+
return
1080+
1081+
class MockMcpClientSession(McpClientSession):
1082+
1083+
def __init__(self):
1084+
self._read_stream = None
1085+
self._write_stream = None
1086+
1087+
async def list_tools(self):
1088+
return mcp_types.ListToolsResult(
1089+
tools=[
1090+
mcp_types.Tool(
1091+
name='get_weather',
1092+
description='Get the weather in a city.',
1093+
inputSchema={
1094+
'type': 'object',
1095+
'properties': {'location': {'type': 'string'}},
1096+
},
1097+
),
1098+
]
1099+
)
1100+
1101+
expected_result_googleai = {
1102+
'setup': {
1103+
'model': 'models/test_model',
1104+
'tools': [{
1105+
'functionDeclarations': [{
1106+
'parameters': {
1107+
'type': 'OBJECT',
1108+
'properties': {
1109+
'location': {
1110+
'type': 'STRING',
1111+
},
1112+
},
1113+
},
1114+
'name': 'get_weather',
1115+
'description': 'Get the weather in a city.',
1116+
}],
1117+
}],
1118+
}
1119+
}
1120+
expected_result_vertexai = {
1121+
'setup': {
1122+
'generationConfig': {
1123+
'responseModalities': [
1124+
'AUDIO',
1125+
],
1126+
},
1127+
'model': (
1128+
'projects/test_project/locations/us-central1/publishers/google/models/test_model'
1129+
),
1130+
'tools': [{
1131+
'functionDeclarations': [{
1132+
'parameters': {
1133+
'type': 'OBJECT',
1134+
'properties': {
1135+
'location': {
1136+
'type': 'STRING',
1137+
},
1138+
},
1139+
},
1140+
'name': 'get_weather',
1141+
'description': 'Get the weather in a city.',
1142+
}],
1143+
}],
1144+
}
1145+
}
1146+
result = await get_connect_message(
1147+
mock_api_client(vertexai=vertexai),
1148+
model='test_model',
1149+
config={
1150+
'tools': [MockMcpClientSession()],
1151+
},
1152+
)
1153+
1154+
assert (
1155+
result == expected_result_vertexai
1156+
if vertexai
1157+
else expected_result_googleai
1158+
)
1159+
1160+
9821161
@pytest.mark.parametrize('vertexai', [True, False])
9831162
@pytest.mark.asyncio
9841163
async def test_bidi_setup_to_api_with_config_tools_code_execution(

0 commit comments

Comments
 (0)