Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tracking SelectorGroupChat tokens with new Message type #4768

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AsyncGenerator, List, Protocol, Sequence

from autogen_core import CancellationToken
from autogen_core.models._types import RequestUsage

from ..messages import AgentEvent, ChatMessage

Expand All @@ -16,6 +17,9 @@ class TaskResult:
stop_reason: str | None = None
"""The reason the task stopped."""

usage: RequestUsage | None = None
"""The usage of the task."""


class TaskRunner(Protocol):
"""A task runner."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,23 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"


class UsageEvent(BaseMessage):
"""An event signaling the usage of a model."""

content: str = ""
"""The content of the usage event."""

models_usage: RequestUsage
"""The model client usage incurred when producing this message."""

type: Literal["UsageEvent"] = "UsageEvent"


ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent, Field(discriminator="type")]
AgentEvent = Annotated[ToolCallRequestEvent | ToolCallExecutionEvent | UsageEvent, Field(discriminator="type")]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TypeSubscription,
)
from autogen_core._closure_agent import ClosureContext
from autogen_core.models._types import RequestUsage

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(
# Flag to track if the group chat is running.
self._is_running = False

self._total_usage = RequestUsage(prompt_tokens=0, completion_tokens=0)

@abstractmethod
def _create_group_chat_manager_factory(
self,
Expand Down Expand Up @@ -418,8 +421,14 @@ async def stop_runtime() -> None:
yield message
output_messages.append(message)

usage = RequestUsage(prompt_tokens=0, completion_tokens=0)
for message in output_messages:
if message.models_usage:
usage.prompt_tokens += message.models_usage.prompt_tokens
usage.completion_tokens += message.models_usage.completion_tokens

# Yield the final result.
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason)
yield TaskResult(messages=output_messages, stop_reason=self._stop_reason, usage=usage)

finally:
# Wait for the shutdown task to finish.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
import re
from typing import Any, Callable, Dict, List, Mapping, Sequence

from autogen_core._default_topic import DefaultTopicId
from autogen_core.models import ChatCompletionClient, SystemMessage

from autogen_agentchat.teams._group_chat._events import GroupChatMessage

from ... import TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import (
Expand All @@ -15,6 +18,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
UsageEvent,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -152,6 +156,12 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
agent_name = participants[0]
self._previous_speaker = agent_name
trace_logger.debug(f"Selected speaker: {agent_name}")

await self.publish_message(
GroupChatMessage(message=UsageEvent(source=self._id._type, models_usage=response.usage)),
topic_id=DefaultTopicId(type=self._output_topic_type),
)

return agent_name

def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from autogen_core.models import RequestUsage

from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UsageEvent


def _is_running_in_iterm() -> bool:
Expand Down Expand Up @@ -90,7 +90,9 @@ async def Console(
sys.stdout.flush()
# mypy ignore
last_processed = message # type: ignore

elif isinstance(message, UsageEvent):
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
Expand Down