Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple model_context from AssistantAgent #4681

Merged
merged 13 commits into from
Dec 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@
Dict,
List,
Mapping,
Optional,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
ChatCompletionContext,
UnboundedBufferedChatCompletionContext,
UnboundedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
Expand Down Expand Up @@ -228,10 +227,10 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
model_context: Optional[ChatCompletionContext] = None,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
model_context: ChatCompletionContext | None = None,
description: str = "An agent that provides assistance with ability to use tools.",
system_message: (
str | None
Expand Down Expand Up @@ -288,7 +287,7 @@ def __init__(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
if not model_context:
self._model_context = UnboundedBufferedChatCompletionContext()
self._model_context = UnboundedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
Expand Down Expand Up @@ -422,14 +421,11 @@ async def on_reset(self, cancellation_token: CancellationToken) -> None:

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
current_model_ctx_state = self._model_context.save_state()
return AssistantAgentState(llm_messages=current_model_ctx_state["messages"]).model_dump()
model_context_state = await self._model_context.save_state()
return AssistantAgentState(llm_context=model_context_state).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
await self._model_context.clear()

current_model_ctx_state = dict(self._model_context.save_state())
current_model_ctx_state["messages"] = assistant_agent_state.llm_messages
self._model_context.load_state(current_model_ctx_state)
# Load the model context state.
await self._model_context.load_state(assistant_agent_state.llm_context)
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import Any, List, Mapping, Optional

from autogen_core.models import (
LLMMessage,
)
from pydantic import BaseModel, Field

from ..messages import (
Expand All @@ -21,7 +18,7 @@ class BaseState(BaseModel):
class AssistantAgentState(BaseState):
"""State for an assistant agent."""

llm_messages: List[LLMMessage] = Field(default_factory=list)
llm_context: Mapping[str, Any] = Field(default_factory=lambda: dict([("messages", [])]))
type: str = Field(default="AssistantAgentState")


Expand Down
16 changes: 8 additions & 8 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ async def test_round_robin_group_chat_state() -> None:
state2 = await team2.save_state()
assert state == state2

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
Expand Down Expand Up @@ -583,10 +583,10 @@ async def test_selector_group_chat_state() -> None:
state2 = await team2.save_state()
assert state == state2

agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
assert agent3_model_ctx_messages == agent1_model_ctx_messages
assert agent4_model_ctx_messages == agent2_model_ctx_messages
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,17 +454,17 @@
"\n",
"The above `SimpleAgent` always responds with a fresh context that contains only\n",
"the system message and the latest user's message.\n",
"We can use model context classes from {py:mod}`autogen_core.components.model_context`\n",
"We can use model context classes from {py:mod}`autogen_core.model_context`\n",
"to make the agent \"remember\" previous conversations.\n",
"A model context supports storage and retrieval of Chat Completion messages.\n",
"It is always used together with a model client to generate LLM-based responses.\n",
"\n",
"For example, {py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`\n",
"For example, {py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`\n",
"is a most-recent-used (MRU) context that stores the most recent `buffer_size`\n",
"number of messages. This is useful to avoid context overflow in many LLMs.\n",
"\n",
"Let's update the previous example to use\n",
"{py:mod}`~autogen_core.components.model_context.BufferedChatCompletionContext`."
"{py:mod}`~autogen_core.model_context.BufferedChatCompletionContext`."
]
},
{
Expand All @@ -473,7 +473,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autogen_core.components.model_context import BufferedChatCompletionContext\n",
"from autogen_core.model_context import BufferedChatCompletionContext\n",
"from autogen_core.models import AssistantMessage\n",
"\n",
"\n",
Expand Down Expand Up @@ -615,7 +615,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.7"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ async def _execute_function(

async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
"system_messages": self._system_messages,
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["memory"])
await self._model_context.load_state(state["memory"])
self._system_messages = state["system_messages"]
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ async def on_new_message(self, message: TextMessage | MultiModalMessage, ctx: Me

async def save_state(self) -> Mapping[str, Any]:
return {
"chat_history": self._model_context.save_state(),
"chat_history": await self._model_context.save_state(),
"termination_word": self._termination_word,
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state(state["chat_history"])
# Load the chat history.
await self._model_context.load_state(state["chat_history"])
# Load the termination word.
self._termination_word = state["termination_word"]
6 changes: 3 additions & 3 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def save_state(self) -> Mapping[str, Any]:
return state_to_save

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})


class ScheduleMeetingInput(BaseModel):
Expand Down Expand Up @@ -200,11 +200,11 @@ async def handle_message(self, message: UserTextMessage, ctx: MessageContext) ->

async def save_state(self) -> Mapping[str, Any]:
return {
"memory": self._model_context.save_state(),
"memory": await self._model_context.save_state(),
}

async def load_state(self, state: Mapping[str, Any]) -> None:
self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]})
await self._model_context.load_state(state["memory"])


