Skip to content

Commit

Permalink
Load and Save state in AgentChat (#4436)
Browse files Browse the repository at this point in the history
1. convert dataclass types to pydantic basemodel 
2. add save_state and load_state for ChatAgent
3. state types for AgentChat
---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
victordibia and ekzhu authored Dec 5, 2024
1 parent fef06fd commit 777f2ab
Show file tree
Hide file tree
Showing 39 changed files with 3,684 additions and 2,964 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def main() -> None:
lambda: Coder(
model_client=client,
system_messages=[
SystemMessage("""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
SystemMessage(content="""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence

from autogen_core import CancellationToken, FunctionCall
from autogen_core.components.models import (
Expand All @@ -29,6 +29,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent

event_logger = logging.getLogger(EVENT_LOGGER_NAME)
Expand All @@ -49,6 +50,12 @@ def model_post_init(self, __context: Any) -> None:
class AssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
```{note}
The assistant agent is not thread-safe or coroutine-safe.
It should not be shared between multiple tasks or coroutines, and it should
not call its methods concurrently.
```
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
Expand Down Expand Up @@ -224,6 +231,7 @@ def __init__(
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
self._is_running = False

@property
def produced_message_types(self) -> List[type[ChatMessage]]:
Expand Down Expand Up @@ -327,3 +335,13 @@ async def _execute_tool_call(
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence

from autogen_core import CancellationToken

from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..state import BaseState


class BaseChatAgent(ChatAgent, ABC):
Expand Down Expand Up @@ -117,3 +118,11 @@ async def run_stream(
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Resets the agent to its initialization state."""
...

async def save_state(self) -> Mapping[str, Any]:
"""Export state. Default implementation for stateless agents."""
return BaseState().model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state. Default implementation for stateless agents."""
BaseState.model_validate(state)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable

from autogen_core import CancellationToken

Expand Down Expand Up @@ -54,3 +54,11 @@ def on_messages_stream(
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Resets the agent to its initialization state."""
...

async def save_state(self) -> Mapping[str, Any]:
"""Save agent state for later restoration"""
...

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state"""
...
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol
from typing import Any, Mapping, Protocol

from ._task import TaskRunner

Expand All @@ -7,3 +7,11 @@ class Team(TaskRunner, Protocol):
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
...

async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the team."""
...

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the team."""
...
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
from typing import List, Literal

from autogen_core import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated


class BaseMessage(BaseModel):
Expand All @@ -23,20 +24,26 @@ class TextMessage(BaseMessage):
content: str
"""The content of the message."""

type: Literal["TextMessage"] = "TextMessage"


class MultiModalMessage(BaseMessage):
"""A multimodal message."""

content: List[str | Image]
"""The content of the message."""

type: Literal["MultiModalMessage"] = "MultiModalMessage"


class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""

content: str
"""The content for the stop message."""

type: Literal["StopMessage"] = "StopMessage"


class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""
Expand All @@ -47,26 +54,35 @@ class HandoffMessage(BaseMessage):
content: str
"""The handoff message to the target agent."""

type: Literal["HandoffMessage"] = "HandoffMessage"


class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""

content: List[FunctionCall]
"""The tool calls."""

type: Literal["ToolCallMessage"] = "ToolCallMessage"


class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls."""

content: List[FunctionExecutionResult]
"""The tool call results."""

type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"


ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
"""Messages for agent-to-agent communication."""


AgentMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage
AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
Field(discriminator="type"),
]
"""All message types."""


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""State management for agents, teams and termination conditions."""

from ._states import (
AssistantAgentState,
BaseGroupChatManagerState,
BaseState,
ChatAgentContainerState,
MagenticOneOrchestratorState,
RoundRobinManagerState,
SelectorManagerState,
SwarmManagerState,
TeamState,
)

__all__ = [
"BaseState",
"AssistantAgentState",
"BaseGroupChatManagerState",
"ChatAgentContainerState",
"RoundRobinManagerState",
"SelectorManagerState",
"SwarmManagerState",
"MagenticOneOrchestratorState",
"TeamState",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Any, List, Mapping, Optional

from autogen_core.components.models import (
LLMMessage,
)
from pydantic import BaseModel, Field

from ..messages import (
AgentMessage,
ChatMessage,
)


class BaseState(BaseModel):
"""Base class for all saveable state"""

type: str = Field(default="BaseState")
version: str = Field(default="1.0.0")


class AssistantAgentState(BaseState):
"""State for an assistant agent."""

llm_messages: List[LLMMessage] = Field(default_factory=list)
type: str = Field(default="AssistantAgentState")


class TeamState(BaseState):
"""State for a team of agents."""

agent_states: Mapping[str, Any] = Field(default_factory=dict)
team_id: str = Field(default="")
type: str = Field(default="TeamState")


class BaseGroupChatManagerState(BaseState):
"""Base state for all group chat managers."""

message_thread: List[AgentMessage] = Field(default_factory=list)
current_turn: int = Field(default=0)
type: str = Field(default="BaseGroupChatManagerState")


class ChatAgentContainerState(BaseState):
"""State for a container of chat agents."""

agent_state: Mapping[str, Any] = Field(default_factory=dict)
message_buffer: List[ChatMessage] = Field(default_factory=list)
type: str = Field(default="ChatAgentContainerState")


class RoundRobinManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.RoundRobinGroupChat` manager."""

next_speaker_index: int = Field(default=0)
type: str = Field(default="RoundRobinManagerState")


class SelectorManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""

previous_speaker: Optional[str] = Field(default=None)
type: str = Field(default="SelectorManagerState")


class SwarmManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.Swarm` manager."""

current_speaker: str = Field(default="")
type: str = Field(default="SwarmManagerState")


class MagenticOneOrchestratorState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.MagneticOneGroupChat` orchestrator."""

task: str = Field(default="")
facts: str = Field(default="")
plan: str = Field(default="")
n_rounds: int = Field(default=0)
n_stalls: int = Field(default=0)
type: str = Field(default="MagenticOneOrchestratorState")
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Callable, List
from typing import Any, AsyncGenerator, Callable, List, Mapping

from autogen_core import (
AgentId,
Expand All @@ -20,6 +20,7 @@
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
from ._sequential_routed_agent import SequentialRoutedAgent
Expand Down Expand Up @@ -493,3 +494,38 @@ async def main() -> None:

# Indicate that the team is no longer running.
self._is_running = False

async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the group chat team."""
if not self._initialized:
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")

if self._is_running:
raise RuntimeError("The team cannot be saved while it is running.")
self._is_running = True

try:
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
agent_states = await self._runtime.save_state()
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
finally:
# Indicate that the team is no longer running.
self._is_running = False

async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the group chat team."""
if not self._initialized:
await self._init(self._runtime)

if self._is_running:
raise RuntimeError("The team cannot be loaded while it is running.")
self._is_running = True

try:
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
team_state = TeamState.model_validate(state)
self._team_id = team_state.team_id
await self._runtime.load_state(team_state.agent_states)
finally:
# Indicate that the team is no longer running.
self._is_running = False
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, List
from typing import Any, List, Mapping

from autogen_core import DefaultTopicId, MessageContext, event, rpc

from ...base import ChatAgent, Response
from ...messages import ChatMessage
from ...state import ChatAgentContainerState
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
from ._sequential_routed_agent import SequentialRoutedAgent

Expand Down Expand Up @@ -75,3 +76,13 @@ async def handle_request(self, message: GroupChatRequestPublish, ctx: MessageCon

async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")

async def save_state(self) -> Mapping[str, Any]:
agent_state = await self._agent.save_state()
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
return state.model_dump()

async def load_state(self, state: Mapping[str, Any]) -> None:
container_state = ChatAgentContainerState.model_validate(state)
self._message_buffer = list(container_state.message_buffer)
await self._agent.load_state(container_state.agent_state)
Loading

0 comments on commit 777f2ab

Please sign in to comment.