Skip to content

Commit a5681d7

Browse files
authored
Allow closure agent to ignore unknown messages, add docs (#4836)
Allow closure agent to ignore unknown messages
1 parent 2819515 commit a5681d7

File tree

1 file changed

+80
-8
lines changed

1 file changed

+80
-8
lines changed

python/packages/autogen-core/src/autogen_core/_closure_agent.py

+80-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

33
import inspect
4-
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints
4+
import warnings
5+
from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints
56

67
from ._agent_id import AgentId
78
from ._agent_instantiation import AgentInstantiationContext
@@ -73,7 +74,11 @@ async def publish_message(
7374

7475
class ClosureAgent(BaseAgent, ClosureContext):
7576
def __init__(
76-
self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]
77+
self,
78+
description: str,
79+
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
80+
*,
81+
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
7782
) -> None:
7883
try:
7984
runtime = AgentInstantiationContext.current_runtime()
@@ -89,6 +94,7 @@ def __init__(
8994
handled_types = get_handled_types_from_closure(closure)
9095
self._expected_types = handled_types
9196
self._closure = closure
97+
self._unknown_type_policy = unknown_type_policy
9298
super().__init__(description)
9399

94100
@property
@@ -110,9 +116,17 @@ def runtime(self) -> AgentRuntime:
110116

111117
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
112118
if type(message) not in self._expected_types:
113-
raise CantHandleException(
114-
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
115-
)
119+
if self._unknown_type_policy == "warn":
120+
warnings.warn(
121+
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",
122+
stacklevel=1,
123+
)
124+
return None
125+
elif self._unknown_type_policy == "error":
126+
raise CantHandleException(
127+
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning."
128+
)
129+
116130
return await self._closure(self, message, ctx)
117131

118132
async def save_state(self) -> Mapping[str, Any]:
@@ -130,19 +144,77 @@ async def register_closure(
130144
type: str,
131145
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
132146
*,
133-
skip_class_subscriptions: bool = False,
147+
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
134148
skip_direct_message_subscription: bool = False,
135149
description: str = "",
136150
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
137151
) -> AgentType:
152+
"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime.
153+
154+
The closure can define the type of message which is expected, or `Any` can be used to accept any type of message.
155+
156+
Example:
157+
158+
.. code-block:: python
159+
160+
import asyncio
161+
from autogen_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext
162+
from dataclasses import dataclass
163+
164+
from autogen_core._default_subscription import DefaultSubscription
165+
from autogen_core._default_topic import DefaultTopicId
166+
167+
168+
@dataclass
169+
class MyMessage:
170+
content: str
171+
172+
173+
async def main():
174+
queue = asyncio.Queue[MyMessage]()
175+
176+
async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None:
177+
await queue.put(message)
178+
179+
runtime = SingleThreadedAgentRuntime()
180+
await ClosureAgent.register_closure(
181+
runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()]
182+
)
183+
184+
runtime.start()
185+
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
186+
await runtime.stop_when_idle()
187+
188+
result = await queue.get()
189+
print(result)
190+
191+
192+
asyncio.run(main())
193+
194+
195+
Args:
196+
runtime (AgentRuntime): Runtime to register the agent to
197+
type (str): Agent type of registered agent
198+
closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages
199+
unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn".
200+
skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False.
201+
description (str, optional): Description of what agent does. Defaults to "".
202+
subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None.
203+
204+
Returns:
205+
AgentType: Type of the agent that was registered
206+
"""
207+
138208
def factory() -> ClosureAgent:
139-
return ClosureAgent(description=description, closure=closure)
209+
return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy)
140210

211+
assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions"
141212
agent_type = await cls.register(
142213
runtime=runtime,
143214
type=type,
144215
factory=factory, # type: ignore
145-
skip_class_subscriptions=skip_class_subscriptions,
216+
# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s
217+
skip_class_subscriptions=True,
146218
skip_direct_message_subscription=skip_direct_message_subscription,
147219
)
148220

0 commit comments

Comments
 (0)