diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 37d4646c685b..8ef47806ac4d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -266,8 +266,8 @@ async def on_messages_stream( while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) # Add the tool call message to the output. - inner_messages.append(ToolCallMessage(content=result.content, source=self.name)) - yield ToolCallMessage(content=result.content, source=self.name) + inner_messages.append(ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage)) + yield ToolCallMessage(content=result.content, source=self.name, model_usage=result.usage) # Execute the tool calls. results = await asyncio.gather( @@ -303,7 +303,8 @@ async def on_messages_stream( assert isinstance(result.content, str) yield Response( - chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages + chat_message=TextMessage(content=result.content, source=self.name, model_usage=result.usage), + inner_messages=inner_messages, ) async def _execute_tool_call( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 51dbcca333d7..c8037671e131 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -1,7 +1,7 @@ from typing import List from autogen_core.components import FunctionCall, Image -from autogen_core.components.models import FunctionExecutionResult +from autogen_core.components.models import FunctionExecutionResult, RequestUsage from pydantic import BaseModel @@ -11,6 +11,9 @@ class BaseMessage(BaseModel): source: str """The name of the agent that sent this message.""" + model_usage: RequestUsage | None = None + """The model client usage incurred when producing this message.""" + class TextMessage(BaseMessage): """A text message.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 4589f86860d3..20556ad783cb 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -78,7 +78,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -88,7 +88,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ChatCompletion( id="id2", @@ -100,7 +100,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0), ), ] mock = _MockChatCompletion(chat_completions) @@ -113,9 +113,17 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: result = await tool_use_agent.run("task") assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].model_usage is None assert isinstance(result.messages[1], ToolCallMessage) + assert result.messages[1].model_usage is not None + assert result.messages[1].model_usage.completion_tokens == 5 + assert result.messages[1].model_usage.prompt_tokens == 10 assert isinstance(result.messages[2], ToolCallResultMessage) + assert result.messages[2].model_usage is None assert isinstance(result.messages[3], TextMessage) + assert result.messages[3].model_usage is not None + assert result.messages[3].model_usage.completion_tokens == 5 + assert result.messages[3].model_usage.prompt_tokens == 10 # Test streaming. mock._curr_index = 0 # pyright: ignore @@ -158,7 +166,7 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: created=0, model=model, object="chat.completion", - usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0), + usage=CompletionUsage(prompt_tokens=42, completion_tokens=43, total_tokens=85), ), ] mock = _MockChatCompletion(chat_completions) @@ -173,9 +181,17 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None: result = await tool_use_agent.run("task") assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) + assert result.messages[0].model_usage is None assert isinstance(result.messages[1], ToolCallMessage) + assert result.messages[1].model_usage is not None + assert result.messages[1].model_usage.completion_tokens == 43 + assert result.messages[1].model_usage.prompt_tokens == 42 assert isinstance(result.messages[2], ToolCallResultMessage) + assert result.messages[2].model_usage is None assert isinstance(result.messages[3], HandoffMessage) + assert result.messages[3].content == handoff.message + assert result.messages[3].target == handoff.target + assert result.messages[3].model_usage is None # Test streaming. mock._curr_index = 0 # pyright: ignore