diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index b4e58546c05a..b17bf18c309e 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,5 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, List, Mapping, Sequence +from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args, get_origin + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated from autogen_core import CancellationToken @@ -79,12 +84,12 @@ async def run( output_messages.append(text_msg) elif isinstance(task, list): for msg in task: - if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(msg) output_messages.append(msg) else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(task) output_messages.append(task) else: @@ -116,13 +121,13 @@ async def run_stream( yield text_msg elif isinstance(task, list): for msg in task: - if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(msg) output_messages.append(msg) yield msg else: raise ValueError(f"Invalid message type in list: {type(msg)}") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: input_messages.append(task) output_messages.append(task) yield task diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 3856ea7fcf0a..2de743c6297d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -2,7 +2,12 @@ import logging import uuid from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, List, Mapping +from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args, get_origin + +try: + from typing import Annotated +except ImportError: + from typing_extensions import Annotated from autogen_core import ( AgentId, @@ -368,12 +373,15 @@ async def main() -> None: pass elif isinstance(task, str): first_chat_message = TextMessage(content=task, source="user") - elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)): + elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__: first_chat_message = task elif isinstance(task, list): if not task: raise ValueError("Task list cannot be empty") - if not all(isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)) for msg in task): + if not all( + get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__ + for msg in task + ): raise ValueError("All messages in task list must be valid ChatMessage types") first_chat_message = task[0] # Queue remaining messages for processing