Skip to content

Commit

Permalink
feat!: Add message context to signature of intervention handler, add …
Browse files Browse the repository at this point in the history
…more to docs (#4882)

* Add message context to signature of intervention handler, add more to docs

* example

* Add to test

* Fix pyright

* mypy
  • Loading branch information
jackgerrits authored Jan 7, 2025
1 parent f4382f0 commit 5b9be79
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"from typing import Any\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultInterventionHandler,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
Expand Down Expand Up @@ -100,7 +99,7 @@
" def __init__(self) -> None:\n",
" self._termination_value: Termination | None = None\n",
"\n",
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
" async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:\n",
" if isinstance(message, Termination):\n",
" self._termination_value = message\n",
" return message\n",
Expand Down Expand Up @@ -171,7 +170,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" async def on_send(\n",
" self, message: Any, *, message_context: MessageContext, recipient: AgentId\n",
" ) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
Expand Down
5 changes: 2 additions & 3 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from typing import Any, Mapping, Optional

from autogen_core import (
AgentId,
CancellationToken,
DefaultInterventionHandler,
DefaultTopicId,
Expand Down Expand Up @@ -211,7 +210,7 @@ class NeedsUserInputHandler(DefaultInterventionHandler):
def __init__(self):
self.question_for_user: GetSlowUserMessage | None = None

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, GetSlowUserMessage):
self.question_for_user = message
return message
Expand All @@ -231,7 +230,7 @@ class TerminationHandler(DefaultInterventionHandler):
def __init__(self):
self.terminateMessage: TerminateMessage | None = None

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, TerminateMessage):
self.terminateMessage = message
return message
Expand Down
58 changes: 50 additions & 8 deletions python/packages/autogen-core/src/autogen_core/_intervention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Protocol, final

from ._agent_id import AgentId
from ._message_context import MessageContext

__all__ = [
"DropMessage",
Expand All @@ -10,31 +11,72 @@


@final
class DropMessage: ...
class DropMessage:
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""

...


class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.
The handler is called when the message is submitted to the runtime.
Currently the only runtime which supports this is the :class:`autogen_core.base.SingleThreadedAgentRuntime`.
Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.
Example:
.. code-block:: python
from autogen_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
from dataclasses import dataclass
from typing import Any
@dataclass
class MyMessage:
content: str
class MyInterventionHandler(DefaultInterventionHandler):
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
if isinstance(message, MyMessage):
message.content = message.content.upper()
return message
runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])
"""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
"""Called when a message is submitted to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.send_message`."""
...

async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
"""Called when a message is published to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.publish_message`."""
...

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
...


class DefaultInterventionHandler(InterventionHandler):
"""Simple class that provides a default implementation for all intervention
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
return message

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
return message

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,16 @@ async def _process_next(self) -> None:
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_send(
message, message_context=message_context, recipient=recipient
)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
Expand Down Expand Up @@ -506,7 +515,14 @@ async def _process_next(self) -> None:
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
message_context = MessageContext(
sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_publish(message, message_context=message_context)
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
Expand Down
46 changes: 38 additions & 8 deletions python/packages/autogen-core/tests/test_intervention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from typing import Any

import pytest
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime
from autogen_core import (
AgentId,
DefaultInterventionHandler,
DefaultSubscription,
DefaultTopicId,
DropMessage,
MessageContext,
SingleThreadedAgentRuntime,
)
from autogen_core.exceptions import MessageDroppedException
from autogen_test_utils import LoopbackAgent, MessageType

Expand All @@ -8,10 +18,20 @@
async def test_intervention_count_messages() -> None:
class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
self.num_messages = 0
self.num_send_messages = 0
self.num_publish_messages = 0
self.num_response_messages = 0

async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> Any:
self.num_send_messages += 1
return message

async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
self.num_publish_messages += 1
return message

async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
self.num_messages += 1
async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any:
self.num_response_messages += 1
return message

handler = DebugInterventionHandler()
Expand All @@ -22,18 +42,28 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie

_response = await runtime.send_message(MessageType(), recipient=loopback)

await runtime.stop()
await runtime.stop_when_idle()

assert handler.num_messages == 1
assert handler.num_send_messages == 1
assert handler.num_response_messages == 1
loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent)
assert loopback_agent.num_calls == 1

runtime.start()
await runtime.add_subscription(DefaultSubscription(agent_type="name"))

await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())

await runtime.stop_when_idle()
assert loopback_agent.num_calls == 2
assert handler.num_publish_messages == 1


@pytest.mark.asyncio
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]:
return DropMessage

Expand Down Expand Up @@ -81,7 +111,7 @@ class InterventionException(Exception):

class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException

Expand Down

0 comments on commit 5b9be79

Please sign in to comment.