Skip to content

Commit e31e100

Browse files
author
aditya.kurniawan
committed
add UnboundedBufferedChatCompletionContext to mimic pervious model_context behaviour on AssistantAgent
1 parent 12aba19 commit e31e100

File tree

3 files changed

+49
-7
lines changed

3 files changed

+49
-7
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
Dict,
1111
List,
1212
Mapping,
13+
Optional,
1314
Sequence,
1415
)
1516

1617
from autogen_core import CancellationToken, FunctionCall
1718
from autogen_core.model_context import (
18-
BufferedChatCompletionContext,
1919
ChatCompletionContext,
20+
UnboundedBufferedChatCompletionContext,
2021
)
2122
from autogen_core.models import (
2223
AssistantMessage,
@@ -227,7 +228,7 @@ def __init__(
227228
self,
228229
name: str,
229230
model_client: ChatCompletionClient,
230-
model_context: ChatCompletionContext = BufferedChatCompletionContext(0),
231+
model_context: Optional[ChatCompletionContext] = None,
231232
*,
232233
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
233234
handoffs: List[HandoffBase | str] | None = None,
@@ -286,7 +287,8 @@ def __init__(
286287
raise ValueError(
287288
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
288289
)
289-
self._model_context = model_context
290+
if not model_context:
291+
self._model_context = UnboundedBufferedChatCompletionContext()
290292
self._reflect_on_tool_use = reflect_on_tool_use
291293
self._tool_call_summary_format = tool_call_summary_format
292294
self._is_running = False
@@ -420,11 +422,14 @@ async def on_reset(self, cancellation_token: CancellationToken) -> None:
420422

421423
async def save_state(self) -> Mapping[str, Any]:
422424
"""Save the current state of the assistant agent."""
423-
return AssistantAgentState(llm_messages=await self._model_context.get_messages()).model_dump()
425+
current_model_ctx_state = self._model_context.save_state()
426+
return AssistantAgentState(llm_messages=current_model_ctx_state["messages"]).model_dump()
424427

425428
async def load_state(self, state: Mapping[str, Any]) -> None:
426429
"""Load the state of the assistant agent"""
427430
assistant_agent_state = AssistantAgentState.model_validate(state)
428431
await self._model_context.clear()
429-
for message in assistant_agent_state.llm_messages:
430-
await self._model_context.add_message(message)
432+
433+
current_model_ctx_state = dict(self._model_context.save_state())
434+
current_model_ctx_state["messages"] = assistant_agent_state.llm_messages
435+
self._model_context.load_state(current_model_ctx_state)
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
from ._buffered_chat_completion_context import BufferedChatCompletionContext
1+
from ._buffered_chat_completion_context import (
2+
BufferedChatCompletionContext,
3+
UnboundedBufferedChatCompletionContext,
4+
)
25
from ._chat_completion_context import ChatCompletionContext
36
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext
47

58
__all__ = [
69
"ChatCompletionContext",
10+
"UnboundedBufferedChatCompletionContext",
711
"BufferedChatCompletionContext",
812
"HeadAndTailChatCompletionContext",
913
]

python/packages/autogen-core/src/autogen_core/model_context/_buffered_chat_completion_context.py

+33
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,36 @@ def save_state(self) -> Mapping[str, Any]:
4343
def load_state(self, state: Mapping[str, Any]) -> None:
4444
self._messages = state["messages"]
4545
self._buffer_size = state["buffer_size"]
46+
47+
48+
class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
49+
"""A buffered chat completion context that keeps a view of the last n messages,
50+
where n is the buffer size. The buffer size is set at initialization.
51+
52+
Args:
53+
buffer_size (int): The size of the buffer.
54+
55+
"""
56+
57+
def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
58+
self._messages: List[LLMMessage] = initial_messages or []
59+
60+
async def add_message(self, message: LLMMessage) -> None:
61+
"""Add a message to the memory."""
62+
self._messages.append(message)
63+
64+
async def get_messages(self) -> List[LLMMessage]:
65+
"""Get at most `buffer_size` recent messages."""
66+
return self._messages
67+
68+
async def clear(self) -> None:
69+
"""Clear the message memory."""
70+
self._messages = []
71+
72+
def save_state(self) -> Mapping[str, Any]:
73+
return {
74+
"messages": [message for message in self._messages],
75+
}
76+
77+
def load_state(self, state: Mapping[str, Any]) -> None:
78+
self._messages = state["messages"]

0 commit comments

Comments
 (0)