Skip to content
Open
149 changes: 88 additions & 61 deletions owl/utils/enhanced_role_playing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from typing import Dict, List, Optional, Tuple
import threading


from camel.agents import ChatAgent
Expand All @@ -38,6 +39,8 @@ def __init__(self, **kwargs):
self.assistant_agent_kwargs: dict = kwargs.get("assistant_agent_kwargs", {})

self.output_language = kwargs.get("output_language", None)

self.stop_event = kwargs.get("stop_event", None)

super().__init__(**kwargs)

Expand All @@ -62,6 +65,7 @@ def __init__(self, **kwargs):
user_agent_kwargs=self.user_agent_kwargs,
output_language=self.output_language,
# is_reasoning_task=self.is_reasoning_task
stop_event=self.stop_event
)

def _init_agents(
Expand All @@ -72,6 +76,7 @@ def _init_agents(
user_agent_kwargs: Optional[Dict] = None,
output_language: Optional[str] = None,
is_reasoning_task: bool = False,
stop_event: Optional[threading.Event] = None,
) -> None:
r"""Initialize assistant and user agents with their system messages.

Expand All @@ -86,6 +91,9 @@ def _init_agents(
pass to the user agent. (default: :obj:`None`)
output_language (str, optional): The language to be output by the
agents. (default: :obj:`None`)
stop_event (Optional[threading.Event], optional): Event to signal
termination of the agent's operation. When set, the agent will
terminate its execution. (default: :obj:`None`)
"""
if self.model is not None:
if assistant_agent_kwargs is None:
Expand All @@ -107,13 +115,15 @@ def _init_agents(
self.assistant_agent = ChatAgent(
init_assistant_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(assistant_agent_kwargs or {}),
)
self.assistant_sys_msg = self.assistant_agent.system_message

self.user_agent = ChatAgent(
init_user_sys_msg,
output_language=output_language,
stop_event=stop_event,
**(user_agent_kwargs or {}),
)
self.user_sys_msg = self.user_agent.system_message
Expand Down Expand Up @@ -217,12 +227,8 @@ def step(
user_response = self.user_agent.step(assistant_msg)
if user_response.terminated or user_response.msgs is None:
return (
ChatAgentResponse(msgs=[], terminated=False, info={}),
ChatAgentResponse(
msgs=[],
terminated=user_response.terminated,
info=user_response.info,
),
ChatAgentResponse(msgs=[assistant_msg], terminated=False, info={}),
user_response
)
user_msg = self._reduce_message_options(user_response.msgs)

Expand All @@ -247,13 +253,9 @@ def step(
assistant_response = self.assistant_agent.step(modified_user_msg)
if assistant_response.terminated or assistant_response.msgs is None:
return (
assistant_response,
ChatAgentResponse(
msgs=[],
terminated=assistant_response.terminated,
info=assistant_response.info,
),
ChatAgentResponse(
msgs=[user_msg], terminated=False, info=user_response.info
msgs=[modified_user_msg], terminated=False, info=user_response.info
),
)
assistant_msg = self._reduce_message_options(assistant_response.msgs)
Expand Down Expand Up @@ -436,10 +438,10 @@ def step(
),
)


def run_society(
society: OwlRolePlaying,
round_limit: int = 15,
stop_event: Optional[threading.Event] = None
) -> Tuple[str, List[dict], dict]:
overall_completion_token_count = 0
overall_prompt_token_count = 0
Expand All @@ -448,58 +450,83 @@ def run_society(
init_prompt = """
Now please give me instructions to solve over overall task step by step. If the task requires some specific knowledge, please instruct me to use tools to complete the task.
"""
input_msg = society.init_chat(init_prompt)
for _round in range(round_limit):
assistant_response, user_response = society.step(input_msg)
# Check if usage info is available before accessing it
if assistant_response.info.get("usage") and user_response.info.get("usage"):
overall_completion_token_count += assistant_response.info["usage"].get(
"completion_tokens", 0
) + user_response.info["usage"].get("completion_tokens", 0)
overall_prompt_token_count += assistant_response.info["usage"].get(
"prompt_tokens", 0
) + user_response.info["usage"].get("prompt_tokens", 0)

# convert tool call to dict
tool_call_records: List[dict] = []
if assistant_response.info.get("tool_calls"):
for tool_call in assistant_response.info["tool_calls"]:
tool_call_records.append(tool_call.as_dict())

_data = {
"user": user_response.msg.content
if hasattr(user_response, "msg") and user_response.msg
else "",
"assistant": assistant_response.msg.content
if hasattr(assistant_response, "msg") and assistant_response.msg
else "",
"tool_calls": tool_call_records,
}
society.stop_event = stop_event

try:
input_msg = society.init_chat(init_prompt)
for _round in range(round_limit):
assistant_response, user_response = society.step(input_msg)
# Check if usage info is available before accessing it
if assistant_response.info.get("usage") and user_response.info.get("usage"):
overall_completion_token_count += assistant_response.info["usage"].get(
"completion_tokens", 0
) + user_response.info["usage"].get("completion_tokens", 0)
overall_prompt_token_count += assistant_response.info["usage"].get(
"prompt_tokens", 0
) + user_response.info["usage"].get("prompt_tokens", 0)

# convert tool call to dict
tool_call_records: List[dict] = []
if assistant_response.info.get("tool_calls"):
for tool_call in assistant_response.info["tool_calls"]:
tool_call_records.append(tool_call.as_dict())

_data = {
"user": user_response.msg.content
if hasattr(user_response, "msg") and user_response.msg
else "",
"assistant": assistant_response.msg.content
if hasattr(assistant_response, "msg") and assistant_response.msg
else "",
"tool_calls": tool_call_records,
}

chat_history.append(_data)
logger.info(
f"Round #{_round} user_response:\n {user_response.msgs[0].content if user_response.msgs and len(user_response.msgs) > 0 else ''}"
)
logger.info(
f"Round #{_round} assistant_response:\n {assistant_response.msgs[0].content if assistant_response.msgs and len(assistant_response.msgs) > 0 else ''}"
)

chat_history.append(_data)
logger.info(
f"Round #{_round} user_response:\n {user_response.msgs[0].content if user_response.msgs and len(user_response.msgs) > 0 else ''}"
)
logger.info(
f"Round #{_round} assistant_response:\n {assistant_response.msgs[0].content if assistant_response.msgs and len(assistant_response.msgs) > 0 else ''}"
)
if (
assistant_response.terminated
or user_response.terminated
or "TASK_DONE" in user_response.msg.content
or (stop_event and stop_event.is_set())
):
break

if (
assistant_response.terminated
or user_response.terminated
or "TASK_DONE" in user_response.msg.content
):
break

input_msg = assistant_response.msg
input_msg = assistant_response.msg

answer = chat_history[-1]["assistant"]
token_info = {
"completion_token_count": overall_completion_token_count,
"prompt_token_count": overall_prompt_token_count,
}
answer = chat_history[-1]["assistant"] if chat_history else ""
token_info = {
"completion_token_count": overall_completion_token_count,
"prompt_token_count": overall_prompt_token_count,
}

return answer, chat_history, token_info
return answer, chat_history, token_info

except Exception as e:
logger.error(f"Exception in run_society: {e}")
# Add empty results for proper return type in case of error
answer = f"Error: {str(e)}"
token_info = {
"completion_token_count": overall_completion_token_count,
"prompt_token_count": overall_prompt_token_count,
}
# Re-raise after cleanup
raise

finally:
# Always attempt to terminate browser, regardless of how we exit the function
if hasattr(society, 'assistant_agent') and hasattr(society.assistant_agent, 'tool_dict') and society.assistant_agent.tool_dict and 'terminate_browser' in society.assistant_agent.tool_dict:
try:
flag, msg = society.assistant_agent.tool_dict['terminate_browser']()
logger.info(f"Browser termination result: success={flag}, message='{msg}'")
except Exception as term_error:
logger.error(f"Failed to terminate browser: {term_error}")
# We don't re-raise browser termination errors to ensure the original error (if any) is preserved


async def arun_society(
Expand Down
Loading