Skip to content

Commit f2d6831

Browse files
yaron2Cyb3rWard0g
andauthored
Refactor LLM Workflows and Orchestrators for Unified Response Handling and Iteration (#163) (#165)
* Refactor ChatClientBase: drop Pydantic inheritance and add typed generate() overloads * Align all LLM chat clients with refactored base and unified response models * Unify LLM utils across providers and delegate streaming/response to provider‑specific handlers * Refactor LLM pipeline: add HuggingFace tool calls, unify chat client/response types, and switch DurableAgent to loop‑based workflow * Refactor orchestrators with loops and unify LLM response handling using LLMChatResponse * test remaining quickstarts after all changes * run pytest after all changes * Run linting and formatting checks to ensure code quality * Update logging, Orchestrator Name and OTel module name --------- Signed-off-by: Roberto Rodriguez <[email protected]> Co-authored-by: Roberto Rodriguez <[email protected]>
1 parent 3e767e0 commit f2d6831

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+3131
-2038
lines changed

dapr_agents/agents/agent/agent.py

Lines changed: 37 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
11
import asyncio
22
import logging
3-
from enum import Enum
43
from typing import Any, Dict, List, Optional, Union
54

65
from dapr_agents.agents.base import AgentBase
76
from dapr_agents.types import (
87
AgentError,
9-
AssistantMessage,
10-
ChatCompletion,
118
ToolCall,
129
ToolExecutionRecord,
1310
ToolMessage,
1411
UserMessage,
12+
LLMChatResponse,
1513
)
1614

1715
logger = logging.getLogger(__name__)
1816

1917

20-
class FinishReason(str, Enum):
21-
STOP = "stop"
22-
LENGTH = "length"
23-
CONTENT_FILTER = "content_filter"
24-
TOOL_CALLS = "tool_calls"
25-
FUNCTION_CALL = "function_call" # deprecated
26-
27-
2818
class Agent(AgentBase):
2919
"""
3020
Agent that manages tool calls and conversations using a language model.
@@ -168,8 +158,10 @@ async def run_and_record(tool_call: ToolCall) -> ToolMessage:
168158
name=function_name,
169159
content=result_str,
170160
)
171-
# Printing the tool message for visibility
161+
# Print the tool message for visibility
172162
self.text_formatter.print_message(tool_message)
163+
# Add tool message to memory
164+
self.memory.add_message(tool_message)
173165
# Append tool message to the persistent audit log
174166
tool_execution_record = ToolExecutionRecord(
175167
tool_call_id=tool_id,
@@ -201,69 +193,54 @@ async def process_iterations(self, messages: List[Dict[str, Any]]) -> Any:
201193
Raises:
202194
AgentError: On chat failure or tool issues.
203195
"""
204-
for iteration in range(self.max_iterations):
205-
logger.info(f"Iteration {iteration + 1}/{self.max_iterations} started.")
206-
196+
final_reply = None
197+
for turn in range(1, self.max_iterations + 1):
198+
logger.info(f"Iteration {turn}/{self.max_iterations} started.")
207199
try:
208200
# Generate response using the LLM
209-
response = self.llm.generate(
201+
response: LLMChatResponse = self.llm.generate(
210202
messages=messages,
211203
tools=self.get_llm_tools(),
212204
tool_choice=self.tool_choice,
213205
)
214-
# If response is a dict, convert to ChatCompletion
215-
if isinstance(response, dict):
216-
response = ChatCompletion(**response)
217-
elif not isinstance(response, ChatCompletion):
218-
# If response is an iterator (stream), raise TypeError
219-
raise TypeError(f"Expected ChatCompletion, got {type(response)}")
220-
# Get the response message and print it
206+
# Get the first candidate from the response
221207
response_message = response.get_message()
222-
if response_message is not None:
223-
self.text_formatter.print_message(response_message)
224-
225-
# Get Reason for the response
226-
reason = FinishReason(response.get_reason())
208+
# Check if the response contains an assistant message
209+
if response_message is None:
210+
raise AgentError("LLM returned no assistant message")
211+
else:
212+
assistant = response_message
213+
self.text_formatter.print_message(assistant)
214+
self.memory.add_message(assistant)
227215

228216
# Handle tool calls response
229-
if reason == FinishReason.TOOL_CALLS:
230-
tool_calls = response.get_tool_calls()
217+
if assistant is not None and assistant.has_tool_calls():
218+
tool_calls = assistant.get_tool_calls()
231219
if tool_calls:
232-
# Add the assistant message with tool calls to the conversation
233-
if response_message is not None:
234-
messages.append(response_message)
235-
# Execute tools and collect results for this iteration only
236-
tool_messages = await self.execute_tools(tool_calls)
237-
# Add tool results to messages for the next iteration
238-
messages.extend([tm.model_dump() for tm in tool_messages])
239-
# Continue to next iteration to let LLM process tool results
220+
messages.append(assistant.model_dump())
221+
tool_msgs = await self.execute_tools(tool_calls)
222+
messages.extend([tm.model_dump() for tm in tool_msgs])
223+
if turn == self.max_iterations:
224+
final_reply = assistant
225+
logger.info("Reached max turns after tool calls; stopping.")
226+
break
240227
continue
241-
# Handle stop response
242-
elif reason == FinishReason.STOP:
243-
# Append AssistantMessage to memory
244-
msg = AssistantMessage(content=response.get_content() or "")
245-
self.memory.add_message(msg)
246-
return msg.content
247-
# Handle Function call response
248-
elif reason == FinishReason.FUNCTION_CALL:
249-
logger.warning(
250-
"LLM returned a deprecated function_call. Function calls are not processed by this agent."
251-
)
252-
msg = AssistantMessage(
253-
content="Function calls are not supported or processed by this agent."
254-
)
255-
self.memory.add_message(msg)
256-
return msg.content
257-
else:
258-
logger.error(f"Unknown finish reason: {reason}")
259-
raise AgentError(f"Unknown finish reason: {reason}")
228+
229+
# No tool calls => done
230+
final_reply = assistant
231+
break
260232

261233
except Exception as e:
262-
logger.error(f"Error during chat generation: {e}")
234+
logger.error(f"Error on turn {turn}: {e}")
263235
raise AgentError(f"Failed during chat generation: {e}") from e
264236

265-
logger.info("Max iterations reached. Agent has stopped.")
266-
return None
237+
# Post-loop
238+
if final_reply is None:
239+
logger.warning("No reply generated; hitting max iterations.")
240+
return None
241+
242+
logger.info(f"Agent conversation completed after {turn} turns.")
243+
return final_reply
267244

268245
async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
269246
"""

dapr_agents/agents/base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,11 @@
2626
ClassVar,
2727
)
2828
from pydantic import BaseModel, Field, PrivateAttr, model_validator, ConfigDict
29+
from dapr_agents.llm.chat import ChatClientBase
2930
from dapr_agents.llm.openai import OpenAIChatClient
30-
from dapr_agents.llm.huggingface import HFHubChatClient
31-
from dapr_agents.llm.nvidia import NVIDIAChatClient
32-
from dapr_agents.llm.dapr import DaprChatClient
3331

3432
logger = logging.getLogger(__name__)
3533

36-
# Type alias for all concrete chat client implementations
37-
ChatClientType = Union[
38-
OpenAIChatClient, HFHubChatClient, NVIDIAChatClient, DaprChatClient
39-
]
40-
4134

4235
class AgentBase(BaseModel, ABC):
4336
"""
@@ -73,7 +66,7 @@ class AgentBase(BaseModel, ABC):
7366
default=None,
7467
description="A custom system prompt, overriding name, role, goal, and instructions.",
7568
)
76-
llm: ChatClientType = Field(
69+
llm: ChatClientBase = Field(
7770
default_factory=OpenAIChatClient,
7871
description="Language model client for generating responses.",
7972
)

0 commit comments

Comments
 (0)