class NeedsUserInputHandler(DefaultInterventionHandler):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._chat_completion_context import ChatCompletionContext
from ._chat_completion_context import ChatCompletionContext, ChatCompletionContextState
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
from ._unbounded_buffered_chat_completion_context import (
UnboundedBufferedChatCompletionContext,
from ._unbounded_chat_completion_context import (
UnboundedChatCompletionContext,
)

__all__ = [
"ChatCompletionContext",
"UnboundedBufferedChatCompletionContext",
"ChatCompletionContextState",
"UnboundedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List

from ..models import FunctionExecutionResultMessage, LLMMessage
from ._chat_completion_context import ChatCompletionContext
Expand All @@ -10,17 +10,15 @@ class BufferedChatCompletionContext(ChatCompletionContext):

Args:
buffer_size (int): The size of the buffer.

initial_messages (List[LLMMessage] | None): The initial messages.
"""

def __init__(self, buffer_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []
super().__init__(initial_messages)
if buffer_size <= 0:
raise ValueError("buffer_size must be greater than 0.")
self._buffer_size = buffer_size

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
messages = self._messages[-self._buffer_size :]
Expand All @@ -29,17 +27,3 @@ async def get_messages(self) -> List[LLMMessage]:
# Remove the first message from the list.
messages = messages[1:]
return messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"buffer_size": self._buffer_size,
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]
Original file line number Diff line number Diff line change
@@ -1,19 +1,40 @@
from typing import List, Mapping, Protocol
from abc import ABC, abstractmethod
from typing import Any, List, Mapping

from pydantic import BaseModel, Field

from ..models import LLMMessage


class ChatCompletionContext(Protocol):
"""A protocol for defining the interface of a chat completion context.
class ChatCompletionContext(ABC):
"""An abstract base class for defining the interface of a chat completion context.
A chat completion context lets agents store and retrieve LLM messages.
It can be implemented with different recall strategies."""
It can be implemented with different recall strategies.

Args:
initial_messages (List[LLMMessage] | None): The initial messages.
"""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None: ...
async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the context."""
self._messages.append(message)

@abstractmethod
async def get_messages(self) -> List[LLMMessage]: ...

async def clear(self) -> None: ...
async def clear(self) -> None:
"""Clear the context."""
self._messages = []

async def save_state(self) -> Mapping[str, Any]:
return ChatCompletionContextState(messages=self._messages).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = ChatCompletionContextState.model_validate(state).messages

def save_state(self) -> Mapping[str, LLMMessage]: ...

def load_state(self, state: Mapping[str, LLMMessage]) -> None: ...
class ChatCompletionContextState(BaseModel):
messages: List[LLMMessage] = Field(default_factory=list)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Mapping
from typing import List

from .._types import FunctionCall
from ..models import AssistantMessage, FunctionExecutionResultMessage, LLMMessage, UserMessage
Expand All @@ -13,17 +13,18 @@ class HeadAndTailChatCompletionContext(ChatCompletionContext):
Args:
head_size (int): The size of the head.
tail_size (int): The size of the tail.
initial_messages (List[LLMMessage] | None): The initial messages.
"""

def __init__(self, head_size: int, tail_size: int) -> None:
self._messages: List[LLMMessage] = []
def __init__(self, head_size: int, tail_size: int, initial_messages: List[LLMMessage] | None = None) -> None:
super().__init__(initial_messages)
if head_size <= 0:
raise ValueError("head_size must be greater than 0.")
if tail_size <= 0:
raise ValueError("tail_size must be greater than 0.")
self._head_size = head_size
self._tail_size = tail_size

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `head_size` recent messages and `tail_size` oldest messages."""
head_messages = self._messages[: self._head_size]
Expand Down Expand Up @@ -51,21 +52,3 @@ async def get_messages(self) -> List[LLMMessage]:

placeholder_messages = [UserMessage(content=f"Skipped {num_skipped} messages.", source="System")]
return head_messages + placeholder_messages + tail_messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
"head_size": self._head_size,
"tail_size": self._tail_size,
"placeholder_message": self._placeholder_message,
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._head_size = state["head_size"]
self._tail_size = state["tail_size"]
self._placeholder_message = state["placeholder_message"]

This file was deleted.

Loading