Skip to content

Commit

Permalink
Tool call result summary message (#4755)
Browse files Browse the repository at this point in the history
* Adding ToolCallResultSummaryMessage

* Support for ToolCallResultSummaryMessage

* Added ToolCallSummaryMessage

* ruff format

* Add ToolCallSummaryMessage to ChatMessage

* typing and tests for ToolCallSummaryMessage

* PR Feedback

---------

Co-authored-by: Eric Zhu <[email protected]>
Co-authored-by: Hussein Mozannar <[email protected]>
  • Loading branch information
3 people authored Dec 20, 2024
1 parent 3dd4be9 commit a271708
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
Expand Down Expand Up @@ -62,7 +63,7 @@ class AssistantAgent(BaseChatAgent):
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
Hand off behavior:
Expand Down Expand Up @@ -280,9 +281,12 @@ def __init__(
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
return [TextMessage, HandoffMessage]
return [TextMessage]
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
return message_types

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -379,7 +383,7 @@ async def on_messages_stream(
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"


ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
class ToolCallSummaryMessage(BaseMessage):
"""A message signaling the summary of tool call results."""

content: str
"""Summary of the the tool call results."""

type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


Expand All @@ -110,7 +121,13 @@ class ToolCallExecutionEvent(BaseMessage):


AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
TextMessage
| MultiModalMessage
| StopMessage
| HandoffMessage
| ToolCallRequestEvent
| ToolCallExecutionEvent
| ToolCallSummaryMessage,
Field(discriminator="type"),
]
"""(Deprecated, will be removed in 0.4.0) All message and event types."""
Expand All @@ -126,6 +143,7 @@ class ToolCallExecutionEvent(BaseMessage):
"ToolCallExecutionEvent",
"ToolCallMessage",
"ToolCallResultMessage",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"AgentMessage",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
Expand Down Expand Up @@ -433,7 +434,7 @@ def _thread_to_context(self) -> List[LLMMessage]:
elif isinstance(m, StopMessage | HandoffMessage):
context.append(UserMessage(content=m.content, source=m.source))
elif m.source == self._name:
assert isinstance(m, TextMessage)
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
context.append(AssistantMessage(content=m.content, source=m.source))
else:
assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -100,7 +101,7 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg, TextMessage | StopMessage | HandoffMessage):
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
message += f" {msg.content}"
elif isinstance(msg, MultiModalMessage):
for item in msg.content:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.tools import FunctionTool
Expand Down Expand Up @@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert result.messages[3].models_usage is None

Expand Down
4 changes: 3 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_agentchat.teams import (
RoundRobinGroupChat,
Expand Down Expand Up @@ -325,7 +326,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
assert result.messages[3].content == "pass" # ensure the tool call was executed
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert isinstance(result.messages[6], TextMessage) # echo agent response
Expand Down

0 comments on commit a271708

Please sign in to comment.