Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions camel/societies/workforce/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
class WorkforceEventBase(BaseModel):
model_config = ConfigDict(frozen=True, extra='forbid')
event_type: Literal[
"log",
"task_decomposed",
"task_created",
"task_assigned",
Expand All @@ -39,6 +40,25 @@ class WorkforceEventBase(BaseModel):
)


class LogEvent(WorkforceEventBase):
event_type: Literal["log"] = "log"
message: str
level: Literal["debug", "info", "warning", "error", "critical"]
color: (
Literal[
"red",
"green",
"yellow",
"blue",
"cyan",
"magenta",
"gray",
"black",
]
| None
) = None


class WorkerCreatedEvent(WorkforceEventBase):
event_type: Literal["worker_created"] = "worker_created"
worker_id: str
Expand Down Expand Up @@ -109,6 +129,7 @@ class QueueStatusEvent(WorkforceEventBase):


WorkforceEvent = Union[
LogEvent,
TaskDecomposedEvent,
TaskCreatedEvent,
TaskAssignedEvent,
Expand Down
38 changes: 29 additions & 9 deletions camel/societies/workforce/workforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@

from .events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand Down Expand Up @@ -3803,10 +3804,17 @@ async def _post_task(self, task: Task, assignee_id: str) -> None:
logger.error(
f"Failed to post task {task.id} to {assignee_id}: {e}"
)
print(
f"{Fore.RED}Failed to post task {task.id} to {assignee_id}: "
f"{e}{Fore.RESET}"
)
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=(
f"Failed to post task {task.id} to {assignee_id}: "
f"{e}"
),
level="error",
color="red",
)
)

async def _post_dependency(self, dependency: Task) -> None:
await self._channel.post_dependency(dependency, self.node_id)
Expand Down Expand Up @@ -3933,7 +3941,12 @@ async def _create_worker_node_for_task(self, task: Task) -> Worker:
)
new_node.set_channel(self._channel)

print(f"{Fore.CYAN}{new_node} created.{Fore.RESET}")
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=f"{new_node} created.", level="info", color="cyan"
)
)

self._children.append(new_node)

Expand Down Expand Up @@ -4495,10 +4508,17 @@ async def _handle_completed_task(self, task: Task) -> None:
tasks_list.pop(i)
self._pending_tasks = deque(tasks_list)
found_and_removed = True
print(
f"{Fore.GREEN}✅ Task {task.id} completed and removed "
f"from queue.{Fore.RESET}"
)
for cb in self._callbacks:
cb.log_message(
LogEvent(
message=(
f"✅ Task {task.id} completed and removed "
f"from queue."
),
level="info",
color="green",
)
)
break

if not found_and_removed:
Expand Down
8 changes: 8 additions & 0 deletions camel/societies/workforce/workforce_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand All @@ -34,6 +35,13 @@ class WorkforceCallback(ABC):
Implementations should persist or stream events as appropriate.
"""

@abstractmethod
def log_message(
self,
event: LogEvent,
) -> None:
pass

@abstractmethod
def log_task_created(
self,
Expand Down
34 changes: 34 additions & 0 deletions camel/societies/workforce/workforce_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

from colorama import Fore

from camel.logger import get_logger
from camel.societies.workforce.events import (
AllTasksCompletedEvent,
LogEvent,
QueueStatusEvent,
TaskAssignedEvent,
TaskCompletedEvent,
Expand All @@ -34,6 +37,16 @@

logger = get_logger(__name__)

_COLOR_MAP = {
"yellow": Fore.YELLOW,
"red": Fore.RED,
"green": Fore.GREEN,
"cyan": Fore.CYAN,
"magenta": Fore.MAGENTA,
"gray": Fore.LIGHTBLACK_EX,
"black": Fore.BLACK,
}


class WorkforceLogger(WorkforceCallback, WorkforceMetrics):
r"""Logs events and metrics for a Workforce instance."""
Expand All @@ -50,6 +63,27 @@ def __init__(self, workforce_id: str):
self._worker_information: Dict[str, Dict[str, Any]] = {}
self._initial_worker_logs: List[Dict[str, Any]] = []

def _get_color_message(self, event: LogEvent) -> str:
r"""Gets a colored message for a log event."""
if event.color is None or event.color not in _COLOR_MAP:
return event.message
color = _COLOR_MAP.get(event.color)
return f"{color}{event.message}{Fore.RESET}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move _get_color_message and _COLOR_MAP into WorkforceCallback. Also rename _COLOR_MAP to __COLOR_MAP to mark it as a private attribute


def log_message(self, event: LogEvent) -> None:
r"""Logs a message to the console with color."""
colored_message = self._get_color_message(event)
if event.level == 'debug':
logger.debug(colored_message)
if event.level == 'info':
logger.info(colored_message)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: the debug and info logs don't show on console out due to minimum=WARNING

if event.level == 'warning':
logger.warning(colored_message)
if event.level == 'error':
logger.error(colored_message)
if event.level == 'critical':
logger.critical(colored_message)

def _log_event(self, event_type: str, **kwargs: Any) -> None:
r"""Internal method to create and store a log entry.

Expand Down
12 changes: 11 additions & 1 deletion examples/workforce/workforce_callbacks_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from camel.models import ModelFactory
from camel.societies.workforce.events import (
AllTasksCompletedEvent,
LogEvent,
TaskAssignedEvent,
TaskCompletedEvent,
TaskCreatedEvent,
Expand All @@ -38,6 +39,7 @@
)
from camel.societies.workforce.workforce import Workforce
from camel.societies.workforce.workforce_callback import WorkforceCallback
from camel.societies.workforce.workforce_logger import WorkforceLogger
from camel.types import ModelPlatformType, ModelType

logger = get_logger(__name__)
Expand All @@ -46,6 +48,12 @@
class PrintCallback(WorkforceCallback):
r"""Simple callback printing events to logs to observe ordering."""

def log_message(self, event: LogEvent) -> None:
print(
f"[PrintCallback] {event.message} level={event.level}, "
f"color={event.color}"
)

def log_task_created(self, event: TaskCreatedEvent) -> None:
print(
f"[PrintCallback] task_created: id={event.task_id}, "
Expand Down Expand Up @@ -115,7 +123,9 @@ def build_student_agent() -> ChatAgent:


async def run_demo() -> None:
callbacks = [PrintCallback()]
logger_cb = WorkforceLogger('demo-logger')
print_cb = PrintCallback()
callbacks = [logger_cb, print_cb]

workforce = Workforce(
"Workforce Callbacks Demo",
Expand Down
Loading