Skip to content

Commit

Permalink
fix: type check for Annotated types
Browse files Browse the repository at this point in the history
  • Loading branch information
iamarunbrahma committed Dec 10, 2024
1 parent 376f337 commit 9666740
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence, get_args, get_origin

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

from autogen_core import CancellationToken

Expand Down Expand Up @@ -79,12 +84,12 @@ async def run(
output_messages.append(text_msg)
elif isinstance(task, list):
for msg in task:
if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__:
input_messages.append(msg)
output_messages.append(msg)
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__:
input_messages.append(task)
output_messages.append(task)
else:
Expand Down Expand Up @@ -116,13 +121,13 @@ async def run_stream(
yield text_msg
elif isinstance(task, list):
for msg in task:
if isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
if get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__:
input_messages.append(msg)
output_messages.append(msg)
yield msg
else:
raise ValueError(f"Invalid message type in list: {type(msg)}")
elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__:
input_messages.append(task)
output_messages.append(task)
yield task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
import logging
import uuid
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Callable, List, Mapping
from typing import Any, AsyncGenerator, Callable, List, Mapping, get_args, get_origin

try:
from typing import Annotated
except ImportError:
from typing_extensions import Annotated

from autogen_core import (
AgentId,
Expand Down Expand Up @@ -368,12 +373,15 @@ async def main() -> None:
pass
elif isinstance(task, str):
first_chat_message = TextMessage(content=task, source="user")
elif isinstance(task, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)):
elif get_origin(ChatMessage) is Annotated and task.__class__ in get_args(ChatMessage)[0].__args__:
first_chat_message = task
elif isinstance(task, list):
if not task:
raise ValueError("Task list cannot be empty")
if not all(isinstance(msg, (TextMessage, MultiModalMessage, StopMessage, HandoffMessage)) for msg in task):
if not all(
get_origin(ChatMessage) is Annotated and msg.__class__ in get_args(ChatMessage)[0].__args__
for msg in task
):
raise ValueError("All messages in task list must be valid ChatMessage types")
first_chat_message = task[0]
# Queue remaining messages for processing
Expand Down

0 comments on commit 9666740

Please sign in to comment.