Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple model_context from AssistantAgent #4681

Merged
merged 13 commits into from
Dec 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,27 @@
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Dict,
List,
Mapping,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
Expand Down Expand Up @@ -87,7 +99,6 @@ class AssistantAgent(BaseChatAgent):
If multiple handoffs are detected, only the first handoff is executed.



Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
Expand All @@ -96,8 +107,9 @@ class AssistantAgent(BaseChatAgent):
allowing it to transfer to other agents by responding with a :class:`HandoffMessage`.
The transfer is only executed when the team is in :class:`~autogen_agentchat.teams.Swarm`.
If a handoff is a string, it should represent the target agent's name.
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
Expand Down Expand Up @@ -219,9 +231,11 @@ def __init__(
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str
| None = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
):
Expand Down Expand Up @@ -273,7 +287,8 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
if not model_context:
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
Expand Down Expand Up @@ -301,19 +316,19 @@ async def on_messages_stream(
for msg in messages:
if isinstance(msg, MultiModalMessage) and self._model_client.capabilities["vision"] is False:
raise ValueError("The model does not support vision.")
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
await self._model_context.add_message(UserMessage(content=msg.content, source=msg.source))

# Inner messages.
inner_messages: List[AgentEvent | ChatMessage] = []

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))

# Check if the response is a string and return it.
if isinstance(result.content, str):
Expand All @@ -335,7 +350,7 @@ async def on_messages_stream(
results = await asyncio.gather(*[self._execute_tool_call(call, cancellation_token) for call in result.content])
tool_call_result_msg = ToolCallExecutionEvent(content=results, source=self.name)
event_logger.debug(tool_call_result_msg)
self._model_context.append(FunctionExecutionResultMessage(content=results))
await self._model_context.add_message(FunctionExecutionResultMessage(content=results))
inner_messages.append(tool_call_result_msg)
yield tool_call_result_msg

Expand All @@ -360,11 +375,11 @@ async def on_messages_stream(

if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + self._model_context
llm_messages = self._system_messages + await self._model_context.get_messages()
result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(result.content, str)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
await self._model_context.add_message(AssistantMessage(content=result.content, source=self.name))
# Yield the response.
yield Response(
chat_message=TextMessage(content=result.content, source=self.name, models_usage=result.usage),
Expand Down Expand Up @@ -406,14 +421,15 @@ async def _execute_tool_call(

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
await self._model_context.clear()

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
model_context_state = await self._model_context.save_state()
return AssistantAgentState(llm_context=model_context_state).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Any, List, Mapping, Optional

from autogen_core.models import (
LLMMessage,
)
from pydantic import BaseModel, Field

from ..messages import (
Expand All @@ -21,7 +18,7 @@ class BaseState(BaseModel):
class AssistantAgentState(BaseState):
"""State for an assistant agent."""

llm_messages: List[LLMMessage] = Field(default_factory=list)
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
type: str = Field(default="AssistantAgentState")


Expand Down
26 changes: 18 additions & 8 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,13 @@ async def test_round_robin_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -337,7 +342,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -351,7 +356,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
index += 1

# Test Console.
tool_use_agent._model_context.clear() # pyright: ignore
await tool_use_agent._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down Expand Up @@ -579,8 +584,13 @@ async def test_selector_group_chat_state() -> None:
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
Expand Down Expand Up @@ -931,7 +941,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"

# Test streaming.
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand All @@ -944,7 +954,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
index += 1

# Test Console
agent1._model_context.clear() # pyright: ignore
await agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
await team.reset()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,17 +454,17 @@
"\n",
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
"the system message and the latest user's message.\n",
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
"to make the agent \"remember\" previous conversations.\n",
"A model context supports storage and retrieval of Chat Completion messages.\n",
"It is always used together with a model client to generate LLM-based responses.\n",
"\n",
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
"\n",
"Let's update the previous example to use\n",
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
]
},
{
Expand All @@ -473,7 +473,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.model_context import BufferedChatCompletionContext\n",
"from autogen_core.models import AssistantMessage\n",
"\n",
"\n",
Expand Down Expand Up @@ -615,7 +615,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ async def _execute_function(

async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
"system_messages": self._system_messages,
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
await self._model_context.load_state(state["memory"])
self._system_messages = state["system_messages"]
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ async def on_new_message(self, message: TextMessage | MultiModalMessage, ctx: Me

async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"termination_word": self._termination_word,
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
# Load the chat history.
await self._model_context.load_state(state["chat_history"])
# Load the termination word.
self._termination_word = state["termination_word"]
6 changes: 3 additions & 3 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def save_state(self) -> Mapping[str, Any]:
return state_to_save

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])


class ScheduleMeetingInput(BaseModel):
Expand Down Expand Up @@ -200,11 +200,11 @@ async def handle_message(self, message: UserTextMessage, ctx: MessageContext) ->

async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])


class NeedsUserInputHandler(DefaultInterventionHandler):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)

__all__ = [
"ChatCompletionContext",
"ChatCompletionContextState",
"UnboundedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List

from ..models import FunctionExecutionResultMessage, LLMMessage
from ._chat_completion_context import ChatCompletionContext
Expand All @@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):

Args:
buffer_size (int): The size of the buffer.

initial_messages (List[LLMMessage] | None): The initial messages.
"""

def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
super().__init__(initial_messages)
if buffer_size <= 0:
raise ValueError("buffer_size must be greater than 0.")
self._buffer_size = buffer_size

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
messages = self._messages[-self._buffer_size :]
Expand All @@ -29,17 +27,3 @@ async def get_messages(self) -> List[LLMMessage]:
# Remove the first message from the list.
messages = messages[1:]
return messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]
Loading