Skip to content

Fix swarm context variable function injection #1264

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 95 additions & 28 deletions autogen/agentchat/contrib/swarm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import json
import inspect
import warnings
from dataclasses import dataclass
from enum import Enum
from functools import partial
from inspect import signature
from types import MethodType
from typing import Any, Callable, Literal, Optional, Union
from typing import Annotated, Any, Callable, Literal, Optional, Union

from pydantic import BaseModel, field_serializer

from autogen.tools.dependency_injection import (
Depends,
inject_params,
on,
)
from autogen.tools.tool import Tool

from ...doc_utils import export_module
from ...oai import OpenAIWrapper
from ..agent import Agent
Expand Down Expand Up @@ -298,8 +304,7 @@ def _link_agents_to_swarm_manager(agents: list[Agent], group_chat_manager: Agent
Does not link the Tool Executor agent.
"""
for agent in agents:
if agent.name not in [__TOOL_EXECUTOR_NAME__]:
agent._swarm_manager = group_chat_manager # type: ignore[attr-defined]
agent._swarm_manager = group_chat_manager # type: ignore[attr-defined]


def _run_oncontextconditions(
Expand Down Expand Up @@ -339,16 +344,72 @@ def _run_oncontextconditions(
return False, None


def _modify_context_variables_param(f: Callable[..., Any], context_variables: dict[str, Any]) -> Callable[..., Any]:
"""Modifies the context_variables parameter to use dependency injection and link it to the swarm context variables.

This essentially changes:
def some_function(some_variable: int, context_variables: dict[str, Any]) -> str:

to:

def some_function(some_variable: int, context_variables: Annotated[dict[str, Any], Depends(on(self._context_variables))]) -> str:
"""
sig = inspect.signature(f)

# Check if context_variables parameter exists and update it if so
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
new_params = []
for name, param in sig.parameters.items():
if name == __CONTEXT_VARIABLES_PARAM_NAME__:
# Replace with new annotation using Depends
new_param = param.replace(annotation=Annotated[dict[str, Any], Depends(on(context_variables))])
new_params.append(new_param)
else:
new_params.append(param)

# Update signature
new_sig = sig.replace(parameters=new_params)
f.__signature__ = new_sig # type: ignore[attr-defined]

return f


def _change_tool_context_variables_to_depends(
agent: ConversableAgent, current_tool: Tool, context_variables: dict[str, Any]
) -> None:
"""Checks for the context_variables parameter in the tool and updates it to use dependency injection."""

# If the tool has a context_variables parameter, remove the tool and reregister it without the parameter
if __CONTEXT_VARIABLES_PARAM_NAME__ in current_tool.tool_schema["function"]["parameters"]["properties"]:
# We'll replace the tool, so start with getting the underlying function
tool_func = current_tool._func

# Remove the Tool from the agent
name = current_tool._name
description = current_tool._description
agent.remove_tool_for_llm(current_tool)

# Recreate the tool without the context_variables parameter
tool_func = _modify_context_variables_param(current_tool._func, context_variables)
tool_func = inject_params(tool_func)
new_tool = ConversableAgent._create_tool_if_needed(func_or_tool=tool_func, name=name, description=description)

# Re-register with the agent
agent.register_for_llm()(new_tool)


def _prepare_swarm_agents(
initial_agent: ConversableAgent,
agents: list[ConversableAgent],
context_variables: dict[str, Any],
exclude_transit_message: bool = True,
) -> tuple[ConversableAgent, list[ConversableAgent]]:
"""Validates agents, create the tool executor, configure nested chats.

Args:
initial_agent (ConversableAgent): The first agent in the conversation.
agents (list[ConversableAgent]): List of all agents in the conversation.
context_variables (dict[str, Any]): Context variables to assign to all agents.
exclude_transit_message (bool): Whether to exclude transit messages from the agents.

Returns:
Expand Down Expand Up @@ -382,16 +443,24 @@ def _prepare_swarm_agents(
for agent in agents:
_create_nested_chats(agent, nested_chat_agents)

# Update any agent's tools that have context_variables as a parameter
# To use Dependency Injection

# Update tool execution agent with all the functions from all the agents
for agent in agents + nested_chat_agents:
tool_execution._function_map.update(agent._function_map)

# Add conditional functions to the tool_execution agent
for func_name, (func, _) in agent._swarm_conditional_functions.items(): # type: ignore[attr-defined]
tool_execution._function_map[func_name] = func

# Register tools from the tools property
# Update any agent tools that have context_variables parameters to use Dependency Injection
for tool in agent.tools:
_change_tool_context_variables_to_depends(agent, tool, context_variables)

# Add all tools to the Tool Executor agent
for tool in agent.tools:
tool_execution.register_for_execution()(tool)
tool_execution.register_for_execution(serialize=False)(tool)

if exclude_transit_message:
# get all transit functions names
Expand Down Expand Up @@ -876,8 +945,11 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
dict[str, Any]: Updated Context variables.
ConversableAgent: Last speaker.
"""
context_variables = context_variables or {}

tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents, exclude_transit_message)
tool_execution, nested_chat_agents = _prepare_swarm_agents(
initial_agent, agents, context_variables, exclude_transit_message
)

processed_messages, last_agent, swarm_agent_names, temp_user_list = _process_initial_messages(
messages, user_agent, agents, nested_chat_agents
Expand All @@ -902,7 +974,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)

# Point all ConversableAgent's context variables to this function's context_variables
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
_setup_context_variables(tool_execution, agents, manager, context_variables)

# Link all agents with the GroupChatManager to allow access to the group chat
# and other agents, particularly the tool executor for setting _swarm_next_agent
Expand All @@ -926,7 +998,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st

_cleanup_temp_user_messages(chat_result)

return chat_result, context_variables, manager.last_speaker # type: ignore[return-value]
return chat_result, context_variables if context_variables != {} else None, manager.last_speaker # type: ignore[return-value]


@export_module("autogen")
Expand Down Expand Up @@ -976,7 +1048,10 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
dict[str, Any]: Updated Context variables.
ConversableAgent: Last speaker.
"""
tool_execution, nested_chat_agents = _prepare_swarm_agents(initial_agent, agents, exclude_transit_message)
context_variables = context_variables or {}
tool_execution, nested_chat_agents = _prepare_swarm_agents(
initial_agent, agents, context_variables, exclude_transit_message
)

processed_messages, last_agent, swarm_agent_names, temp_user_list = _process_initial_messages(
messages, user_agent, agents, nested_chat_agents
Expand All @@ -1001,7 +1076,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st
manager = _create_swarm_manager(groupchat, swarm_manager_args, agents)

# Point all ConversableAgent's context variables to this function's context_variables
_setup_context_variables(tool_execution, agents, manager, context_variables or {})
_setup_context_variables(tool_execution, agents, manager, context_variables)

if len(processed_messages) > 1:
last_agent, last_message = await manager.a_resume(messages=processed_messages)
Expand All @@ -1021,7 +1096,7 @@ def custom_afterwork_func(last_speaker: ConversableAgent, messages: list[dict[st

_cleanup_temp_user_messages(chat_result)

return chat_result, context_variables, manager.last_speaker # type: ignore[return-value]
return chat_result, context_variables if context_variables != {} else None, manager.last_speaker # type: ignore[return-value]


class SwarmResult(BaseModel):
Expand Down Expand Up @@ -1155,6 +1230,8 @@ def _update_conditional_functions(agent: ConversableAgent, messages: Optional[li
condition = condition.format(context_variables=agent._context_variables)
elif callable(condition):
condition = condition(agent, messages)

# TODO: Don't add it if it's already there
agent._add_single_function(func, func_name, condition)


Expand Down Expand Up @@ -1185,26 +1262,11 @@ def _generate_swarm_tool_reply(
tool_responses_inner = []
contents = []
for index in range(tool_call_count):
# Deep copy to ensure no changes to messages when we insert the context variables
message_copy = copy.deepcopy(message)

# 1. add context_variables to the tool call arguments
tool_call = message_copy["tool_calls"][index]

if tool_call["type"] == "function":
function_name = tool_call["function"]["name"]

# Check if this function exists in our function map
if function_name in agent._function_map:
func = agent._function_map[function_name] # Get the original function

# Inject the context variables into the tool call if it has the parameter
sig = signature(func)
if __CONTEXT_VARIABLES_PARAM_NAME__ in sig.parameters:
current_args = json.loads(tool_call["function"]["arguments"])
current_args[__CONTEXT_VARIABLES_PARAM_NAME__] = agent._context_variables
tool_call["function"]["arguments"] = json.dumps(current_args, default=BaseModel.model_dump_json)

# Ensure we are only executing the one tool at a time
message_copy["tool_calls"] = [tool_call]

Expand All @@ -1217,6 +1279,7 @@ def _generate_swarm_tool_reply(
# 3. update context_variables and next_agent, convert content to string
for tool_response in tool_message["tool_responses"]:
content = tool_response.get("content")

if isinstance(content, SwarmResult):
if content.context_variables != {}:
agent._context_variables.update(content.context_variables)
Expand All @@ -1225,6 +1288,10 @@ def _generate_swarm_tool_reply(
elif isinstance(content, Agent):
next_agent = content

# Serialize the content to a string
if content is not None:
tool_response["content"] = str(content)

tool_responses_inner.append(tool_response)
contents.append(str(tool_response["content"]))

Expand Down
35 changes: 12 additions & 23 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
)
from ..oai.client import ModelClient, OpenAIWrapper
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
from ..tools import ChatContext, Tool, get_function_schema, load_basemodels_if_needed, serialize_to_str
from ..tools import ChatContext, Tool, load_basemodels_if_needed, serialize_to_str
from .agent import Agent, LLMAgent
from .chat import ChatResult, _post_process_carryover_item, a_initiate_chats, initiate_chats
from .utils import consolidate_chat_info, gather_usage_summary
Expand Down Expand Up @@ -400,7 +400,7 @@ def _add_functions(self, func_list: list[Callable[..., Any]]):
self._add_single_function(func)

def _add_single_function(self, func: Callable, name: Optional[str] = None, description: Optional[str] = ""):
"""Add a single function to the agent, removing context variables for LLM use.
"""Add a single function to the agent

Args:
func (Callable): The function to register.
Expand All @@ -418,23 +418,8 @@ def _add_single_function(self, func: Callable, name: Optional[str] = None, descr
# Use function's docstring, strip whitespace, fall back to empty string
func._description = (func.__doc__ or "").strip()

f = get_function_schema(func, name=func._name, description=func._description)

# Remove context_variables parameter from function schema
f_no_context = f.copy()
if __CONTEXT_VARIABLES_PARAM_NAME__ in f_no_context["function"]["parameters"]["properties"]:
del f_no_context["function"]["parameters"]["properties"][__CONTEXT_VARIABLES_PARAM_NAME__]
if "required" in f_no_context["function"]["parameters"]:
required = f_no_context["function"]["parameters"]["required"]
f_no_context["function"]["parameters"]["required"] = [
param for param in required if param != __CONTEXT_VARIABLES_PARAM_NAME__
]
# If required list is empty, remove it
if not f_no_context["function"]["parameters"]["required"]:
del f_no_context["function"]["parameters"]["required"]

self.update_tool_signature(f_no_context, is_remove=False)
self.register_function({func._name: func})
# Register the function
self.register_for_llm(name=name, description=description)(func)

def _register_update_agent_state_before_reply(
self, functions: Optional[Union[list[Callable[..., Any]], Callable[..., Any]]]
Expand Down Expand Up @@ -3038,7 +3023,7 @@ def function_map(self) -> dict[str, Callable[..., Any]]:
"""Return the function map."""
return self._function_map

def _wrap_function(self, func: F, inject_params: dict[str, Any] = {}) -> F:
def _wrap_function(self, func: F, inject_params: dict[str, Any] = {}, *, serialize: bool = True) -> F:
"""Wrap the function inject chat context parameters and to dump the return value to json.

Handles both sync and async functions.
Expand All @@ -3057,15 +3042,15 @@ def _wrapped_func(*args, **kwargs):
retval = func(*args, **kwargs, **inject_params)
if logging_enabled():
log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
return serialize_to_str(retval) if serialize else retval

@load_basemodels_if_needed
@functools.wraps(func)
async def _a_wrapped_func(*args, **kwargs):
retval = await func(*args, **kwargs, **inject_params)
if logging_enabled():
log_function_use(self, func, kwargs, retval)
return serialize_to_str(retval)
return serialize_to_str(retval) if serialize else retval

wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func

Expand Down Expand Up @@ -3181,6 +3166,8 @@ def register_for_execution(
self,
name: Optional[str] = None,
description: Optional[str] = None,
*,
serialize: bool = True,
) -> Callable[[Union[Tool, F]], Tool]:
"""Decorator factory for registering a function to be executed by an agent.

Expand Down Expand Up @@ -3223,7 +3210,9 @@ def _decorator(
chat_context = ChatContext(self)
chat_context_params = {param: chat_context for param in tool._chat_context_param_names}

self.register_function({tool.name: self._wrap_function(tool.func, chat_context_params)})
self.register_function({
tool.name: self._wrap_function(tool.func, chat_context_params, serialize=serialize)
})

return tool

Expand Down
5 changes: 3 additions & 2 deletions autogen/tools/dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"get_context_params",
"inject_params",
"on",
"remove_params",
]


Expand Down Expand Up @@ -135,7 +136,7 @@ def _is_depends_param(param: inspect.Parameter) -> bool:
)


def _remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
def remove_params(func: Callable[..., Any], sig: inspect.Signature, params: Iterable[str]) -> None:
new_signature = sig.replace(parameters=[p for p in sig.parameters.values() if p.name not in params])
func.__signature__ = new_signature # type: ignore[attr-defined]

Expand All @@ -147,7 +148,7 @@ def _remove_injected_params_from_signature(func: Callable[..., Any]) -> Callable

sig = inspect.signature(func)
params_to_remove = [p.name for p in sig.parameters.values() if _is_context_param(p) or _is_depends_param(p)]
_remove_params(func, sig, params_to_remove)
remove_params(func, sig, params_to_remove)
return func


Expand Down
Loading
Loading