Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tool call result summary message #4755

Merged
merged 13 commits into from
Dec 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
Expand Down Expand Up @@ -62,7 +63,7 @@ class AssistantAgent(BaseChatAgent):

* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.

Hand off behavior:
Expand Down Expand Up @@ -280,9 +281,12 @@ def __init__(
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
return [TextMessage, HandoffMessage]
return [TextMessage]
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
return message_types

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -379,7 +383,7 @@ async def on_messages_stream(
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"


ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
class ToolCallSummaryMessage(BaseMessage):
"""A message signaling the summary of tool call results."""

content: str
"""Summary of the the tool call results."""

type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


Expand All @@ -110,7 +121,13 @@ class ToolCallExecutionEvent(BaseMessage):


AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
TextMessage
| MultiModalMessage
| StopMessage
| HandoffMessage
| ToolCallRequestEvent
| ToolCallExecutionEvent
| ToolCallSummaryMessage,
Field(discriminator="type"),
]
"""(Deprecated, will be removed in 0.4.0) All message and event types."""
Expand All @@ -126,6 +143,7 @@ class ToolCallExecutionEvent(BaseMessage):
"ToolCallExecutionEvent",
"ToolCallMessage",
"ToolCallResultMessage",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"AgentMessage",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
Expand Down Expand Up @@ -433,7 +434,7 @@ def _thread_to_context(self) -> List[LLMMessage]:
elif isinstance(m, StopMessage | HandoffMessage):
context.append(UserMessage(content=m.content, source=m.source))
elif m.source == self._name:
assert isinstance(m, TextMessage)
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
context.append(AssistantMessage(content=m.content, source=m.source))
else:
assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -100,7 +101,7 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg, TextMessage | StopMessage | HandoffMessage):
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
message += f" {msg.content}"
elif isinstance(msg, MultiModalMessage):
for item in msg.content:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.tools import FunctionTool
Expand Down Expand Up @@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert result.messages[3].models_usage is None

Expand Down
4 changes: 3 additions & 1 deletion python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_agentchat.teams import (
RoundRobinGroupChat,
Expand Down Expand Up @@ -325,7 +326,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
assert result.messages[3].content == "pass" # ensure the tool call was executed
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert isinstance(result.messages[6], TextMessage) # echo agent response
Expand Down
Loading