diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py index 002d5fb472a..77d625d0e20 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/state/_states.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional +from typing import Annotated, Any, List, Mapping, Optional from pydantic import BaseModel, Field @@ -7,6 +7,9 @@ ChatMessage, ) +# Ensures pydantic can distinguish between types of events & messages. +_AgentMessage = Annotated[AgentEvent | ChatMessage, Field(discriminator="type")] + class BaseState(BaseModel): """Base class for all saveable state""" @@ -33,7 +36,7 @@ class TeamState(BaseState): class BaseGroupChatManagerState(BaseState): """Base state for all group chat managers.""" - message_thread: List[AgentEvent | ChatMessage] = Field(default_factory=list) + message_thread: List[_AgentMessage] = Field(default_factory=list) current_turn: int = Field(default=0) type: str = Field(default="BaseGroupChatManagerState")