Skip to content

Commit 8cb530f

Browse files
authored
Simplify handler decorator (#50)
* Simplify handler decorator * add more tests * mypy * formatting * fix 3.10 and improve type handling of decorator * test fix * format
1 parent ad513d5 commit 8cb530f

10 files changed

+191
-29
lines changed

examples/futures.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class Inner(TypeRoutedAgent):
1616
def __init__(self, name: str, router: AgentRuntime) -> None:
1717
super().__init__(name, router)
1818

19-
@message_handler(MessageType)
19+
@message_handler()
2020
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
2121
return MessageType(body=f"Inner: {message.body}", sender=self.name)
2222

@@ -26,7 +26,7 @@ def __init__(self, name: str, router: AgentRuntime, inner: Agent) -> None:
2626
super().__init__(name, router)
2727
self._inner = inner
2828

29-
@message_handler(MessageType)
29+
@message_handler()
3030
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
3131
inner_response = self._send_message(message, self._inner)
3232
inner_message = await inner_response

src/agnext/chat/agents/chat_completion_agent.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def __init__(
3838
self._chat_messages: List[Message] = []
3939
self._function_executor = function_executor
4040

41-
@message_handler(TextMessage)
41+
@message_handler()
4242
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
4343
# Add a user message.
4444
self._chat_messages.append(message)
4545

46-
@message_handler(Reset)
46+
@message_handler()
4747
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
4848
# Reset the chat messages.
4949
self._chat_messages = []
5050

51-
@message_handler(RespondNow)
51+
@message_handler()
5252
async def on_respond_now(
5353
self, message: RespondNow, cancellation_token: CancellationToken
5454
) -> TextMessage | FunctionCallMessage:
@@ -101,7 +101,7 @@ async def on_respond_now(
101101
# Return the response.
102102
return final_response
103103

104-
@message_handler(FunctionCallMessage)
104+
@message_handler()
105105
async def on_tool_call_message(
106106
self, message: FunctionCallMessage, cancellation_token: CancellationToken
107107
) -> FunctionExecutionResultMessage:

src/agnext/chat/agents/oai_assistant.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(
2424
self._assistant_id = assistant_id
2525
self._thread_id = thread_id
2626

27-
@message_handler(TextMessage)
27+
@message_handler()
2828
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
2929
# Save the message to the thread.
3030
_ = await self._client.beta.threads.messages.create(
@@ -34,7 +34,7 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel
3434
metadata={"sender": message.source},
3535
)
3636

37-
@message_handler(Reset)
37+
@message_handler()
3838
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
3939
# Get all messages in this thread.
4040
all_msgs: List[str] = []
@@ -52,7 +52,7 @@ async def on_reset(self, message: Reset, cancellation_token: CancellationToken)
5252
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
5353
assert status.deleted is True
5454

55-
@message_handler(RespondNow)
55+
@message_handler()
5656
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
5757
# Handle response format.
5858
if message.response_format == ResponseFormat.json_object:

src/agnext/chat/patterns/group_chat.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,16 @@ def subscriptions(self) -> Sequence[type]:
3535
agent_sublists = [agent.subscriptions for agent in self._agents]
3636
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
3737

38-
@message_handler(Reset)
38+
@message_handler()
3939
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
4040
self._history.clear()
4141

42-
@message_handler(RespondNow)
42+
@message_handler()
4343
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
4444
return self._output.get_output()
4545

46-
@message_handler(TextMessage)
47-
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
46+
@message_handler()
47+
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> Any:
4848
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
4949
# Should this class disallow it?
5050

src/agnext/chat/patterns/orchestrator_chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
def children(self) -> Sequence[str]:
3535
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
3636

37-
@message_handler(TextMessage)
37+
@message_handler()
3838
async def on_text_message(
3939
self,
4040
message: TextMessage,
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
2+
3+
from ...core import AgentRuntime
4+
from ..agents.base import BaseChatAgent
5+
6+
7+
class TwoAgentChat(GroupChat):
8+
def __init__(
9+
self,
10+
name: str,
11+
description: str,
12+
runtime: AgentRuntime,
13+
agent1: BaseChatAgent,
14+
agent2: BaseChatAgent,
15+
num_rounds: int,
16+
output: GroupChatOutput,
17+
) -> None:
18+
super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output)

src/agnext/components/type_routed_agent.py

+117-12
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,132 @@
1-
from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar
1+
import logging
2+
from functools import wraps
3+
from types import NoneType, UnionType
4+
from typing import (
5+
Any,
6+
Callable,
7+
Coroutine,
8+
Dict,
9+
Literal,
10+
NoReturn,
11+
Optional,
12+
Protocol,
13+
Sequence,
14+
Type,
15+
TypeVar,
16+
Union,
17+
cast,
18+
get_args,
19+
get_origin,
20+
get_type_hints,
21+
runtime_checkable,
22+
)
223

324
from agnext.core import AgentRuntime, BaseAgent, CancellationToken
425
from agnext.core.exceptions import CantHandleException
526

6-
ReceivesT = TypeVar("ReceivesT")
27+
logger = logging.getLogger("agnext")
28+
29+
ReceivesT = TypeVar("ReceivesT", contravariant=True)
730
ProducesT = TypeVar("ProducesT", covariant=True)
831

932
# TODO: Generic typevar bound binding U to agent type
1033
# Can't do because python doesnt support it
1134

1235

36+
def is_union(t: object) -> bool:
37+
origin = get_origin(t)
38+
return origin is Union or origin is UnionType
39+
40+
41+
def is_optional(t: object) -> bool:
42+
origin = get_origin(t)
43+
return origin is Optional
44+
45+
46+
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
47+
class AnyType:
48+
pass
49+
50+
51+
def get_types(t: object) -> Sequence[Type[Any]] | None:
52+
if is_union(t):
53+
return get_args(t)
54+
elif is_optional(t):
55+
return tuple(list(get_args(t)) + [NoneType])
56+
elif t is Any:
57+
return (AnyType,)
58+
elif isinstance(t, type):
59+
return (t,)
60+
elif isinstance(t, NoneType):
61+
return (NoneType,)
62+
else:
63+
return None
64+
65+
66+
@runtime_checkable
67+
class MessageHandler(Protocol[ReceivesT, ProducesT]):
68+
target_types: Sequence[type]
69+
produces_types: Sequence[type]
70+
is_message_handler: Literal[True]
71+
72+
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
73+
74+
1375
# NOTE: this works on concrete types and not inheritance
76+
# TODO: Use a protocl for the outer function to check checked arg names
1477
def message_handler(
15-
*target_types: Type[ReceivesT],
78+
strict: bool = True,
1679
) -> Callable[
17-
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
18-
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
80+
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
81+
MessageHandler[ReceivesT, ProducesT],
1982
]:
2083
def decorator(
21-
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
22-
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
84+
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
85+
) -> MessageHandler[ReceivesT, ProducesT]:
86+
type_hints = get_type_hints(func)
87+
if "message" not in type_hints:
88+
raise AssertionError("message parameter not found in function signature")
89+
90+
if "return" not in type_hints:
91+
raise AssertionError("return not found in function signature")
92+
93+
# Get the type of the message parameter
94+
target_types = get_types(type_hints["message"])
95+
if target_types is None:
96+
raise AssertionError("Message type not found")
97+
98+
print(type_hints)
99+
return_types = get_types(type_hints["return"])
100+
101+
if return_types is None:
102+
raise AssertionError("Return type not found")
103+
23104
# Convert target_types to list and stash
24-
func._target_types = list(target_types) # type: ignore
25-
return func
105+
106+
@wraps(func)
107+
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
108+
if strict:
109+
if type(message) not in target_types:
110+
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
111+
else:
112+
logger.warning(f"Message type {type(message)} not in target types {target_types}")
113+
114+
return_value = await func(self, message, cancellation_token)
115+
116+
if strict:
117+
if return_value is not AnyType and type(return_value) not in return_types:
118+
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
119+
elif return_value is not AnyType:
120+
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
121+
122+
return return_value
123+
124+
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
125+
wrapper_handler.target_types = list(target_types)
126+
wrapper_handler.produces_types = list(return_types)
127+
wrapper_handler.is_message_handler = True
128+
129+
return wrapper_handler
26130

27131
return decorator
28132

@@ -35,9 +139,10 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
35139
for attr in dir(self):
36140
if callable(getattr(self, attr, None)):
37141
handler = getattr(self, attr)
38-
if hasattr(handler, "_target_types"):
39-
for target_type in handler._target_types:
40-
self._handlers[target_type] = handler
142+
if hasattr(handler, "is_message_handler"):
143+
message_handler = cast(MessageHandler[Any, Any], handler)
144+
for target_type in message_handler.target_types:
145+
self._handlers[target_type] = message_handler
41146

42147
super().__init__(name, router)
43148

tests/test_cancellation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
2222
self.called = False
2323
self.cancelled = False
2424

25-
@message_handler(MessageType)
25+
@message_handler()
2626
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
2727
self.called = True
2828
sleep = asyncio.ensure_future(asyncio.sleep(100))
@@ -41,7 +41,7 @@ def __init__(self, name: str, router: AgentRuntime, nested_agent: Agent) -> None
4141
self.cancelled = False
4242
self._nested_agent = nested_agent
4343

44-
@message_handler(MessageType)
44+
@message_handler()
4545
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
4646
self.called = True
4747
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)

tests/test_intervention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
1919
self.num_calls = 0
2020

2121

22-
@message_handler(MessageType)
22+
@message_handler()
2323
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
2424
self.num_calls += 1
2525
return message

tests/test_types.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from types import NoneType
2+
from typing import Any, Optional, Union
3+
4+
from agnext.components.type_routed_agent import AnyType, get_types, message_handler
5+
from agnext.core import CancellationToken
6+
7+
8+
def test_get_types() -> None:
9+
assert get_types(Union[int, str]) == (int, str)
10+
assert get_types(int | str) == (int, str)
11+
assert get_types(int) == (int,)
12+
assert get_types(str) == (str,)
13+
assert get_types("test") is None
14+
assert get_types(Optional[int]) == (int, NoneType)
15+
assert get_types(NoneType) == (NoneType,)
16+
assert get_types(None) == (NoneType,)
17+
18+
19+
def test_handler() -> None:
20+
21+
class HandlerClass:
22+
@message_handler()
23+
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
24+
return None
25+
26+
@message_handler()
27+
async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None:
28+
return None
29+
30+
assert HandlerClass.handler.target_types == [int]
31+
assert HandlerClass.handler.produces_types == [AnyType]
32+
33+
assert HandlerClass.handler2.target_types == [str, bool]
34+
assert HandlerClass.handler2.produces_types == [NoneType]
35+
36+
class HandlerClass:
37+
@message_handler()
38+
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
39+
return None

0 commit comments

Comments
 (0)