Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple model_context from AssistantAgent #4681

Merged
merged 13 commits into from
Dec 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,28 @@
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Mapping,
Optional,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedBufferedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -215,12 +228,14 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
model_context: Optional[ChatCompletionContext] = None,
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
Expand Down Expand Up @@ -272,7 +287,8 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
if not model_context:
self._model_context = UnboundedBufferedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
Expand All @@ -297,19 +313,19 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
Expand All @@ -331,7 +347,7 @@ async def on_messages_stream(
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -356,11 +372,11 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
Expand Down Expand Up @@ -402,14 +418,18 @@ async def _execute_tool_call(

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
await self._model_context.clear()

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
current_model_ctx_state = self._model_context.save_state()
return AssistantAgentState(llm_messages=current_model_ctx_state["messages"]).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
await self._model_context.clear()

current_model_ctx_state = dict(self._model_context.save_state())
current_model_ctx_state["messages"] = assistant_agent_state.llm_messages
self._model_context.load_state(current_model_ctx_state)
26 changes: 18 additions & 8 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,13 @@ async def test_round_robin_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -335,7 +340,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -349,7 +354,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
index += 1

# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down Expand Up @@ -577,8 +582,13 @@ async def test_selector_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -929,7 +939,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -942,7 +952,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
index += 1

# Test Console
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_buffered_chat_completion_context import (
UnboundedBufferedChatCompletionContext,
)

__all__ = [
"ChatCompletionContext",
"UnboundedBufferedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Any, List, Mapping

from ..models import LLMMessage
from ._chat_completion_context import ChatCompletionContext


class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
"""An unbounded buffered chat completion context that keeps a view of the all the messages."""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]