Skip to content

Commit

Permalink
add UnboundedBufferedChatCompletionContext to mimic pervious model_co…
Browse files Browse the repository at this point in the history
…ntext behaviour on AssistantAgent
  • Loading branch information
aditya.kurniawan authored and akurniawan committed Dec 16, 2024
1 parent 8cb546f commit fc6c184
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
Dict,
List,
Mapping,
Optional,
Sequence,
)

from autogen_core import CancellationToken, FunctionCall
from autogen_core.model_context import (
BufferedChatCompletionContext,
ChatCompletionContext,
UnboundedBufferedChatCompletionContext,
)
from autogen_core.models import (
AssistantMessage,
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(
self,
name: str,
model_client: ChatCompletionClient,
model_context: ChatCompletionContext = BufferedChatCompletionContext(0),
model_context: Optional[ChatCompletionContext] = None,
*,
tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]] | None = None,
handoffs: List[HandoffBase | str] | None = None,
Expand Down Expand Up @@ -286,7 +287,8 @@ def __init__(
raise ValueError(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context = model_context
if not model_context:
self._model_context = UnboundedBufferedChatCompletionContext()
self._reflect_on_tool_use = reflect_on_tool_use
self._tool_call_summary_format = tool_call_summary_format
self._is_running = False
Expand Down Expand Up @@ -420,11 +422,14 @@ async def on_reset(self, cancellation_token: CancellationToken) -> None:

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=await self._model_context.get_messages()).model_dump()
current_model_ctx_state = self._model_context.save_state()
return AssistantAgentState(llm_messages=current_model_ctx_state["messages"]).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
await self._model_context.clear()
for message in assistant_agent_state.llm_messages:
await self._model_context.add_message(message)

current_model_ctx_state = dict(self._model_context.save_state())
current_model_ctx_state["messages"] = assistant_agent_state.llm_messages
self._model_context.load_state(current_model_ctx_state)
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from ._buffered_chat_completion_context import BufferedChatCompletionContext
from ._buffered_chat_completion_context import (
BufferedChatCompletionContext,
UnboundedBufferedChatCompletionContext,
)
from ._chat_completion_context import ChatCompletionContext
from ._head_and_tail_chat_completion_context import HeadAndTailChatCompletionContext

__all__ = [
"ChatCompletionContext",
"UnboundedBufferedChatCompletionContext",
"BufferedChatCompletionContext",
"HeadAndTailChatCompletionContext",
]
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,36 @@ def save_state(self) -> Mapping[str, Any]:
def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]
self._buffer_size = state["buffer_size"]


class UnboundedBufferedChatCompletionContext(ChatCompletionContext):
"""A buffered chat completion context that keeps a view of the last n messages,
where n is the buffer size. The buffer size is set at initialization.
Args:
buffer_size (int): The size of the buffer.
"""

def __init__(self, initial_messages: List[LLMMessage] | None = None) -> None:
self._messages: List[LLMMessage] = initial_messages or []

async def add_message(self, message: LLMMessage) -> None:
"""Add a message to the memory."""
self._messages.append(message)

async def get_messages(self) -> List[LLMMessage]:
"""Get at most `buffer_size` recent messages."""
return self._messages

async def clear(self) -> None:
"""Clear the message memory."""
self._messages = []

def save_state(self) -> Mapping[str, Any]:
return {
"messages": [message for message in self._messages],
}

def load_state(self, state: Mapping[str, Any]) -> None:
self._messages = state["messages"]

0 comments on commit fc6c184

Please sign in to comment.