Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Add message context to signature of intervention handler, add more to docs #4882

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"from typing import Any\n",
"\n",
"from autogen_core import (\n",
" AgentId,\n",
" DefaultInterventionHandler,\n",
" DefaultTopicId,\n",
" MessageContext,\n",
Expand Down Expand Up @@ -100,7 +99,7 @@
" def __init__(self) -> None:\n",
" self._termination_value: Termination | None = None\n",
"\n",
" async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:\n",
" async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:\n",
" if isinstance(message, Termination):\n",
" self._termination_value = message\n",
" return message\n",
Expand Down Expand Up @@ -171,7 +170,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" async def on_send(\n",
" self, message: Any, *, message_context: MessageContext, recipient: AgentId\n",
" ) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
Expand Down
5 changes: 2 additions & 3 deletions python/packages/autogen-core/samples/slow_human_in_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from typing import Any, Mapping, Optional

from autogen_core import (
AgentId,
CancellationToken,
DefaultInterventionHandler,
DefaultTopicId,
Expand Down Expand Up @@ -211,7 +210,7 @@ class NeedsUserInputHandler(DefaultInterventionHandler):
def __init__(self):
self.question_for_user: GetSlowUserMessage | None = None

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, GetSlowUserMessage):
self.question_for_user = message
return message
Expand All @@ -231,7 +230,7 @@ class TerminationHandler(DefaultInterventionHandler):
def __init__(self):
self.terminateMessage: TerminateMessage | None = None

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any:
if isinstance(message, TerminateMessage):
self.terminateMessage = message
return message
Expand Down
58 changes: 50 additions & 8 deletions python/packages/autogen-core/src/autogen_core/_intervention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Protocol, final

from ._agent_id import AgentId
from ._message_context import MessageContext

__all__ = [
"DropMessage",
Expand All @@ -10,31 +11,72 @@


@final
class DropMessage: ...
class DropMessage:
"""Marker type for signalling that a message should be dropped by an intervention handler. The type itself should be returned from the handler."""

...


class InterventionHandler(Protocol):
"""An intervention handler is a class that can be used to modify, log or drop messages that are being processed by the :class:`autogen_core.base.AgentRuntime`.

The handler is called when the message is submitted to the runtime.

Currently the only runtime which supports this is the :class:`autogen_core.base.SingleThreadedAgentRuntime`.

Note: Returning None from any of the intervention handler methods will result in a warning being issued and treated as "no change". If you intend to drop a message, you should return :class:`DropMessage` explicitly.

Example:

.. code-block:: python

from autogen_core import DefaultInterventionHandler, MessageContext, AgentId, SingleThreadedAgentRuntime
from dataclasses import dataclass
from typing import Any


@dataclass
class MyMessage:
content: str


class MyInterventionHandler(DefaultInterventionHandler):
async def on_send(self, message: Any, *, message_context: MessageContext, recipient: AgentId) -> MyMessage:
if isinstance(message, MyMessage):
message.content = message.content.upper()
return message


runtime = SingleThreadedAgentRuntime(intervention_handlers=[MyInterventionHandler()])

"""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]: ...
async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]: ...
async def on_response(
self, message: Any, *, sender: AgentId, recipient: AgentId | None
) -> Any | type[DropMessage]: ...
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
"""Called when a message is submitted to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.send_message`."""
...

Check warning on line 58 in python/packages/autogen-core/src/autogen_core/_intervention.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_intervention.py#L58

Added line #L58 was not covered by tests

async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
"""Called when a message is published to the AgentRuntime using :meth:`autogen_core.base.AgentRuntime.publish_message`."""
...

Check warning on line 62 in python/packages/autogen-core/src/autogen_core/_intervention.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_intervention.py#L62

Added line #L62 was not covered by tests

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
"""Called when a response is received by the AgentRuntime from an Agent's message handler returning a value."""
...

Check warning on line 66 in python/packages/autogen-core/src/autogen_core/_intervention.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_intervention.py#L66

Added line #L66 was not covered by tests


class DefaultInterventionHandler(InterventionHandler):
"""Simple class that provides a default implementation for all intervention
handler methods, that simply returns the message unchanged. Allows for easy
subclassing to override only the desired methods."""

async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:
async def on_send(
self, message: Any, *, message_context: MessageContext, recipient: AgentId
) -> Any | type[DropMessage]:
return message

async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any | type[DropMessage]:
async def on_publish(self, message: Any, *, message_context: MessageContext) -> Any | type[DropMessage]:
return message

async def on_response(self, message: Any, *, sender: AgentId, recipient: AgentId | None) -> Any | type[DropMessage]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,16 @@
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
message_context = MessageContext(
sender=sender,
topic_id=None,
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_send(
message, message_context=message_context, recipient=recipient
)
_warn_if_none(temp_message, "on_send")
except BaseException as e:
future.set_exception(e)
Expand Down Expand Up @@ -506,7 +515,14 @@
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
message_context = MessageContext(

Check warning on line 518 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L518

Added line #L518 was not covered by tests
Copy link
Collaborator

@ekzhu ekzhu Jan 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like some missing tests to make sure the runtime actually invokes the intervention handlers?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The publish_message code path is missing test.

sender=sender,
topic_id=topic_id,
is_rpc=False,
cancellation_token=message_envelope.cancellation_token,
message_id=message_envelope.message_id,
)
temp_message = await handler.on_publish(message, message_context=message_context)

Check warning on line 525 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L525

Added line #L525 was not covered by tests
_warn_if_none(temp_message, "on_publish")
except BaseException as e:
# TODO: we should raise the intervention exception to the publisher.
Expand Down
10 changes: 6 additions & 4 deletions python/packages/autogen-core/tests/test_intervention.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, SingleThreadedAgentRuntime
from autogen_core import AgentId, DefaultInterventionHandler, DropMessage, MessageContext, SingleThreadedAgentRuntime
from autogen_core.exceptions import MessageDroppedException
from autogen_test_utils import LoopbackAgent, MessageType

Expand All @@ -10,7 +10,9 @@ class DebugInterventionHandler(DefaultInterventionHandler):
def __init__(self) -> None:
self.num_messages = 0

async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType:
async def on_send(
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType:
self.num_messages += 1
return message

Expand All @@ -33,7 +35,7 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie
async def test_intervention_drop_send() -> None:
class DropSendInterventionHandler(DefaultInterventionHandler):
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]:
return DropMessage

Expand Down Expand Up @@ -81,7 +83,7 @@ class InterventionException(Exception):

class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore
async def on_send(
self, message: MessageType, *, sender: AgentId | None, recipient: AgentId
self, message: MessageType, *, message_context: MessageContext, recipient: AgentId
) -> MessageType | type[DropMessage]: # type: ignore
raise InterventionException

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Sequence

import pytest
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter, SpanExportResult

Expand Down
Loading