Skip to content

Commit 51c8b67

Browse files
authored
exception in intervention (#82)
* exception in intervention * log
1 parent 641d83f commit 51c8b67

File tree

2 files changed

+77
-15
lines changed

2 files changed

+77
-15
lines changed

src/agnext/application/_single_threaded_agent_runtime.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,11 @@ async def process_next(self) -> None:
272272
match message_envelope:
273273
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
274274
if self._before_send is not None:
275-
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
275+
try:
276+
temp_message = await self._before_send.on_send(message, sender=sender, recipient=recipient)
277+
except BaseException as e:
278+
future.set_exception(e)
279+
return
276280
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
277281
future.set_exception(MessageDroppedException())
278282
return
@@ -285,7 +289,12 @@ async def process_next(self) -> None:
285289
sender=sender,
286290
):
287291
if self._before_send is not None:
288-
temp_message = await self._before_send.on_publish(message, sender=sender)
292+
try:
293+
temp_message = await self._before_send.on_publish(message, sender=sender)
294+
except BaseException as e:
295+
# TODO: we should raise the intervention exception to the publisher.
296+
logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True)
297+
return
289298
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
290299
# TODO log message dropped
291300
return
@@ -295,7 +304,12 @@ async def process_next(self) -> None:
295304
asyncio.create_task(self._process_publish(message_envelope))
296305
case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
297306
if self._before_send is not None:
298-
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
307+
try:
308+
temp_message = await self._before_send.on_response(message, sender=sender, recipient=recipient)
309+
except BaseException as e:
310+
# TODO: should we raise the exception to sender of the response instead?
311+
future.set_exception(e)
312+
return
299313
if temp_message is DropMessage or isinstance(temp_message, DropMessage):
300314
future.set_exception(MessageDroppedException())
301315
return

tests/test_intervention.py

+60-12
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient
3535
return message
3636

3737
handler = DebugInterventionHandler()
38-
router = SingleThreadedAgentRuntime(before_send=handler)
38+
runtime = SingleThreadedAgentRuntime(before_send=handler)
3939

40-
long_running = LoopbackAgent("name", router)
41-
response = router.send_message(MessageType(), recipient=long_running)
40+
long_running = LoopbackAgent("name", runtime)
41+
response = runtime.send_message(MessageType(), recipient=long_running)
4242

4343
while not response.done():
44-
await router.process_next()
44+
await runtime.process_next()
4545

4646
assert handler.num_messages == 1
4747
assert long_running.num_calls == 1
@@ -54,13 +54,13 @@ async def on_send(self, message: MessageType, *, sender: Agent | None, recipient
5454
return DropMessage # type: ignore
5555

5656
handler = DropSendInterventionHandler()
57-
router = SingleThreadedAgentRuntime(before_send=handler)
57+
runtime = SingleThreadedAgentRuntime(before_send=handler)
5858

59-
long_running = LoopbackAgent("name", router)
60-
response = router.send_message(MessageType(), recipient=long_running)
59+
long_running = LoopbackAgent("name", runtime)
60+
response = runtime.send_message(MessageType(), recipient=long_running)
6161

6262
while not response.done():
63-
await router.process_next()
63+
await runtime.process_next()
6464

6565
with pytest.raises(MessageDroppedException):
6666
await response
@@ -76,15 +76,63 @@ async def on_response(self, message: MessageType, *, sender: Agent, recipient: A
7676
return DropMessage # type: ignore
7777

7878
handler = DropResponseInterventionHandler()
79-
router = SingleThreadedAgentRuntime(before_send=handler)
79+
runtime = SingleThreadedAgentRuntime(before_send=handler)
8080

81-
long_running = LoopbackAgent("name", router)
82-
response = router.send_message(MessageType(), recipient=long_running)
81+
long_running = LoopbackAgent("name", runtime)
82+
response = runtime.send_message(MessageType(), recipient=long_running)
8383

8484
while not response.done():
85-
await router.process_next()
85+
await runtime.process_next()
8686

8787
with pytest.raises(MessageDroppedException):
8888
await response
8989

9090
assert long_running.num_calls == 1
91+
92+
@pytest.mark.asyncio
93+
async def test_intervention_raise_exception_on_send() -> None:
94+
95+
class InterventionException(Exception):
96+
pass
97+
98+
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
99+
async def on_send(self, message: MessageType, *, sender: Agent | None, recipient: Agent) -> MessageType | type[DropMessage]: # type: ignore
100+
raise InterventionException
101+
102+
handler = ExceptionInterventionHandler()
103+
runtime = SingleThreadedAgentRuntime(before_send=handler)
104+
105+
long_running = LoopbackAgent("name", runtime)
106+
response = runtime.send_message(MessageType(), recipient=long_running)
107+
108+
while not response.done():
109+
await runtime.process_next()
110+
111+
with pytest.raises(InterventionException):
112+
await response
113+
114+
assert long_running.num_calls == 0
115+
116+
@pytest.mark.asyncio
117+
async def test_intervention_raise_exception_on_respond() -> None:
118+
119+
class InterventionException(Exception):
120+
pass
121+
122+
class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
123+
async def on_response(self, message: MessageType, *, sender: Agent, recipient: Agent | None) -> MessageType | type[DropMessage]: # type: ignore
124+
raise InterventionException
125+
126+
handler = ExceptionInterventionHandler()
127+
runtime = SingleThreadedAgentRuntime(before_send=handler)
128+
129+
long_running = LoopbackAgent("name", runtime)
130+
response = runtime.send_message(MessageType(), recipient=long_running)
131+
132+
while not response.done():
133+
await runtime.process_next()
134+
135+
with pytest.raises(InterventionException):
136+
await response
137+
138+
assert long_running.num_calls == 1

0 commit comments

Comments
 (0)