Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Dec 20, 2024
2 parents 73330da + c989181 commit 2e47f9e
Show file tree
Hide file tree
Showing 18 changed files with 221 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,27 @@
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Mapping,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand All @@ -28,6 +40,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
Expand Down Expand Up @@ -62,7 +75,7 @@ class AssistantAgent(BaseChatAgent):
* If the model returns no tool call, then the response is immediately returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
* When the model returns tool calls, they will be executed right away:
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is False (default), the tool call results are returned as a :class:`~autogen_agentchat.messages.ToolCallSummaryMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`. `tool_call_summary_format` can be used to customize the tool call summary.
- When `reflect_on_tool_use` is True, the another model inference is made using the tool calls and results, and the text response is returned as a :class:`~autogen_agentchat.messages.TextMessage` in :attr:`~autogen_agentchat.base.Response.chat_message`.
Hand off behavior:
Expand All @@ -86,7 +99,6 @@ class AssistantAgent(BaseChatAgent):
If multiple handoffs are detected, only the first handoff is executed.
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
Expand All @@ -95,8 +107,9 @@ class AssistantAgent(BaseChatAgent):
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
If a handoff is a string, it should represent the target agent's name.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
Expand Down Expand Up @@ -218,9 +231,11 @@ def __init__(
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
Expand Down Expand Up @@ -272,17 +287,21 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
if not model_context:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
return [TextMessage, HandoffMessage]
return [TextMessage]
message_types.append(HandoffMessage)
if self._tools:
message_types.append(ToolCallSummaryMessage)
return message_types

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand All @@ -297,19 +316,19 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
Expand All @@ -331,7 +350,7 @@ async def on_messages_stream(
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -356,11 +375,11 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
Expand All @@ -379,7 +398,7 @@ async def on_messages_stream(
)
tool_call_summary = "\n".join(tool_call_summaries)
yield Response(
chat_message=TextMessage(content=tool_call_summary, source=self.name),
chat_message=ToolCallSummaryMessage(content=tool_call_summary, source=self.name),
inner_messages=inner_messages,
)

Expand All @@ -402,14 +421,15 @@ async def _execute_tool_call(

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
await self._model_context.clear()

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
model_context_state = await self._model_context.save_state()
return AssistantAgentState(llm_context=model_context_state).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,18 @@ class ToolCallExecutionEvent(BaseMessage):
type: Literal["ToolCallExecutionEvent"] = "ToolCallExecutionEvent"


ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
class ToolCallSummaryMessage(BaseMessage):
"""A message signaling the summary of tool call results."""

content: str
"""Summary of the the tool call results."""

type: Literal["ToolCallSummaryMessage"] = "ToolCallSummaryMessage"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


Expand All @@ -110,7 +121,13 @@ class ToolCallExecutionEvent(BaseMessage):


AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallRequestEvent | ToolCallExecutionEvent,
TextMessage
| MultiModalMessage
| StopMessage
| HandoffMessage
| ToolCallRequestEvent
| ToolCallExecutionEvent
| ToolCallSummaryMessage,
Field(discriminator="type"),
]
"""(Deprecated, will be removed in 0.4.0) All message and event types."""
Expand All @@ -126,6 +143,7 @@ class ToolCallExecutionEvent(BaseMessage):
"ToolCallExecutionEvent",
"ToolCallMessage",
"ToolCallResultMessage",
"ToolCallSummaryMessage",
"ChatMessage",
"AgentEvent",
"AgentMessage",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Any, List, Mapping, Optional

from autogen_core.models import (
LLMMessage,
)
from pydantic import BaseModel, Field

from ..messages import (
Expand All @@ -21,7 +18,7 @@ class BaseState(BaseModel):
class AssistantAgentState(BaseState):
"""State for an assistant agent."""

llm_messages: List[LLMMessage] = Field(default_factory=list)
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
type: str = Field(default="AssistantAgentState")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
Expand Down Expand Up @@ -433,7 +434,7 @@ def _thread_to_context(self) -> List[LLMMessage]:
elif isinstance(m, StopMessage | HandoffMessage):
context.append(UserMessage(content=m.content, source=m.source))
elif m.source == self._name:
assert isinstance(m, TextMessage)
assert isinstance(m, TextMessage | ToolCallSummaryMessage)
context.append(AssistantMessage(content=m.content, source=m.source))
else:
assert isinstance(m, TextMessage) or isinstance(m, MultiModalMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
Expand Down Expand Up @@ -100,7 +101,7 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str:
continue
# The agent type must be the same as the topic type, which we use as the agent name.
message = f"{msg.source}:"
if isinstance(msg, TextMessage | StopMessage | HandoffMessage):
if isinstance(msg, TextMessage | StopMessage | HandoffMessage | ToolCallSummaryMessage):
message += f" {msg.content}"
elif isinstance(msg, MultiModalMessage):
for item in msg.content:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_core import Image
from autogen_core.tools import FunctionTool
Expand Down Expand Up @@ -142,7 +143,7 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[2].models_usage is None
assert isinstance(result.messages[3], TextMessage)
assert isinstance(result.messages[3], ToolCallSummaryMessage)
assert result.messages[3].content == "pass"
assert result.messages[3].models_usage is None

Expand Down
30 changes: 21 additions & 9 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TextMessage,
ToolCallExecutionEvent,
ToolCallRequestEvent,
ToolCallSummaryMessage,
)
from autogen_agentchat.teams import (
RoundRobinGroupChat,
Expand Down Expand Up @@ -238,8 +239,13 @@ async def test_round_robin_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -325,7 +331,8 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
assert result.messages[3].content == "pass" # ensure the tool call was executed
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response
assert isinstance(result.messages[6], TextMessage) # echo agent response
Expand All @@ -335,7 +342,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -349,7 +356,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
index += 1

# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down Expand Up @@ -577,8 +584,13 @@ async def test_selector_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -929,7 +941,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -942,7 +954,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
index += 1

# Test Console
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down
Loading

0 comments on commit 2e47f9e

Please sign in to comment.