From d349011d6162291b0f817761ed3a1f864f1e26da Mon Sep 17 00:00:00 2001 From: gziz Date: Thu, 19 Dec 2024 21:41:42 -0800 Subject: [PATCH] Track SelectorGroup select_speaker tokens with new Message type --- .../src/autogen_agentchat/base/_task.py | 4 ++++ .../src/autogen_agentchat/messages.py | 14 +++++++++++++- .../teams/_group_chat/_base_group_chat.py | 11 ++++++++++- .../teams/_group_chat/_selector_group_chat.py | 10 ++++++++++ .../src/autogen_agentchat/ui/_console.py | 6 ++++-- 5 files changed, 41 insertions(+), 4 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index ecf05b170866..d6b25b91e95f 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -2,6 +2,7 @@ from typing import AsyncGenerator, List, Protocol, Sequence from autogen_core import CancellationToken +from autogen_core.models._types import RequestUsage from ..messages import AgentEvent, ChatMessage @@ -16,6 +17,9 @@ class TaskResult: stop_reason: str | None = None """The reason the task stopped.""" + usage: RequestUsage | None = None + """The usage of the task.""" + class TaskRunner(Protocol): """A task runner.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 7237812ca72f..83a3dc785d79 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -110,13 +110,25 @@ class ToolCallSummaryMessage(BaseMessage): type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage" +class UsageEvent(BaseMessage): + """An event signaling the usage of a model.""" + + content: str = "" + """The content of the usage event.""" + + models_usage: RequestUsage + """The model client usage incurred when producing this message.""" + + type: Literal["UsageEvent"] = "UsageEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] """Messages for agent-to-agent communication only.""" -AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")] +AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | UsageEvent, Field(discriminator="type")] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 1d8d30d586d4..e04c1ddb55f8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -16,6 +16,7 @@ TypeSubscription, ) from autogen_core._closure_agent import ClosureContext +from autogen_core.models._types import RequestUsage from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition @@ -74,6 +75,8 @@ def __init__( # Flag to track if the group chat is running. self._is_running = False + self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + @abstractmethod def _create_group_chat_manager_factory( self, @@ -418,8 +421,14 @@ async def stop_runtime() -> None: yield message output_messages.append(message) + usage = RequestUsage(prompt_tokens=0, completion_tokens=0) + for message in output_messages: + if message.models_usage: + usage.prompt_tokens += message.models_usage.prompt_tokens + usage.completion_tokens += message.models_usage.completion_tokens + # Yield the final result. - yield TaskResult(messages=output_messages, stop_reason=self._stop_reason) + yield TaskResult(messages=output_messages, stop_reason=self._stop_reason, usage=usage) finally: # Wait for the shutdown task to finish. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index be8ec726c301..c6fb01b1ffd9 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -2,8 +2,11 @@ import re from typing import Any, Callable, Dict, List, Mapping, Sequence +from autogen_core._default_topic import DefaultTopicId from autogen_core.models import ChatCompletionClient, SystemMessage +from autogen_agentchat.teams._group_chat._events import GroupChatMessage + from ... import TRACE_LOGGER_NAME from ...base import ChatAgent, TerminationCondition from ...messages import ( @@ -16,6 +19,7 @@ ToolCallExecutionEvent, ToolCallRequestEvent, ToolCallSummaryMessage, + UsageEvent, ) from ...state import SelectorManagerState from ._base_group_chat import BaseGroupChat @@ -153,6 +157,12 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: agent_name = participants[0] self._previous_speaker = agent_name trace_logger.debug(f"Selected speaker: {agent_name}") + + await self.publish_message( + GroupChatMessage(message=UsageEvent(source=self._id._type, models_usage=response.usage)), + topic_id=DefaultTopicId(type=self._output_topic_type), + ) + return agent_name def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 6315b504977c..d0930f71b4bd 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -7,7 +7,7 @@ from autogen_core.models import RequestUsage from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage +from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UsageEvent def _is_running_in_iterm() -> bool: @@ -90,7 +90,9 @@ async def Console( sys.stdout.flush() # mypy ignore last_processed = message # type: ignore - + elif isinstance(message, UsageEvent): + total_usage.completion_tokens += message.models_usage.completion_tokens + total_usage.prompt_tokens += message.models_usage.prompt_tokens else: # Cast required for mypy to be happy message = cast(AgentEvent | ChatMessage, message) # type: ignore