Skip to content
Open
21 changes: 21 additions & 0 deletions camel/societies/workforce/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class WorkforceEventBase(BaseModel):
model_config = ConfigDict(frozen=True, extra='forbid')
event_type: Literal[
"log",
"task_decomposed",
"task_created",
"task_assigned",
Expand All @@ -39,6 +40,25 @@ class WorkforceEventBase(BaseModel):
)


class LogEvent(WorkforceEventBase):
event_type: Literal["log"] = "log"
message: str
level: Literal["debug", "info", "warning", "error", "critical"]
color: (
Literal[
"red",
"green",
"yellow",
"blue",
"cyan",
"magenta",
"gray",
"black",
]
| None
) = None


class WorkerCreatedEvent(WorkforceEventBase):
event_type: Literal["worker_created"] = "worker_created"
worker_id: str
Expand Down Expand Up @@ -109,6 +129,7 @@ class QueueStatusEvent(WorkforceEventBase):


WorkforceEvent = Union[
LogEvent,
TaskDecomposedEvent,
TaskCreatedEvent,
TaskAssignedEvent,
Expand Down
38 changes: 29 additions & 9 deletions camel/societies/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@

from .events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand Down Expand Up @@ -3813,10 +3814,17 @@ async def _post_task(self, task: Task, assignee_id: str) -> None:
logger.error(
f"Failed to post task {task.id} to {assignee_id}: {e}"
)
print(
f"{Fore.RED}Failed to post task {task.id} to {assignee_id}: "
f"{e}{Fore.RESET}"
)
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=(
f"Failed to post task {task.id} to {assignee_id}: "
f"{e}"
),
level="error",
color="red",
)
)

async def _post_dependency(self, dependency: Task) -> None:
await self._channel.post_dependency(dependency, self.node_id)
Expand Down Expand Up @@ -3943,7 +3951,12 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
)
new_node.set_channel(self._channel)

print(f"{Fore.CYAN}{new_node} created.{Fore.RESET}")
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=f"{new_node} created.", level="info", color="cyan"
)
)

self._children.append(new_node)

Expand Down Expand Up @@ -4505,10 +4518,17 @@ async def _handle_completed_task(self, task: Task) -> None:
tasks_list.pop(i)
self._pending_tasks = deque(tasks_list)
found_and_removed = True
print(
f"{Fore.GREEN}✅ Task {task.id} completed and removed "
f"from queue.{Fore.RESET}"
)
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=(
f"✅ Task {task.id} completed and removed "
f"from queue."
),
level="info",
color="green",
)
)
break

if not found_and_removed:
Expand Down
28 changes: 28 additions & 0 deletions camel/societies/workforce/workforce_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
from __future__ import annotations

import typing
from abc import ABC, abstractmethod

from colorama import Fore

from .events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand All @@ -34,6 +38,30 @@ class WorkforceCallback(ABC):
Implementations should persist or stream events as appropriate.
"""

__COLOR_MAP: typing.ClassVar = {
"yellow": Fore.YELLOW,
"red": Fore.RED,
"green": Fore.GREEN,
"cyan": Fore.CYAN,
"magenta": Fore.MAGENTA,
"gray": Fore.LIGHTBLACK_EX,
"black": Fore.BLACK,
}

def _get_color_message(self, event: LogEvent) -> str:
r"""Gets a colored message for a log event."""
if event.color is None or event.color not in self.__COLOR_MAP:
return event.message
color = self.__COLOR_MAP.get(event.color)
return f"{color}{event.message}{Fore.RESET}"

@abstractmethod
def log_message(
self,
event: LogEvent,
) -> None:
pass

@abstractmethod
def log_task_created(
self,
Expand Down
6 changes: 6 additions & 0 deletions camel/societies/workforce/workforce_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from camel.logger import get_logger
from camel.societies.workforce.events import (
AllTasksCompletedEvent,
LogEvent,
QueueStatusEvent,
TaskAssignedEvent,
TaskCompletedEvent,
Expand Down Expand Up @@ -50,6 +51,11 @@ def __init__(self, workforce_id: str):
self._worker_information: Dict[str, Dict[str, Any]] = {}
self._initial_worker_logs: List[Dict[str, Any]] = []

def log_message(self, event: LogEvent) -> None:
r"""Logs a message to the console with color."""
colored_message = self._get_color_message(event)
print(colored_message)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted back to using print instead of logger.xxx


def _log_event(self, event_type: str, **kwargs: Any) -> None:
r"""Internal method to create and store a log entry.

Expand Down
12 changes: 11 additions & 1 deletion examples/workforce/workforce_callbacks_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from camel.models import ModelFactory
from camel.societies.workforce.events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand All @@ -38,6 +39,7 @@
)
from camel.societies.workforce.workforce import Workforce
from camel.societies.workforce.workforce_callback import WorkforceCallback
from camel.societies.workforce.workforce_logger import WorkforceLogger
from camel.types import ModelPlatformType, ModelType

logger = get_logger(__name__)
Expand All @@ -46,6 +48,12 @@
class PrintCallback(WorkforceCallback):
r"""Simple callback printing events to logs to observe ordering."""

def log_message(self, event: LogEvent) -> None:
print(
f"[PrintCallback] {event.message} level={event.level}, "
f"color={event.color}"
)

def log_task_created(self, event: TaskCreatedEvent) -> None:
print(
f"[PrintCallback] task_created: id={event.task_id}, "
Expand Down Expand Up @@ -115,7 +123,9 @@ def build_student_agent() -> ChatAgent:


async def run_demo() -> None:
callbacks = [PrintCallback()]
logger_cb = WorkforceLogger('demo-logger')
print_cb = PrintCallback()
callbacks = [logger_cb, print_cb]

workforce = Workforce(
"Workforce Callbacks Demo",
Expand Down
Loading