|
1 | 1 | import asyncio
|
2 | 2 | import logging
|
3 |
| -from enum import Enum |
4 | 3 | from typing import Any, Dict, List, Optional, Union
|
5 | 4 |
|
6 | 5 | from dapr_agents.agents.base import AgentBase
|
7 | 6 | from dapr_agents.types import (
|
8 | 7 | AgentError,
|
9 |
| - AssistantMessage, |
10 |
| - ChatCompletion, |
11 | 8 | ToolCall,
|
12 | 9 | ToolExecutionRecord,
|
13 | 10 | ToolMessage,
|
14 | 11 | UserMessage,
|
| 12 | + LLMChatResponse, |
15 | 13 | )
|
16 | 14 |
|
17 | 15 | logger = logging.getLogger(__name__)
|
18 | 16 |
|
19 | 17 |
|
20 |
| -class FinishReason(str, Enum): |
21 |
| - STOP = "stop" |
22 |
| - LENGTH = "length" |
23 |
| - CONTENT_FILTER = "content_filter" |
24 |
| - TOOL_CALLS = "tool_calls" |
25 |
| - FUNCTION_CALL = "function_call" # deprecated |
26 |
| - |
27 |
| - |
28 | 18 | class Agent(AgentBase):
|
29 | 19 | """
|
30 | 20 | Agent that manages tool calls and conversations using a language model.
|
@@ -168,8 +158,10 @@ async def run_and_record(tool_call: ToolCall) -> ToolMessage:
|
168 | 158 | name=function_name,
|
169 | 159 | content=result_str,
|
170 | 160 | )
|
171 |
| - # Printing the tool message for visibility |
| 161 | + # Print the tool message for visibility |
172 | 162 | self.text_formatter.print_message(tool_message)
|
| 163 | + # Add tool message to memory |
| 164 | + self.memory.add_message(tool_message) |
173 | 165 | # Append tool message to the persistent audit log
|
174 | 166 | tool_execution_record = ToolExecutionRecord(
|
175 | 167 | tool_call_id=tool_id,
|
@@ -201,69 +193,54 @@ async def process_iterations(self, messages: List[Dict[str, Any]]) -> Any:
|
201 | 193 | Raises:
|
202 | 194 | AgentError: On chat failure or tool issues.
|
203 | 195 | """
|
204 |
| - for iteration in range(self.max_iterations): |
205 |
| - logger.info(f"Iteration {iteration + 1}/{self.max_iterations} started.") |
206 |
| - |
| 196 | + final_reply = None |
| 197 | + for turn in range(1, self.max_iterations + 1): |
| 198 | + logger.info(f"Iteration {turn}/{self.max_iterations} started.") |
207 | 199 | try:
|
208 | 200 | # Generate response using the LLM
|
209 |
| - response = self.llm.generate( |
| 201 | + response: LLMChatResponse = self.llm.generate( |
210 | 202 | messages=messages,
|
211 | 203 | tools=self.get_llm_tools(),
|
212 | 204 | tool_choice=self.tool_choice,
|
213 | 205 | )
|
214 |
| - # If response is a dict, convert to ChatCompletion |
215 |
| - if isinstance(response, dict): |
216 |
| - response = ChatCompletion(**response) |
217 |
| - elif not isinstance(response, ChatCompletion): |
218 |
| - # If response is an iterator (stream), raise TypeError |
219 |
| - raise TypeError(f"Expected ChatCompletion, got {type(response)}") |
220 |
| - # Get the response message and print it |
| 206 | + # Get the first candidate from the response |
221 | 207 | response_message = response.get_message()
|
222 |
| - if response_message is not None: |
223 |
| - self.text_formatter.print_message(response_message) |
224 |
| - |
225 |
| - # Get Reason for the response |
226 |
| - reason = FinishReason(response.get_reason()) |
| 208 | + # Check if the response contains an assistant message |
| 209 | + if response_message is None: |
| 210 | + raise AgentError("LLM returned no assistant message") |
| 211 | + else: |
| 212 | + assistant = response_message |
| 213 | + self.text_formatter.print_message(assistant) |
| 214 | + self.memory.add_message(assistant) |
227 | 215 |
|
228 | 216 | # Handle tool calls response
|
229 |
| - if reason == FinishReason.TOOL_CALLS: |
230 |
| - tool_calls = response.get_tool_calls() |
| 217 | + if assistant is not None and assistant.has_tool_calls(): |
| 218 | + tool_calls = assistant.get_tool_calls() |
231 | 219 | if tool_calls:
|
232 |
| - # Add the assistant message with tool calls to the conversation |
233 |
| - if response_message is not None: |
234 |
| - messages.append(response_message) |
235 |
| - # Execute tools and collect results for this iteration only |
236 |
| - tool_messages = await self.execute_tools(tool_calls) |
237 |
| - # Add tool results to messages for the next iteration |
238 |
| - messages.extend([tm.model_dump() for tm in tool_messages]) |
239 |
| - # Continue to next iteration to let LLM process tool results |
| 220 | + messages.append(assistant.model_dump()) |
| 221 | + tool_msgs = await self.execute_tools(tool_calls) |
| 222 | + messages.extend([tm.model_dump() for tm in tool_msgs]) |
| 223 | + if turn == self.max_iterations: |
| 224 | + final_reply = assistant |
| 225 | + logger.info("Reached max turns after tool calls; stopping.") |
| 226 | + break |
240 | 227 | continue
|
241 |
| - # Handle stop response |
242 |
| - elif reason == FinishReason.STOP: |
243 |
| - # Append AssistantMessage to memory |
244 |
| - msg = AssistantMessage(content=response.get_content() or "") |
245 |
| - self.memory.add_message(msg) |
246 |
| - return msg.content |
247 |
| - # Handle Function call response |
248 |
| - elif reason == FinishReason.FUNCTION_CALL: |
249 |
| - logger.warning( |
250 |
| - "LLM returned a deprecated function_call. Function calls are not processed by this agent." |
251 |
| - ) |
252 |
| - msg = AssistantMessage( |
253 |
| - content="Function calls are not supported or processed by this agent." |
254 |
| - ) |
255 |
| - self.memory.add_message(msg) |
256 |
| - return msg.content |
257 |
| - else: |
258 |
| - logger.error(f"Unknown finish reason: {reason}") |
259 |
| - raise AgentError(f"Unknown finish reason: {reason}") |
| 228 | + |
| 229 | + # No tool calls => done |
| 230 | + final_reply = assistant |
| 231 | + break |
260 | 232 |
|
261 | 233 | except Exception as e:
|
262 |
| - logger.error(f"Error during chat generation: {e}") |
| 234 | + logger.error(f"Error on turn {turn}: {e}") |
263 | 235 | raise AgentError(f"Failed during chat generation: {e}") from e
|
264 | 236 |
|
265 |
| - logger.info("Max iterations reached. Agent has stopped.") |
266 |
| - return None |
| 237 | + # Post-loop |
| 238 | + if final_reply is None: |
| 239 | + logger.warning("No reply generated; hitting max iterations.") |
| 240 | + return None |
| 241 | + |
| 242 | + logger.info(f"Agent conversation completed after {turn} turns.") |
| 243 | + return final_reply |
267 | 244 |
|
268 | 245 | async def run_tool(self, tool_name: str, *args, **kwargs) -> Any:
|
269 | 246 | """
|
|
0 commit comments