1
- from typing import Any , Callable , Coroutine , Dict , NoReturn , Sequence , Type , TypeVar
1
+ import logging
2
+ from functools import wraps
3
+ from types import NoneType , UnionType
4
+ from typing import (
5
+ Any ,
6
+ Callable ,
7
+ Coroutine ,
8
+ Dict ,
9
+ Literal ,
10
+ NoReturn ,
11
+ Optional ,
12
+ Protocol ,
13
+ Sequence ,
14
+ Type ,
15
+ TypeVar ,
16
+ Union ,
17
+ cast ,
18
+ get_args ,
19
+ get_origin ,
20
+ get_type_hints ,
21
+ runtime_checkable ,
22
+ )
2
23
3
24
from agnext .core import AgentRuntime , BaseAgent , CancellationToken
4
25
from agnext .core .exceptions import CantHandleException
5
26
6
- ReceivesT = TypeVar ("ReceivesT" )
27
+ logger = logging .getLogger ("agnext" )
28
+
29
+ ReceivesT = TypeVar ("ReceivesT" , contravariant = True )
7
30
ProducesT = TypeVar ("ProducesT" , covariant = True )
8
31
9
32
# TODO: Generic typevar bound binding U to agent type
10
33
# Can't do because python doesnt support it
11
34
12
35
36
+ def is_union (t : object ) -> bool :
37
+ origin = get_origin (t )
38
+ return origin is Union or origin is UnionType
39
+
40
+
41
+ def is_optional (t : object ) -> bool :
42
+ origin = get_origin (t )
43
+ return origin is Optional
44
+
45
+
46
+ # Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
47
+ class AnyType :
48
+ pass
49
+
50
+
51
+ def get_types (t : object ) -> Sequence [Type [Any ]] | None :
52
+ if is_union (t ):
53
+ return get_args (t )
54
+ elif is_optional (t ):
55
+ return tuple (list (get_args (t )) + [NoneType ])
56
+ elif t is Any :
57
+ return (AnyType ,)
58
+ elif isinstance (t , type ):
59
+ return (t ,)
60
+ elif isinstance (t , NoneType ):
61
+ return (NoneType ,)
62
+ else :
63
+ return None
64
+
65
+
66
+ @runtime_checkable
67
+ class MessageHandler (Protocol [ReceivesT , ProducesT ]):
68
+ target_types : Sequence [type ]
69
+ produces_types : Sequence [type ]
70
+ is_message_handler : Literal [True ]
71
+
72
+ async def __call__ (self , message : ReceivesT , cancellation_token : CancellationToken ) -> ProducesT : ...
73
+
74
+
13
75
# NOTE: this works on concrete types and not inheritance
76
+ # TODO: Use a protocl for the outer function to check checked arg names
14
77
def message_handler (
15
- * target_types : Type [ ReceivesT ] ,
78
+ strict : bool = True ,
16
79
) -> Callable [
17
- [Callable [[Any , ReceivesT , CancellationToken ], Coroutine [Any , Any , ProducesT | None ]]],
18
- Callable [[ Any , ReceivesT , CancellationToken ], Coroutine [ Any , Any , ProducesT | None ] ],
80
+ [Callable [[Any , ReceivesT , CancellationToken ], Coroutine [Any , Any , ProducesT ]]],
81
+ MessageHandler [ ReceivesT , ProducesT ],
19
82
]:
20
83
def decorator (
21
- func : Callable [[Any , ReceivesT , CancellationToken ], Coroutine [Any , Any , ProducesT | None ]],
22
- ) -> Callable [[Any , ReceivesT , CancellationToken ], Coroutine [Any , Any , ProducesT | None ]]:
84
+ func : Callable [[Any , ReceivesT , CancellationToken ], Coroutine [Any , Any , ProducesT ]],
85
+ ) -> MessageHandler [ReceivesT , ProducesT ]:
86
+ type_hints = get_type_hints (func )
87
+ if "message" not in type_hints :
88
+ raise AssertionError ("message parameter not found in function signature" )
89
+
90
+ if "return" not in type_hints :
91
+ raise AssertionError ("return not found in function signature" )
92
+
93
+ # Get the type of the message parameter
94
+ target_types = get_types (type_hints ["message" ])
95
+ if target_types is None :
96
+ raise AssertionError ("Message type not found" )
97
+
98
+ print (type_hints )
99
+ return_types = get_types (type_hints ["return" ])
100
+
101
+ if return_types is None :
102
+ raise AssertionError ("Return type not found" )
103
+
23
104
# Convert target_types to list and stash
24
- func ._target_types = list (target_types ) # type: ignore
25
- return func
105
+
106
+ @wraps (func )
107
+ async def wrapper (self : Any , message : ReceivesT , cancellation_token : CancellationToken ) -> ProducesT :
108
+ if strict :
109
+ if type (message ) not in target_types :
110
+ raise CantHandleException (f"Message type { type (message )} not in target types { target_types } " )
111
+ else :
112
+ logger .warning (f"Message type { type (message )} not in target types { target_types } " )
113
+
114
+ return_value = await func (self , message , cancellation_token )
115
+
116
+ if strict :
117
+ if return_value is not AnyType and type (return_value ) not in return_types :
118
+ raise ValueError (f"Return type { type (return_value )} not in return types { return_types } " )
119
+ elif return_value is not AnyType :
120
+ logger .warning (f"Return type { type (return_value )} not in return types { return_types } " )
121
+
122
+ return return_value
123
+
124
+ wrapper_handler = cast (MessageHandler [ReceivesT , ProducesT ], wrapper )
125
+ wrapper_handler .target_types = list (target_types )
126
+ wrapper_handler .produces_types = list (return_types )
127
+ wrapper_handler .is_message_handler = True
128
+
129
+ return wrapper_handler
26
130
27
131
return decorator
28
132
@@ -35,9 +139,10 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
35
139
for attr in dir (self ):
36
140
if callable (getattr (self , attr , None )):
37
141
handler = getattr (self , attr )
38
- if hasattr (handler , "_target_types" ):
39
- for target_type in handler ._target_types :
40
- self ._handlers [target_type ] = handler
142
+ if hasattr (handler , "is_message_handler" ):
143
+ message_handler = cast (MessageHandler [Any , Any ], handler )
144
+ for target_type in message_handler .target_types :
145
+ self ._handlers [target_type ] = message_handler
41
146
42
147
super ().__init__ (name , router )
43
148
0 commit comments