Skip to content

Commit b2e6499

Browse files
Fix swarm context variable function injection (ag2ai#1264)
* Fix swarm tool registration to align with standardised approach * Move Field to a new file, update references * Added tests * Tidy * Restore originals * Dependency Injection for context_variables parameter * Fix swarm test * Tests * Update autogen/agentchat/contrib/swarm_agent.py Thanks! Co-authored-by: Davor Runje <[email protected]> * Pre-commit fixes --------- Co-authored-by: Davor Runje <[email protected]>
1 parent 2572ee2 commit b2e6499

File tree

4 files changed

+354
-72
lines changed

4 files changed

+354
-72
lines changed

autogen/agentchat/contrib/swarm_agent.py

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import copy
6-
import json
6+
import inspect
77
import warnings
88
from dataclasses import dataclass
99
from enum import Enum
1010
from functools import partial
11-
from inspect import signature
1211
from types import MethodType
13-
from typing import Any, Callable, Literal, Optional, Union
12+
from typing import Annotated, Any, Callable, Literal, Optional, Union
1413

1514
from pydantic import BaseModel, field_serializer
1615

1716
from ...doc_utils import export_module
1817
from ...oai import OpenAIWrapper
18+
from ...tools import Depends, Tool
19+
from ...tools.dependency_injection import inject_params, on
1920
from ..agent import Agent
2021
from ..chat import ChatResult
2122
from ..conversable_agent import __CONTEXT_VARIABLES_PARAM_NAME__, ConversableAgent
@@ -298,8 +299,7 @@ def _link_agents_to_swarm_manager(agents: list[Agent], group_chat_manager: Agent
298299
Does not link the Tool Executor agent.
299300
"""
300301
for agent in agents:
301-
if agent.name not in [__TOOL_EXECUTOR_NAME__]:
302-
agent._swarm_manager = group_chat_manager # type: ignore[attr-defined]
302+
agent._swarm_manager = group_chat_manager # type: ignore[attr-defined]
303303

304304

305305
def _run_oncontextconditions(
@@ -339,16 +339,72 @@ def _run_oncontextconditions(
339339
return False, None
340340

341341

342+
def _modify_context_variables_param(f: Callable[..., Any], context_variables: dict[str, Any]) -> Callable[..., Any]:
343+
"""Modifies the context_variables parameter to use dependency injection and link it to the swarm context variables.
344+
345+
This essentially changes:
346+
def some_function(some_variable: int, context_variables: dict[str, Any]) -> str:
347+
348+
to:
349+
350+
def some_function(some_variable: int, context_variables: Annotated[dict[str, Any], Depends(on(self._context_variables))]) -> str:
351+
"""
352+
sig = inspect.signature(f)
353+
354+
# Check if context_variables parameter exists and update it if so
355+
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
356+
new_params = []
357+
for name, param in sig.parameters.items():
358+
if name == __CONTEXT_VARIABLES_PARAM_NAME__:
359+
# Replace with new annotation using Depends
360+
new_param = param.replace(annotation=Annotated[dict[str, Any], Depends(on(context_variables))])
361+
new_params.append(new_param)
362+
else:
363+
new_params.append(param)
364+
365+
# Update signature
366+
new_sig = sig.replace(parameters=new_params)
367+
f.__signature__ = new_sig # type: ignore[attr-defined]
368+
369+
return f
370+
371+
372+
def _change_tool_context_variables_to_depends(
373+
agent: ConversableAgent, current_tool: Tool, context_variables: dict[str, Any]
374+
) -> None:
375+
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""
376+
377+
# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
378+
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
379+
# We'll replace the tool, so start with getting the underlying function
380+
tool_func = current_tool._func
381+
382+
# Remove the Tool from the agent
383+
name = current_tool._name
384+
description = current_tool._description
385+
agent.remove_tool_for_llm(current_tool)
386+
387+
# Recreate the tool without the context_variables parameter
388+
tool_func = _modify_context_variables_param(current_tool._func, context_variables)
389+
tool_func = inject_params(tool_func)
390+
new_tool = ConversableAgent._create_tool_if_needed(func_or_tool=tool_func, name=name, description=description)
391+
392+
# Re-register with the agent
393+
agent.register_for_llm()(new_tool)
394+
395+
342396
def _prepare_swarm_agents(
343397
initial_agent: ConversableAgent,
344398
agents: list[ConversableAgent],
399+
context_variables: dict[str, Any],
345400
exclude_transit_message: bool = True,
346401
) -> tuple[ConversableAgent, list[ConversableAgent]]:
347402
"""Validates agents, create the tool executor, configure nested chats.
348403
349404
Args:
350405
initial_agent (ConversableAgent): The first agent in the conversation.
351406
agents (list[ConversableAgent]): List of all agents in the conversation.
407+
context_variables (dict[str, Any]): Context variables to assign to all agents.
352408
exclude_transit_message (bool): Whether to exclude transit messages from the agents.
353409
354410
Returns:
@@ -382,16 +438,24 @@ def _prepare_swarm_agents(
382438
for agent in agents:
383439
_create_nested_chats(agent, nested_chat_agents)
384440

441+
# Update any agent's tools that have context_variables as a parameter
442+
# To use Dependency Injection
443+
385444
# Update tool execution agent with all the functions from all the agents
386445
for agent in agents + nested_chat_agents:
387446
tool_execution._function_map.update(agent._function_map)
447+
388448
# Add conditional functions to the tool_execution agent
389449
for func_name, (func, _) in agent._swarm_conditional_functions.items(): # type: ignore[attr-defined]
390450
tool_execution._function_map[func_name] = func
391451

392-
# Register tools from the tools property
452+
# Update any agent tools that have context_variables parameters to use Dependency Injection
393453
for tool in agent.tools:
394-
tool_execution.register_for_execution()(tool)
454+
_change_tool_context_variables_to_depends(agent, tool, context_variables)
455+
456+
# Add all tools to the Tool Executor agent
457+
for tool in agent.tools:
458+
tool_execution.register_for_execution(serialize=False)(tool)
395459

396460
if exclude_transit_message:
397461
# get all transit functions names
@@ -876,8 +940,11 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
876940
dict[str, Any]: Updated Context variables.
877941
ConversableAgent: Last speaker.
878942
"""
943+
context_variables = context_variables or {}
879944

880-
tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents, exclude_transit_message)
945+
tool_execution, nested_chat_agents = _prepare_swarm_agents(
946+
initial_agent, agents, context_variables, exclude_transit_message
947+
)
881948

882949
processed_messages, last_agent, swarm_agent_names, temp_user_list = _process_initial_messages(
883950
messages, user_agent, agents, nested_chat_agents
@@ -902,7 +969,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
902969
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)
903970

904971
# Point all ConversableAgent's context variables to this function's context_variables
905-
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
972+
_setup_context_variables(tool_execution, agents, manager, context_variables)
906973

907974
# Link all agents with the GroupChatManager to allow access to the group chat
908975
# and other agents, particularly the tool executor for setting _swarm_next_agent
@@ -926,7 +993,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
926993

927994
_cleanup_temp_user_messages(chat_result)
928995

929-
return chat_result, context_variables, manager.last_speaker # type: ignore[return-value]
996+
return chat_result, context_variables if context_variables != {} else None, manager.last_speaker # type: ignore[return-value]
930997

931998

932999
@export_module("autogen")
@@ -976,7 +1043,10 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
9761043
dict[str, Any]: Updated Context variables.
9771044
ConversableAgent: Last speaker.
9781045
"""
979-
tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents, exclude_transit_message)
1046+
context_variables = context_variables or {}
1047+
tool_execution, nested_chat_agents = _prepare_swarm_agents(
1048+
initial_agent, agents, context_variables, exclude_transit_message
1049+
)
9801050

9811051
processed_messages, last_agent, swarm_agent_names, temp_user_list = _process_initial_messages(
9821052
messages, user_agent, agents, nested_chat_agents
@@ -1001,7 +1071,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
10011071
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)
10021072

10031073
# Point all ConversableAgent's context variables to this function's context_variables
1004-
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
1074+
_setup_context_variables(tool_execution, agents, manager, context_variables)
10051075

10061076
if len(processed_messages) > 1:
10071077
last_agent, last_message = await manager.a_resume(messages=processed_messages)
@@ -1021,7 +1091,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
10211091

10221092
_cleanup_temp_user_messages(chat_result)
10231093

1024-
return chat_result, context_variables, manager.last_speaker # type: ignore[return-value]
1094+
return chat_result, context_variables if context_variables != {} else None, manager.last_speaker # type: ignore[return-value]
10251095

10261096

10271097
class SwarmResult(BaseModel):
@@ -1155,6 +1225,8 @@ def _update_conditional_functions(agent: ConversableAgent, messages: Optional[li
11551225
condition = condition.format(context_variables=agent._context_variables)
11561226
elif callable(condition):
11571227
condition = condition(agent, messages)
1228+
1229+
# TODO: Don't add it if it's already there
11581230
agent._add_single_function(func, func_name, condition)
11591231

11601232

@@ -1185,26 +1257,11 @@ def _generate_swarm_tool_reply(
11851257
tool_responses_inner = []
11861258
contents = []
11871259
for index in range(tool_call_count):
1188-
# Deep copy to ensure no changes to messages when we insert the context variables
11891260
message_copy = copy.deepcopy(message)
11901261

11911262
# 1. add context_variables to the tool call arguments
11921263
tool_call = message_copy["tool_calls"][index]
11931264

1194-
if tool_call["type"] == "function":
1195-
function_name = tool_call["function"]["name"]
1196-
1197-
# Check if this function exists in our function map
1198-
if function_name in agent._function_map:
1199-
func = agent._function_map[function_name] # Get the original function
1200-
1201-
# Inject the context variables into the tool call if it has the parameter
1202-
sig = signature(func)
1203-
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
1204-
current_args = json.loads(tool_call["function"]["arguments"])
1205-
current_args[__CONTEXT_VARIABLES_PARAM_NAME__] = agent._context_variables
1206-
tool_call["function"]["arguments"] = json.dumps(current_args, default=BaseModel.model_dump_json)
1207-
12081265
# Ensure we are only executing the one tool at a time
12091266
message_copy["tool_calls"] = [tool_call]
12101267

@@ -1217,6 +1274,7 @@ def _generate_swarm_tool_reply(
12171274
# 3. update context_variables and next_agent, convert content to string
12181275
for tool_response in tool_message["tool_responses"]:
12191276
content = tool_response.get("content")
1277+
12201278
if isinstance(content, SwarmResult):
12211279
if content.context_variables != {}:
12221280
agent._context_variables.update(content.context_variables)
@@ -1225,6 +1283,10 @@ def _generate_swarm_tool_reply(
12251283
elif isinstance(content, Agent):
12261284
next_agent = content
12271285

1286+
# Serialize the content to a string
1287+
if content is not None:
1288+
tool_response["content"] = str(content)
1289+
12281290
tool_responses_inner.append(tool_response)
12291291
contents.append(str(tool_response["content"]))
12301292

autogen/agentchat/conversable_agent.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
)
5959
from ..oai.client import ModelClient, OpenAIWrapper
6060
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
61-
from ..tools import ChatContext, Tool, get_function_schema, load_basemodels_if_needed, serialize_to_str
61+
from ..tools import ChatContext, Tool, load_basemodels_if_needed, serialize_to_str
6262
from .agent import Agent, LLMAgent
6363
from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats
6464
from .utils import consolidate_chat_info, gather_usage_summary
@@ -400,7 +400,7 @@ def _add_functions(self, func_list: list[Callable[..., Any]]):
400400
self._add_single_function(func)
401401

402402
def _add_single_function(self, func: Callable, name: Optional[str] = None, description: Optional[str] = ""):
403-
"""Add a single function to the agent, removing context variables for LLM use.
403+
"""Add a single function to the agent
404404
405405
Args:
406406
func (Callable): The function to register.
@@ -418,23 +418,8 @@ def _add_single_function(self, func: Callable, name: Optional[str] = None, descr
418418
# Use function's docstring, strip whitespace, fall back to empty string
419419
func._description = (func.__doc__ or "").strip()
420420

421-
f = get_function_schema(func, name=func._name, description=func._description)
422-
423-
# Remove context_variables parameter from function schema
424-
f_no_context = f.copy()
425-
if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]:
426-
del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__]
427-
if "required" in f_no_context["function"]["parameters"]:
428-
required = f_no_context["function"]["parameters"]["required"]
429-
f_no_context["function"]["parameters"]["required"] = [
430-
param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__
431-
]
432-
# If required list is empty, remove it
433-
if not f_no_context["function"]["parameters"]["required"]:
434-
del f_no_context["function"]["parameters"]["required"]
435-
436-
self.update_tool_signature(f_no_context, is_remove=False)
437-
self.register_function({func._name: func})
421+
# Register the function
422+
self.register_for_llm(name=name, description=description)(func)
438423

439424
def _register_update_agent_state_before_reply(
440425
self, functions: Optional[Union[list[Callable[..., Any]], Callable[..., Any]]]
@@ -3038,14 +3023,15 @@ def function_map(self) -> dict[str, Callable[..., Any]]:
30383023
"""Return the function map."""
30393024
return self._function_map
30403025

3041-
def _wrap_function(self, func: F, inject_params: dict[str, Any] = {}) -> F:
3026+
def _wrap_function(self, func: F, inject_params: dict[str, Any] = {}, *, serialize: bool = True) -> F:
30423027
"""Wrap the function inject chat context parameters and to dump the return value to json.
30433028
30443029
Handles both sync and async functions.
30453030
30463031
Args:
30473032
func: the function to be wrapped.
30483033
inject_params: the chat context parameters which will be passed to the function.
3034+
serialize: whether to serialize the return value
30493035
30503036
Returns:
30513037
The wrapped function.
@@ -3057,15 +3043,15 @@ def _wrapped_func(*args, **kwargs):
30573043
retval = func(*args, **kwargs, **inject_params)
30583044
if logging_enabled():
30593045
log_function_use(self, func, kwargs, retval)
3060-
return serialize_to_str(retval)
3046+
return serialize_to_str(retval) if serialize else retval
30613047

30623048
@load_basemodels_if_needed
30633049
@functools.wraps(func)
30643050
async def _a_wrapped_func(*args, **kwargs):
30653051
retval = await func(*args, **kwargs, **inject_params)
30663052
if logging_enabled():
30673053
log_function_use(self, func, kwargs, retval)
3068-
return serialize_to_str(retval)
3054+
return serialize_to_str(retval) if serialize else retval
30693055

30703056
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func
30713057

@@ -3181,6 +3167,8 @@ def register_for_execution(
31813167
self,
31823168
name: Optional[str] = None,
31833169
description: Optional[str] = None,
3170+
*,
3171+
serialize: bool = True,
31843172
) -> Callable[[Union[Tool, F]], Tool]:
31853173
"""Decorator factory for registering a function to be executed by an agent.
31863174
@@ -3189,6 +3177,7 @@ def register_for_execution(
31893177
Args:
31903178
name: name of the function. If None, the function name will be used (default: None).
31913179
description: description of the function (default: None).
3180+
serialize: whether to serialize the return value
31923181
31933182
Returns:
31943183
The decorator for registering a function to be used by an agent.
@@ -3223,7 +3212,9 @@ def _decorator(
32233212
chat_context = ChatContext(self)
32243213
chat_context_params = {param: chat_context for param in tool._chat_context_param_names}
32253214

3226-
self.register_function({tool.name: self._wrap_function(tool.func, chat_context_params)})
3215+
self.register_function({
3216+
tool.name: self._wrap_function(tool.func, chat_context_params, serialize=serialize)
3217+
})
32273218

32283219
return tool
32293220

autogen/tools/dependency_injection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"get_context_params",
2828
"inject_params",
2929
"on",
30+
"remove_params",
3031
]
3132

3233

@@ -135,7 +136,7 @@ def _is_depends_param(param: inspect.Parameter) -> bool:
135136
)
136137

137138

138-
def _remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
139+
def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
139140
new_signature = sig.replace(parameters=[p for p in sig.parameters.values() if p.name not in params])
140141
func.__signature__ = new_signature # type: ignore[attr-defined]
141142

@@ -147,7 +148,7 @@ def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable
147148

148149
sig = inspect.signature(func)
149150
params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)]
150-
_remove_params(func, sig, params_to_remove)
151+
remove_params(func, sig, params_to_remove)
151152
return func
152153

153154

0 commit comments

Comments
 (0)