3
3
# SPDX-License-Identifier: Apache-2.0
4
4
5
5
import copy
6
- import json
6
+ import inspect
7
7
import warnings
8
8
from dataclasses import dataclass
9
9
from enum import Enum
10
10
from functools import partial
11
- from inspect import signature
12
11
from types import MethodType
13
- from typing import Any , Callable , Literal , Optional , Union
12
+ from typing import Annotated , Any , Callable , Literal , Optional , Union
14
13
15
14
from pydantic import BaseModel , field_serializer
16
15
17
16
from ...doc_utils import export_module
18
17
from ...oai import OpenAIWrapper
18
+ from ...tools import Depends , Tool
19
+ from ...tools .dependency_injection import inject_params , on
19
20
from ..agent import Agent
20
21
from ..chat import ChatResult
21
22
from ..conversable_agent import __CONTEXT_VARIABLES_PARAM_NAME__ , ConversableAgent
@@ -298,8 +299,7 @@ def _link_agents_to_swarm_manager(agents: list[Agent], group_chat_manager: Agent
298
299
Does not link the Tool Executor agent.
299
300
"""
300
301
for agent in agents :
301
- if agent .name not in [__TOOL_EXECUTOR_NAME__ ]:
302
- agent ._swarm_manager = group_chat_manager # type: ignore[attr-defined]
302
+ agent ._swarm_manager = group_chat_manager # type: ignore[attr-defined]
303
303
304
304
305
305
def _run_oncontextconditions (
@@ -339,16 +339,72 @@ def _run_oncontextconditions(
339
339
return False , None
340
340
341
341
342
+ def _modify_context_variables_param (f : Callable [..., Any ], context_variables : dict [str , Any ]) -> Callable [..., Any ]:
343
+ """Modifies the context_variables parameter to use dependency injection and link it to the swarm context variables.
344
+
345
+ This essentially changes:
346
+ def some_function(some_variable: int, context_variables: dict[str, Any]) -> str:
347
+
348
+ to:
349
+
350
+ def some_function(some_variable: int, context_variables: Annotated[dict[str, Any], Depends(on(self._context_variables))]) -> str:
351
+ """
352
+ sig = inspect .signature (f )
353
+
354
+ # Check if context_variables parameter exists and update it if so
355
+ if __CONTEXT_VARIABLES_PARAM_NAME__ in sig .parameters :
356
+ new_params = []
357
+ for name , param in sig .parameters .items ():
358
+ if name == __CONTEXT_VARIABLES_PARAM_NAME__ :
359
+ # Replace with new annotation using Depends
360
+ new_param = param .replace (annotation = Annotated [dict [str , Any ], Depends (on (context_variables ))])
361
+ new_params .append (new_param )
362
+ else :
363
+ new_params .append (param )
364
+
365
+ # Update signature
366
+ new_sig = sig .replace (parameters = new_params )
367
+ f .__signature__ = new_sig # type: ignore[attr-defined]
368
+
369
+ return f
370
+
371
+
372
+ def _change_tool_context_variables_to_depends (
373
+ agent : ConversableAgent , current_tool : Tool , context_variables : dict [str , Any ]
374
+ ) -> None :
375
+ """Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
376
+
377
+ # If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
378
+ if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool .tool_schema ["function" ]["parameters" ]["properties" ]:
379
+ # We'll replace the tool, so start with getting the underlying function
380
+ tool_func = current_tool ._func
381
+
382
+ # Remove the Tool from the agent
383
+ name = current_tool ._name
384
+ description = current_tool ._description
385
+ agent .remove_tool_for_llm (current_tool )
386
+
387
+ # Recreate the tool without the context_variables parameter
388
+ tool_func = _modify_context_variables_param (current_tool ._func , context_variables )
389
+ tool_func = inject_params (tool_func )
390
+ new_tool = ConversableAgent ._create_tool_if_needed (func_or_tool = tool_func , name = name , description = description )
391
+
392
+ # Re-register with the agent
393
+ agent .register_for_llm ()(new_tool )
394
+
395
+
342
396
def _prepare_swarm_agents (
343
397
initial_agent : ConversableAgent ,
344
398
agents : list [ConversableAgent ],
399
+ context_variables : dict [str , Any ],
345
400
exclude_transit_message : bool = True ,
346
401
) -> tuple [ConversableAgent , list [ConversableAgent ]]:
347
402
"""Validates agents, create the tool executor, configure nested chats.
348
403
349
404
Args:
350
405
initial_agent (ConversableAgent): The first agent in the conversation.
351
406
agents (list[ConversableAgent]): List of all agents in the conversation.
407
+ context_variables (dict[str, Any]): Context variables to assign to all agents.
352
408
exclude_transit_message (bool): Whether to exclude transit messages from the agents.
353
409
354
410
Returns:
@@ -382,16 +438,24 @@ def _prepare_swarm_agents(
382
438
for agent in agents :
383
439
_create_nested_chats (agent , nested_chat_agents )
384
440
441
+ # Update any agent's tools that have context_variables as a parameter
442
+ # To use Dependency Injection
443
+
385
444
# Update tool execution agent with all the functions from all the agents
386
445
for agent in agents + nested_chat_agents :
387
446
tool_execution ._function_map .update (agent ._function_map )
447
+
388
448
# Add conditional functions to the tool_execution agent
389
449
for func_name , (func , _ ) in agent ._swarm_conditional_functions .items (): # type: ignore[attr-defined]
390
450
tool_execution ._function_map [func_name ] = func
391
451
392
- # Register tools from the tools property
452
+ # Update any agent tools that have context_variables parameters to use Dependency Injection
393
453
for tool in agent .tools :
394
- tool_execution .register_for_execution ()(tool )
454
+ _change_tool_context_variables_to_depends (agent , tool , context_variables )
455
+
456
+ # Add all tools to the Tool Executor agent
457
+ for tool in agent .tools :
458
+ tool_execution .register_for_execution (serialize = False )(tool )
395
459
396
460
if exclude_transit_message :
397
461
# get all transit functions names
@@ -876,8 +940,11 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
876
940
dict[str, Any]: Updated Context variables.
877
941
ConversableAgent: Last speaker.
878
942
"""
943
+ context_variables = context_variables or {}
879
944
880
- tool_execution , nested_chat_agents = _prepare_swarm_agents (initial_agent , agents , exclude_transit_message )
945
+ tool_execution , nested_chat_agents = _prepare_swarm_agents (
946
+ initial_agent , agents , context_variables , exclude_transit_message
947
+ )
881
948
882
949
processed_messages , last_agent , swarm_agent_names , temp_user_list = _process_initial_messages (
883
950
messages , user_agent , agents , nested_chat_agents
@@ -902,7 +969,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
902
969
manager = _create_swarm_manager (groupchat , swarm_manager_args , agents )
903
970
904
971
# Point all ConversableAgent's context variables to this function's context_variables
905
- _setup_context_variables (tool_execution , agents , manager , context_variables or {} )
972
+ _setup_context_variables (tool_execution , agents , manager , context_variables )
906
973
907
974
# Link all agents with the GroupChatManager to allow access to the group chat
908
975
# and other agents, particularly the tool executor for setting _swarm_next_agent
@@ -926,7 +993,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
926
993
927
994
_cleanup_temp_user_messages (chat_result )
928
995
929
- return chat_result , context_variables , manager .last_speaker # type: ignore[return-value]
996
+ return chat_result , context_variables if context_variables != {} else None , manager .last_speaker # type: ignore[return-value]
930
997
931
998
932
999
@export_module ("autogen" )
@@ -976,7 +1043,10 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
976
1043
dict[str, Any]: Updated Context variables.
977
1044
ConversableAgent: Last speaker.
978
1045
"""
979
- tool_execution , nested_chat_agents = _prepare_swarm_agents (initial_agent , agents , exclude_transit_message )
1046
+ context_variables = context_variables or {}
1047
+ tool_execution , nested_chat_agents = _prepare_swarm_agents (
1048
+ initial_agent , agents , context_variables , exclude_transit_message
1049
+ )
980
1050
981
1051
processed_messages , last_agent , swarm_agent_names , temp_user_list = _process_initial_messages (
982
1052
messages , user_agent , agents , nested_chat_agents
@@ -1001,7 +1071,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
1001
1071
manager = _create_swarm_manager (groupchat , swarm_manager_args , agents )
1002
1072
1003
1073
# Point all ConversableAgent's context variables to this function's context_variables
1004
- _setup_context_variables (tool_execution , agents , manager , context_variables or {} )
1074
+ _setup_context_variables (tool_execution , agents , manager , context_variables )
1005
1075
1006
1076
if len (processed_messages ) > 1 :
1007
1077
last_agent , last_message = await manager .a_resume (messages = processed_messages )
@@ -1021,7 +1091,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
1021
1091
1022
1092
_cleanup_temp_user_messages (chat_result )
1023
1093
1024
- return chat_result , context_variables , manager .last_speaker # type: ignore[return-value]
1094
+ return chat_result , context_variables if context_variables != {} else None , manager .last_speaker # type: ignore[return-value]
1025
1095
1026
1096
1027
1097
class SwarmResult (BaseModel ):
@@ -1155,6 +1225,8 @@ def _update_conditional_functions(agent: ConversableAgent, messages: Optional[li
1155
1225
condition = condition .format (context_variables = agent ._context_variables )
1156
1226
elif callable (condition ):
1157
1227
condition = condition (agent , messages )
1228
+
1229
+ # TODO: Don't add it if it's already there
1158
1230
agent ._add_single_function (func , func_name , condition )
1159
1231
1160
1232
@@ -1185,26 +1257,11 @@ def _generate_swarm_tool_reply(
1185
1257
tool_responses_inner = []
1186
1258
contents = []
1187
1259
for index in range (tool_call_count ):
1188
- # Deep copy to ensure no changes to messages when we insert the context variables
1189
1260
message_copy = copy .deepcopy (message )
1190
1261
1191
1262
# 1. add context_variables to the tool call arguments
1192
1263
tool_call = message_copy ["tool_calls" ][index ]
1193
1264
1194
- if tool_call ["type" ] == "function" :
1195
- function_name = tool_call ["function" ]["name" ]
1196
-
1197
- # Check if this function exists in our function map
1198
- if function_name in agent ._function_map :
1199
- func = agent ._function_map [function_name ] # Get the original function
1200
-
1201
- # Inject the context variables into the tool call if it has the parameter
1202
- sig = signature (func )
1203
- if __CONTEXT_VARIABLES_PARAM_NAME__ in sig .parameters :
1204
- current_args = json .loads (tool_call ["function" ]["arguments" ])
1205
- current_args [__CONTEXT_VARIABLES_PARAM_NAME__ ] = agent ._context_variables
1206
- tool_call ["function" ]["arguments" ] = json .dumps (current_args , default = BaseModel .model_dump_json )
1207
-
1208
1265
# Ensure we are only executing the one tool at a time
1209
1266
message_copy ["tool_calls" ] = [tool_call ]
1210
1267
@@ -1217,6 +1274,7 @@ def _generate_swarm_tool_reply(
1217
1274
# 3. update context_variables and next_agent, convert content to string
1218
1275
for tool_response in tool_message ["tool_responses" ]:
1219
1276
content = tool_response .get ("content" )
1277
+
1220
1278
if isinstance (content , SwarmResult ):
1221
1279
if content .context_variables != {}:
1222
1280
agent ._context_variables .update (content .context_variables )
@@ -1225,6 +1283,10 @@ def _generate_swarm_tool_reply(
1225
1283
elif isinstance (content , Agent ):
1226
1284
next_agent = content
1227
1285
1286
+ # Serialize the content to a string
1287
+ if content is not None :
1288
+ tool_response ["content" ] = str (content )
1289
+
1228
1290
tool_responses_inner .append (tool_response )
1229
1291
contents .append (str (tool_response ["content" ]))
1230
1292
0 commit comments