From 0c9666f5b7a940011574329ddffea33251fd6534 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Wed, 27 Nov 2024 16:32:04 -0500 Subject: [PATCH 01/10] WIP - rpc over events --- docs/design/02 - Topics.md | 2 +- .../framework/agent-and-agent-runtime.ipynb | 2 +- .../_single_threaded_agent_runtime.py | 101 +++++++++--------- .../application/_worker_runtime.py | 1 - .../src/autogen_core/base/_agent.py | 5 +- .../src/autogen_core/base/_agent_id.py | 5 +- .../src/autogen_core/base/_agent_runtime.py | 1 - .../src/autogen_core/base/_base_agent.py | 51 +++++++-- .../src/autogen_core/base/_message_context.py | 3 +- .../src/autogen_core/base/_rpc.py | 31 ++++++ .../autogen_core/components/_closure_agent.py | 16 +-- .../autogen_core/components/_routed_agent.py | 59 ++++++++-- .../components/send_message_mixin.py | 61 +++++++++++ python/packages/autogen-core/test.py | 35 ++++++ .../packages/autogen-core/tests/test_state.py | 2 +- .../packages/autogen-core/tests/test_types.py | 6 +- .../autogen-core/tests/test_utils/__init__.py | 2 +- 17 files changed, 292 insertions(+), 91 deletions(-) create mode 100644 python/packages/autogen-core/src/autogen_core/base/_rpc.py create mode 100644 python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py create mode 100644 python/packages/autogen-core/test.py diff --git a/docs/design/02 - Topics.md b/docs/design/02 - Topics.md index bf3ed8d9dcac..c64ac50a2740 100644 --- a/docs/design/02 - Topics.md +++ b/docs/design/02 - Topics.md @@ -61,6 +61,6 @@ For this subscription source should map directly to agent key. This subscription will therefore receive all events for the following well known topics: - `{AgentType}:` - General purpose direct messages. These should be routed to the approriate message handler. -- `{AgentType}:rpc_request` - RPC request messages. These should be routed to the approriate RPC handler. +- `{AgentType}:rpc_request={RequesterAgentType}` - RPC request messages. These should be routed to the approriate RPC handler. - `{AgentType}:rpc_response={RequestId}` - RPC response messages. These should be routed back to the response future of the caller. - `{AgentType}:error={RequestId}` - Error message that corresponds to the given request. diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb index fdd7aed5644e..5884967418d5 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/framework/agent-and-agent-runtime.ipynb @@ -67,7 +67,7 @@ " def __init__(self) -> None:\n", " super().__init__(\"MyAgent\")\n", "\n", - " async def on_message(self, message: MyMessageType, ctx: MessageContext) -> None:\n", + " async def on_message_impl(self, message: MyMessageType, ctx: MessageContext) -> None:\n", " print(f\"Received message: {message.content}\") # type: ignore" ] }, diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 3d81f15eb330..9b402c11c4a7 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -16,6 +16,7 @@ from typing_extensions import deprecated from autogen_core.base._serialization import MessageSerializer, SerializationRegistry +from autogen_core.components.send_message_mixin import PublishBasedRpcMixin from ..base import ( Agent, @@ -166,7 +167,7 @@ def _warn_if_none(value: Any, handler_name: str) -> None: ) -class SingleThreadedAgentRuntime(AgentRuntime): +class SingleThreadedAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, *, @@ -202,54 +203,54 @@ def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) # Returns the response of the message - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if cancellation_token is None: - cancellation_token = CancellationToken() - - # event_logger.info( - # MessageEvent( - # payload=message, - # sender=sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.SEND, - # ) - # ) - - with self._tracer_helper.trace_block( - "create", - recipient, - parent=None, - extraAttributes={"message_type": type(message).__name__}, - ): - future = asyncio.get_event_loop().create_future() - if recipient.type not in self._known_agent_names: - future.set_exception(Exception("Recipient not found")) - - content = message.__dict__ if hasattr(message, "__dict__") else message - logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - - self._message_queue.append( - SendMessageEnvelope( - message=message, - recipient=recipient, - future=future, - cancellation_token=cancellation_token, - sender=sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - - cancellation_token.link_future(future) - - return await future + # async def send_message( + # self, + # message: Any, + # recipient: AgentId, + # *, + # sender: AgentId | None = None, + # cancellation_token: CancellationToken | None = None, + # ) -> Any: + # if cancellation_token is None: + # cancellation_token = CancellationToken() + + # # event_logger.info( + # # MessageEvent( + # # payload=message, + # # sender=sender, + # # receiver=recipient, + # # kind=MessageKind.DIRECT, + # # delivery_stage=DeliveryStage.SEND, + # # ) + # # ) + + # with self._tracer_helper.trace_block( + # "create", + # recipient, + # parent=None, + # extraAttributes={"message_type": type(message).__name__}, + # ): + # future = asyncio.get_event_loop().create_future() + # if recipient.type not in self._known_agent_names: + # future.set_exception(Exception("Recipient not found")) + + # content = message.__dict__ if hasattr(message, "__dict__") else message + # logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") + + # self._message_queue.append( + # SendMessageEnvelope( + # message=message, + # recipient=recipient, + # future=future, + # cancellation_token=cancellation_token, + # sender=sender, + # metadata=get_telemetry_envelope_metadata(), + # ) + # ) + + # cancellation_token.link_future(future) + + # return await future async def publish_message( self, @@ -332,7 +333,6 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: message_context = MessageContext( sender=message_envelope.sender, topic_id=None, - is_rpc=True, cancellation_token=message_envelope.cancellation_token, # Will be fixed when send API removed message_id="NOT_DEFINED_TODO_FIX", @@ -392,7 +392,6 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No message_context = MessageContext( sender=message_envelope.sender, topic_id=message_envelope.topic_id, - is_rpc=False, cancellation_token=message_envelope.cancellation_token, message_id=message_envelope.message_id, ) diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 24007fadfc7d..24714f29155b 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -497,7 +497,6 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: message_context = MessageContext( sender=sender, topic_id=None, - is_rpc=True, cancellation_token=CancellationToken(), message_id=request.request_id, ) diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent.py b/python/packages/autogen-core/src/autogen_core/base/_agent.py index edb5e59b1ce3..0202522d08a3 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent.py @@ -17,16 +17,13 @@ def id(self) -> AgentId: """ID of the agent.""" ... - async def on_message(self, message: Any, ctx: MessageContext) -> Any: + async def on_message(self, message: Any, ctx: MessageContext) -> None: """Message handler for the agent. This should only be called by the runtime, not by other agents. Args: message (Any): Received message. Type is one of the types in `subscriptions`. ctx (MessageContext): Context of the message. - Returns: - Any: Response to the message. Can be None. - Raises: asyncio.CancelledError: If the message was cancelled. CantHandleException: If the agent cannot handle the message. diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_id.py b/python/packages/autogen-core/src/autogen_core/base/_agent_id.py index 06f163ed9c30..b3a129ac4176 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_id.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_id.py @@ -8,8 +8,9 @@ def __init__(self, type: str | AgentType, key: str) -> None: if isinstance(type, AgentType): type = type.type - if type.isidentifier() is False: - raise ValueError(f"Invalid type: {type}") + # TODO: fixme + # if type.isidentifier() is False: + # raise ValueError(f"Invalid type: {type}") self._type = type self._key = key diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py index 27c37ad9f349..0cd65ef383be 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_runtime.py @@ -26,7 +26,6 @@ async def send_message( message: Any, recipient: AgentId, *, - sender: AgentId | None = None, cancellation_token: CancellationToken | None = None, ) -> Any: """Send a message to an agent and get a response. diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 70481705ca6e..25ce6179b432 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -1,13 +1,17 @@ from __future__ import annotations import inspect +import uuid import warnings from abc import ABC, abstractmethod +from asyncio import Future from collections.abc import Sequence -from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar +from typing import Any, Awaitable, Callable, ClassVar, Dict, List, Mapping, Tuple, Type, TypeVar, final from typing_extensions import Self +from autogen_core.base._rpc import format_rpc_request_topic, is_rpc_response + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -53,7 +57,6 @@ def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: return decorator - class BaseAgent(ABC, Agent): internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] @@ -77,7 +80,17 @@ def metadata(self) -> AgentMetadata: assert self._id is not None return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description) - def __init__(self, description: str) -> None: + def __init__(self, description: str, *, forward_unbound_rpc_responses_to_handler: bool = False) -> None: + """Base agent that all agents should inherit from. Puts in place assumed common functionality. + + Args: + description (str): Description of the agent. + forward_unbound_rpc_responses_to_handler (bool, optional): If an rpc request ID is not know to the agent, should the rpc request be forwarded to the handler. Defaults to False. + + Raises: + RuntimeError: If the agent is not instantiated within the context of an AgentRuntime. + ValueError: If there is an argument type error. + """ try: runtime = AgentInstantiationContext.current_runtime() id = AgentInstantiationContext.current_agent_id() @@ -91,6 +104,8 @@ def __init__(self, description: str) -> None: if not isinstance(description, str): raise ValueError("Agent description must be a string") self._description = description + self._pending_rpc_requests: Dict[str, Future[Any]] = {} + self._forward_unbound_rpc_responses_to_handler = forward_unbound_rpc_responses_to_handler @property def type(self) -> str: @@ -105,7 +120,21 @@ def runtime(self) -> AgentRuntime: return self._runtime @abstractmethod - async def on_message(self, message: Any, ctx: MessageContext) -> Any: ... + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... + + @final + async def on_message(self, message: Any, ctx: MessageContext) -> None: + # Intercept RPC responses + if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: + if request_id in self._pending_rpc_requests: + self._pending_rpc_requests[request_id].set_result(message) + elif self._forward_unbound_rpc_responses_to_handler: + await self.on_message_impl(message, ctx) + else: + warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) + return None + + return await self.on_message_impl(message, ctx) async def send_message( self, @@ -118,13 +147,23 @@ async def send_message( if cancellation_token is None: cancellation_token = CancellationToken() - return await self._runtime.send_message( + recipient_topic = TopicId(type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), source=recipient.key) + request_id = str(uuid.uuid4()) + + future = Future[Any]() + + await self._runtime.publish_message( message, sender=self.id, - recipient=recipient, + topic_id=recipient_topic, cancellation_token=cancellation_token, + message_id=request_id, ) + self._pending_rpc_requests[request_id] = future + + return future + async def publish_message( self, message: Any, diff --git a/python/packages/autogen-core/src/autogen_core/base/_message_context.py b/python/packages/autogen-core/src/autogen_core/base/_message_context.py index c5c00559ed0e..65cbbb64d4bb 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_message_context.py +++ b/python/packages/autogen-core/src/autogen_core/base/_message_context.py @@ -8,7 +8,6 @@ @dataclass class MessageContext: sender: AgentId | None - topic_id: TopicId | None - is_rpc: bool + topic_id: TopicId cancellation_token: CancellationToken message_id: str diff --git a/python/packages/autogen-core/src/autogen_core/base/_rpc.py b/python/packages/autogen-core/src/autogen_core/base/_rpc.py new file mode 100644 index 000000000000..d6b857afd0e8 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/base/_rpc.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Optional + + + + +def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str: + return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" + +def format_rpc_response_topic(rpc_sender_agent_type: str,request_id: str) -> str: + return f"{rpc_sender_agent_type}:rpc_response={request_id}" + +# If is an rpc response, return the request id +def is_rpc_response(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_response= + for segment in topic_segments: + if segment.startswith("rpc_response="): + return segment[len("rpc_response=") :] + return None + + +# If is an rpc response, return the requestor agent type +def is_rpc_request(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_request= + for segment in topic_segments: + if segment.startswith("rpc_request="): + return segment[len("rpc_request=") :] + return None diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index 12e5faae6bf1..d6b100547bd1 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -18,7 +18,7 @@ Subscription, TopicId, ) -from ..base._type_helpers import get_types +from ..base._type_helpers import AnyType, get_types from ..base.exceptions import CantHandleException T = TypeVar("T") @@ -76,7 +76,7 @@ async def publish_message( class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]] + self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, forward_unbound_rpc_responses_to_handler: bool = False ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -92,7 +92,7 @@ def __init__( handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure - super().__init__(description) + super().__init__(description, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) @property def metadata(self) -> AgentMetadata: @@ -111,8 +111,8 @@ def id(self) -> AgentId: def runtime(self) -> AgentRuntime: return self._runtime - async def on_message(self, message: Any, ctx: MessageContext) -> Any: - if type(message) not in self._expected_types: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: + if AnyType not in self._expected_types and type(message) not in self._expected_types: raise CantHandleException( f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" ) @@ -131,19 +131,19 @@ async def register_closure( type: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, - skip_class_subscriptions: bool = False, skip_direct_message_subscription: bool = False, + forward_unbound_rpc_responses_to_handler: bool = False, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure) + return ClosureAgent(description=description, closure=closure, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) agent_type = await cls.register( runtime=runtime, type=type, factory=factory, # type: ignore - skip_class_subscriptions=skip_class_subscriptions, + skip_class_subscriptions=True, skip_direct_message_subscription=skip_direct_message_subscription, ) diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index e7f266bf49d6..50ef43aa67f9 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -19,13 +19,16 @@ runtime_checkable, ) +from autogen_core.base._rpc import format_rpc_response_topic, is_rpc_request +from autogen_core.base._topic import TopicId + from ..base import BaseAgent, MessageContext, MessageSerializer, try_get_known_serializers_for_type from ..base._type_helpers import AnyType, get_types from ..base.exceptions import CantHandleException logger = logging.getLogger("autogen_core") -AgentT = TypeVar("AgentT") +AgentT = TypeVar("AgentT", bound=BaseAgent) ReceivesT = TypeVar("ReceivesT") ProducesT = TypeVar("ProducesT", covariant=True) @@ -138,7 +141,7 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") @@ -153,7 +156,26 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish it if you need to... + # Any return is treated as a response to the RPC request and is published accordingly + + if return_value is not None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + + await self.publish_message( + message=return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + ) + else: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) @@ -278,8 +300,8 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - # Wrap the match function with a check on the is_rpc flag. - wrapper_handler.router = lambda _message, _ctx: (not _ctx.is_rpc) and (match(_message, _ctx) if match else True) + # Wrap the match function with a check on the topic for rpc + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and (match(_message, _ctx) if match else True) return wrapper_handler @@ -378,7 +400,7 @@ def decorator( # Convert target_types to list and stash @wraps(func) - async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> ProducesT: + async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None: if type(message) not in target_types: if strict: raise CantHandleException(f"Message type {type(message)} not in target types {target_types}") @@ -393,13 +415,32 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> Prod else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - return return_value + # Dont return, but publish it if you need to... + # Any return is treated as a response to the RPC request and is published accordingly + + if return_value is not None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + + await self.publish_message( + message=return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + ) + else: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - wrapper_handler.router = lambda _message, _ctx: (_ctx.is_rpc) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and (match(_message, _ctx) if match else True) return wrapper_handler @@ -470,7 +511,7 @@ def __init__(self, description: str) -> None: super().__init__(description) - async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext): """Handle a message by routing it to the appropriate message handler. Do not override this method in subclasses. Instead, add message handlers as methods decorated with either the :func:`event` or :func:`rpc` decorator.""" diff --git a/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py b/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py new file mode 100644 index 000000000000..2ae0e09eebaf --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py @@ -0,0 +1,61 @@ + +from typing import Any +import uuid +import warnings + +from autogen_core.base._rpc import format_rpc_request_topic, format_rpc_response_topic +from autogen_core.base._topic import TopicId + +from ..base._message_context import MessageContext + +from ..base._agent_id import AgentId +from ..base._cancellation_token import CancellationToken +from ._closure_agent import ClosureAgent, ClosureContext +from ..base._agent_runtime import AgentRuntime + + +import asyncio + +class PublishBasedRpcMixin(AgentRuntime): + async def send_message( + self: AgentRuntime, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + ) -> Any: + + rpc_request_id = str(uuid.uuid4()) + # TODO add "-" to topic and agent type allowed characters in spec + closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" + + future: asyncio.Future[Any] = asyncio.Future() + expected_response_topic_type = format_rpc_response_topic(rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id) + async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageContext) -> None: + assert ctx.topic_id is not None + if ctx.topic_id.type == expected_response_topic_type: + future.set_result(message) + else: + warnings.warn(f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", stacklevel=2) + + # TODO: remove agent after response is received + + await ClosureAgent.register_closure( + runtime=self, + type=closure_agent_type, + closure=set_result, + forward_unbound_rpc_responses_to_handler=True, + ) + + rpc_request_topic_id = format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type) + await self.publish_message( + message=message, + topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), + cancellation_token=cancellation_token, + message_id=rpc_request_id, + ) + + return await future + + # register a closure agent... + diff --git a/python/packages/autogen-core/test.py b/python/packages/autogen-core/test.py new file mode 100644 index 000000000000..d8c5672d1feb --- /dev/null +++ b/python/packages/autogen-core/test.py @@ -0,0 +1,35 @@ +from dataclasses import dataclass + +from autogen_core.base import MessageContext +from autogen_core.base._agent_id import AgentId +from autogen_core.components import RoutedAgent +from autogen_core.components._routed_agent import rpc + +from autogen_core.application import SingleThreadedAgentRuntime +import asyncio + +@dataclass +class Message: + content: str + +class MyAgent(RoutedAgent): + def __init__(self) -> None: + super().__init__("My agent") + + @rpc + async def handle_message(self, message: Message, ctx: MessageContext) -> Message: + print(f"Received message: {message.content}") + return Message(content=f"I got: {message.content}") + +async def main(): + runtime = SingleThreadedAgentRuntime() + + await MyAgent.register(runtime, "my_agent", MyAgent) + + runtime.start() + print(await runtime.send_message( + Message("I'm sending you this"), recipient=AgentId("my_agent", "default") + )) + await runtime.stop_when_idle() + +asyncio.run(main()) diff --git a/python/packages/autogen-core/tests/test_state.py b/python/packages/autogen-core/tests/test_state.py index 7120a9baab41..ba4fe86cf13e 100644 --- a/python/packages/autogen-core/tests/test_state.py +++ b/python/packages/autogen-core/tests/test_state.py @@ -10,7 +10,7 @@ def __init__(self) -> None: super().__init__("A stateful agent") self.state = 0 - async def on_message(self, message: Any, ctx: MessageContext) -> None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: raise NotImplementedError async def save_state(self) -> Mapping[str, Any]: diff --git a/python/packages/autogen-core/tests/test_types.py b/python/packages/autogen-core/tests/test_types.py index 1dbc02c4fa96..3959456b35b4 100644 --- a/python/packages/autogen-core/tests/test_types.py +++ b/python/packages/autogen-core/tests/test_types.py @@ -5,7 +5,7 @@ from autogen_core.base import MessageContext from autogen_core.base._serialization import has_nested_base_model from autogen_core.base._type_helpers import AnyType, get_types -from autogen_core.components._routed_agent import message_handler +from autogen_core.components._routed_agent import RoutedAgent, message_handler from pydantic import BaseModel @@ -21,7 +21,7 @@ def test_get_types() -> None: def test_handler() -> None: - class HandlerClass: + class HandlerClass(RoutedAgent): @message_handler() async def handler(self, message: int, ctx: MessageContext) -> Any: return None @@ -37,7 +37,7 @@ async def handler2(self, message: str | bool, ctx: MessageContext) -> None: assert HandlerClass.handler2.produces_types == [NoneType] -class HandlerClass: +class HandlerClass(RoutedAgent): @message_handler() async def handler(self, message: int, ctx: MessageContext) -> Any: return None diff --git a/python/packages/autogen-core/tests/test_utils/__init__.py b/python/packages/autogen-core/tests/test_utils/__init__.py index 5de7519fc49b..3b1ac1101fcd 100644 --- a/python/packages/autogen-core/tests/test_utils/__init__.py +++ b/python/packages/autogen-core/tests/test_utils/__init__.py @@ -57,5 +57,5 @@ class NoopAgent(BaseAgent): def __init__(self) -> None: super().__init__("A no op agent") - async def on_message(self, message: Any, ctx: MessageContext) -> Any: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: raise NotImplementedError From f7a6d481c7489867717d1e73952dd5aae9a1ed56 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:43:35 -0500 Subject: [PATCH 02/10] remove handled rpc --- .../packages/autogen-core/src/autogen_core/base/_base_agent.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 25ce6179b432..f2baa6d1a07e 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -128,13 +128,14 @@ async def on_message(self, message: Any, ctx: MessageContext) -> None: if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: if request_id in self._pending_rpc_requests: self._pending_rpc_requests[request_id].set_result(message) + del self._pending_rpc_requests[request_id] elif self._forward_unbound_rpc_responses_to_handler: await self.on_message_impl(message, ctx) else: warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) return None - return await self.on_message_impl(message, ctx) + await self.on_message_impl(message, ctx) async def send_message( self, From 4170415e27e07459202f247a1088a969c5972413 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:54:40 -0500 Subject: [PATCH 03/10] move module --- .../application/_single_threaded_agent_runtime.py | 4 ++-- .../{send_message_mixin.py => _publish_based_rpc.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename python/packages/autogen-core/src/autogen_core/components/{send_message_mixin.py => _publish_based_rpc.py} (100%) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 9b402c11c4a7..f56618da512f 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -15,8 +15,8 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from autogen_core.base._serialization import MessageSerializer, SerializationRegistry -from autogen_core.components.send_message_mixin import PublishBasedRpcMixin +from ..base._serialization import MessageSerializer, SerializationRegistry +from ..components._publish_based_rpc import PublishBasedRpcMixin from ..base import ( Agent, diff --git a/python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py similarity index 100% rename from python/packages/autogen-core/src/autogen_core/components/send_message_mixin.py rename to python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py From 7e62aad51ab16dcf5c2baf5ca0ad86927b9b21db Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 14:59:49 -0500 Subject: [PATCH 04/10] remove rpc from single threaded runtime --- .../_single_threaded_agent_runtime.py | 249 ++---------------- 1 file changed, 28 insertions(+), 221 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index f56618da512f..763dada6ab5d 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -6,7 +6,7 @@ import threading import uuid import warnings -from asyncio import CancelledError, Future, Task +from asyncio import CancelledError, Task from collections.abc import Sequence from dataclasses import dataclass from enum import Enum @@ -32,7 +32,7 @@ SubscriptionInstantiationContext, TopicId, ) -from ..base.exceptions import MessageDroppedException + from ..base.intervention import DropMessage, InterventionHandler from ._helpers import SubscriptionManager, get_impl from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata @@ -58,30 +58,6 @@ class PublishMessageEnvelope: message_id: str -@dataclass(kw_only=True) -class SendMessageEnvelope: - """A message envelope for sending a message to a specific agent that can handle - the message of the type T.""" - - message: Any - sender: AgentId | None - recipient: AgentId - future: Future[Any] - cancellation_token: CancellationToken - metadata: EnvelopeMetadata | None = None - - -@dataclass(kw_only=True) -class ResponseMessageEnvelope: - """A message envelope for sending a response to a message.""" - - message: Any - future: Future[Any] - sender: AgentId - recipient: AgentId | None - metadata: EnvelopeMetadata | None = None - - P = ParamSpec("P") T = TypeVar("T", bound=Agent) @@ -175,7 +151,7 @@ def __init__( tracer_provider: TracerProvider | None = None, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) - self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = [] + self._message_queue: List[PublishMessageEnvelope] = [] # (namespace, type) -> List[AgentId] self._agent_factories: Dict[ str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]] @@ -191,7 +167,7 @@ def __init__( @property def unprocessed_messages( self, - ) -> Sequence[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope]: + ) -> Sequence[PublishMessageEnvelope]: return self._message_queue @property @@ -202,56 +178,6 @@ def outstanding_tasks(self) -> int: def _known_agent_names(self) -> Set[str]: return set(self._agent_factories.keys()) - # Returns the response of the message - # async def send_message( - # self, - # message: Any, - # recipient: AgentId, - # *, - # sender: AgentId | None = None, - # cancellation_token: CancellationToken | None = None, - # ) -> Any: - # if cancellation_token is None: - # cancellation_token = CancellationToken() - - # # event_logger.info( - # # MessageEvent( - # # payload=message, - # # sender=sender, - # # receiver=recipient, - # # kind=MessageKind.DIRECT, - # # delivery_stage=DeliveryStage.SEND, - # # ) - # # ) - - # with self._tracer_helper.trace_block( - # "create", - # recipient, - # parent=None, - # extraAttributes={"message_type": type(message).__name__}, - # ): - # future = asyncio.get_event_loop().create_future() - # if recipient.type not in self._known_agent_names: - # future.set_exception(Exception("Recipient not found")) - - # content = message.__dict__ if hasattr(message, "__dict__") else message - # logger.info(f"Sending message of type {type(message).__name__} to {recipient.type}: {content}") - - # self._message_queue.append( - # SendMessageEnvelope( - # message=message, - # recipient=recipient, - # future=future, - # cancellation_token=cancellation_token, - # sender=sender, - # metadata=get_telemetry_envelope_metadata(), - # ) - # ) - - # cancellation_token.link_future(future) - - # return await future - async def publish_message( self, message: Any, @@ -308,61 +234,6 @@ async def load_state(self, state: Mapping[str, Any]) -> None: if agent_id.type in self._known_agent_names: await (await self._get_agent(agent_id)).load_state(state[str(agent_id)]) - async def _process_send(self, message_envelope: SendMessageEnvelope) -> None: - with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata): - recipient = message_envelope.recipient - # todo: check if recipient is in the known namespaces - # assert recipient in self._agents - - try: - # TODO use id - sender_name = message_envelope.sender.type if message_envelope.sender is not None else "Unknown" - logger.info( - f"Calling message handler for {recipient} with message type {type(message_envelope.message).__name__} sent by {sender_name}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=recipient, - # kind=MessageKind.DIRECT, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - recipient_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=message_envelope.sender, - topic_id=None, - cancellation_token=message_envelope.cancellation_token, - # Will be fixed when send API removed - message_id="NOT_DEFINED_TODO_FIX", - ) - with MessageHandlerContext.populate_context(recipient_agent.id): - response = await recipient_agent.on_message( - message_envelope.message, - ctx=message_context, - ) - except CancelledError as e: - if not message_envelope.future.cancelled(): - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - except BaseException as e: - message_envelope.future.set_exception(e) - self._outstanding_tasks.decrement() - return - - self._message_queue.append( - ResponseMessageEnvelope( - message=response, - future=message_envelope.future, - sender=message_envelope.recipient, - recipient=message_envelope.sender, - metadata=get_telemetry_envelope_metadata(), - ) - ) - self._outstanding_tasks.decrement() - async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None: with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata): try: @@ -418,29 +289,6 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: self._outstanding_tasks.decrement() # TODO if responses are given for a publish - async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None: - with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata): - content = ( - message_envelope.message.__dict__ - if hasattr(message_envelope.message, "__dict__") - else message_envelope.message - ) - logger.info( - f"Resolving response with message type {type(message_envelope.message).__name__} for recipient {message_envelope.recipient} from {message_envelope.sender.type}: {content}" - ) - # event_logger.info( - # MessageEvent( - # payload=message_envelope.message, - # sender=message_envelope.sender, - # receiver=message_envelope.recipient, - # kind=MessageKind.RESPOND, - # delivery_stage=DeliveryStage.DELIVER, - # ) - # ) - self._outstanding_tasks.decrement() - if not message_envelope.future.cancelled(): - message_envelope.future.set_result(message_envelope.message) - async def process_next(self) -> None: """Process the next message in the queue.""" @@ -450,71 +298,30 @@ async def process_next(self) -> None: return message_envelope = self._message_queue.pop(0) - match message_envelope: - case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_send(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_send") - except BaseException as e: - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_send(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case PublishMessageEnvelope( - message=message, - sender=sender, - ): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - with self._tracer_helper.trace_block( - "intercept", handler.__class__.__name__, parent=message_envelope.metadata - ): - try: - temp_message = await handler.on_publish(message, sender=sender) - _warn_if_none(temp_message, "on_publish") - except BaseException as e: - # TODO: we should raise the intervention exception to the publisher. - logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - # TODO log message dropped - return - - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_publish(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - case ResponseMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future): - if self._intervention_handlers is not None: - for handler in self._intervention_handlers: - try: - temp_message = await handler.on_response(message, sender=sender, recipient=recipient) - _warn_if_none(temp_message, "on_response") - except BaseException as e: - # TODO: should we raise the exception to sender of the response instead? - future.set_exception(e) - return - if temp_message is DropMessage or isinstance(temp_message, DropMessage): - future.set_exception(MessageDroppedException()) - return - message_envelope.message = temp_message - self._outstanding_tasks.increment() - task = asyncio.create_task(self._process_response(message_envelope)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) + message = message_envelope.message + sender = message_envelope.sender + + if self._intervention_handlers is not None: + for handler in self._intervention_handlers: + with self._tracer_helper.trace_block( + "intercept", handler.__class__.__name__, parent=message_envelope.metadata + ): + try: + temp_message = await handler.on_publish(message, sender=sender) + _warn_if_none(temp_message, "on_publish") + except BaseException as e: + # TODO: we should raise the intervention exception to the publisher. + logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) + return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): + # TODO log message dropped + return + + message_envelope.message = temp_message + self._outstanding_tasks.increment() + task = asyncio.create_task(self._process_publish(message_envelope)) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) # Yield control to the message loop to allow other tasks to run await asyncio.sleep(0) From 13cc05ac0946063824dee47feba65fbdb43b4100 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:01:58 -0500 Subject: [PATCH 05/10] remove rpc from worker runtime --- protos/agent_worker.proto | 47 +---- .../application/_worker_runtime.py | 162 +------------- .../application/protos/agent_worker_pb2.py | 76 +++---- .../application/protos/agent_worker_pb2.pyi | 198 +----------------- 4 files changed, 38 insertions(+), 445 deletions(-) diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 4d346dfecd63..8260a1d77f9f 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -7,46 +7,11 @@ option csharp_namespace = "Microsoft.AutoGen.Abstractions"; import "cloudevent.proto"; import "google/protobuf/any.proto"; -message TopicId { - string type = 1; - string source = 2; -} - message AgentId { string type = 1; string key = 2; } -message Payload { - string data_type = 1; - string data_content_type = 2; - bytes data = 3; -} - -message RpcRequest { - string request_id = 1; - optional AgentId source = 2; - AgentId target = 3; - string method = 4; - Payload payload = 5; - map metadata = 6; -} - -message RpcResponse { - string request_id = 1; - Payload payload = 2; - string error = 3; - map metadata = 4; -} - -message Event { - string topic_type = 1; - string topic_source = 2; - optional AgentId source = 3; - Payload payload = 4; - map metadata = 5; -} - message RegisterAgentTypeRequest { string request_id = 1; string type = 2; @@ -115,13 +80,11 @@ message SaveStateResponse { message Message { oneof message { - RpcRequest request = 1; - RpcResponse response = 2; - cloudevent.CloudEvent cloudEvent = 3; - RegisterAgentTypeRequest registerAgentTypeRequest = 4; - RegisterAgentTypeResponse registerAgentTypeResponse = 5; - AddSubscriptionRequest addSubscriptionRequest = 6; - AddSubscriptionResponse addSubscriptionResponse = 7; + cloudevent.CloudEvent cloudEvent = 1; + RegisterAgentTypeRequest registerAgentTypeRequest = 2; + RegisterAgentTypeResponse registerAgentTypeResponse = 3; + AddSubscriptionRequest addSubscriptionRequest = 4; + AddSubscriptionResponse addSubscriptionResponse = 5; } } diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index 24714f29155b..d72d15554863 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,6 +28,7 @@ cast, ) + from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated @@ -53,6 +54,7 @@ from ..base._serialization import MessageSerializer, SerializationRegistry from ..base._type_helpers import ChannelArgumentType from ..components import TypePrefixSubscription, TypeSubscription +from ..components._publish_based_rpc import PublishBasedRpcMixin from . import _constants from ._constants import GRPC_IMPORT_ERROR_STR from ._helpers import SubscriptionManager, get_impl @@ -177,7 +179,7 @@ async def recv(self) -> agent_worker_pb2.Message: return await self._recv_queue.get() -class WorkerAgentRuntime(AgentRuntime): +class WorkerAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, host_address: str, @@ -237,16 +239,6 @@ async def _run_read_loop(self) -> None: match oneofcase: case "registerAgentTypeRequest" | "addSubscriptionRequest": logger.warning(f"Cant handle {oneofcase}, skipping.") - case "request": - task = asyncio.create_task(self._process_request(message.request)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - task = asyncio.create_task(self._process_response(message.response)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -331,51 +323,6 @@ async def _send_message( with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata): await self._host_connection.send(runtime_message) - async def send_message( - self, - message: Any, - recipient: AgentId, - *, - sender: AgentId | None = None, - cancellation_token: CancellationToken | None = None, - ) -> Any: - if not self._running: - raise ValueError("Runtime must be running when sending message.") - if self._host_connection is None: - raise RuntimeError("Host connection is not set.") - data_type = self._serialization_registry.type_name(message) - with self._trace_helper.trace_block( - "create", recipient, parent=None, extraAttributes={"message_type": data_type} - ): - # create a new future for the result - future = asyncio.get_event_loop().create_future() - request_id = await self._get_new_request_id() - self._pending_requests[request_id] = future - serialized_message = self._serialization_registry.serialize( - message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - telemetry_metadata = get_telemetry_grpc_metadata() - runtime_message = agent_worker_pb2.Message( - request=agent_worker_pb2.RpcRequest( - request_id=request_id, - target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key), - source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None, - metadata=telemetry_metadata, - payload=agent_worker_pb2.Payload( - data_type=data_type, - data=serialized_message, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - ) - ) - - # TODO: Find a way to handle timeouts/errors - task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - return await future - async def publish_message( self, message: Any, @@ -475,98 +422,6 @@ async def _get_new_request_id(self) -> str: self._next_request_id += 1 return str(self._next_request_id) - async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None: - assert self._host_connection is not None - recipient = AgentId(request.target.type, request.target.key) - sender: AgentId | None = None - if request.HasField("source"): - sender = AgentId(request.source.type, request.source.key) - logging.info(f"Processing request from {sender} to {recipient}") - else: - logging.info(f"Processing request from unknown source to {recipient}") - - # Deserialize the message. - message = self._serialization_registry.deserialize( - request.payload.data, - type_name=request.payload.data_type, - data_content_type=request.payload.data_content_type, - ) - - # Get the receiving agent and prepare the message context. - rec_agent = await self._get_agent(recipient) - message_context = MessageContext( - sender=sender, - topic_id=None, - cancellation_token=CancellationToken(), - message_id=request.request_id, - ) - - # Call the receiving agent. - try: - with MessageHandlerContext.populate_context(rec_agent.id): - with self._trace_helper.trace_block( - "process", - rec_agent.id, - parent=request.metadata, - attributes={"request_id": request.request_id}, - extraAttributes={"message_type": request.payload.data_type}, - ): - result = await rec_agent.on_message(message, ctx=message_context) - except BaseException as e: - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - error=str(e), - metadata=get_telemetry_grpc_metadata(), - ), - ) - # Send the error response. - await self._host_connection.send(response_message) - return - - # Serialize the result. - result_type = self._serialization_registry.type_name(result) - serialized_result = self._serialization_registry.serialize( - result, type_name=result_type, data_content_type=JSON_DATA_CONTENT_TYPE - ) - - # Create the response message. - response_message = agent_worker_pb2.Message( - response=agent_worker_pb2.RpcResponse( - request_id=request.request_id, - payload=agent_worker_pb2.Payload( - data_type=result_type, - data=serialized_result, - data_content_type=JSON_DATA_CONTENT_TYPE, - ), - metadata=get_telemetry_grpc_metadata(), - ) - ) - - # Send the response. - await self._host_connection.send(response_message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None: - with self._trace_helper.trace_block( - "ack", - None, - parent=response.metadata, - attributes={"request_id": response.request_id}, - extraAttributes={"message_type": response.payload.data_type}, - ): - # Deserialize the result. - result = self._serialization_registry.deserialize( - response.payload.data, - type_name=response.payload.data_type, - data_content_type=response.payload.data_content_type, - ) - # Get the future and set the result. - future = self._pending_requests.pop(response.request_id) - if len(response.error) > 0: - future.set_exception(Exception(response.error)) - else: - future.set_result(result) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: event_attributes = event.attributes sender: AgentId | None = None @@ -598,16 +453,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: else: raise ValueError(f"Unsupported message content type: {message_content_type}") - # TODO: dont read these values in the runtime - topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else "" - is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - is_marked_rpc_type = ( - _constants.MESSAGE_KIND_ATTR in event_attributes - and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST - ) - if is_rpc and not is_marked_rpc_type: - warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2) - # Send the message to each recipient. responses: List[Awaitable[Any]] = [] for agent_id in recipients: @@ -616,7 +461,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: message_context = MessageContext( sender=sender, topic_id=topic_id, - is_rpc=is_rpc, cancellation_token=CancellationToken(), message_id=event.id, ) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py index 319ee2c6365d..ca08dcb1db83 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"\'\n\x07TopicId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0e\n\x06source\x18\x02 \x01(\t\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xe4\x01\n\x05\x45vent\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x14\n\x0ctopic_source\x18\x02 \x01(\t\x12$\n\x06source\x18\x03 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12 \n\x07payload\x18\x04 \x01(\x0b\x32\x0f.agents.Payload\x12-\n\x08metadata\x18\x05 \x03(\x0b\x32\x1b.agents.Event.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xa6\x03\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12,\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x04 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x05 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x06 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x07 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"<\n\x18RegisterAgentTypeRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\"^\n\x19RegisterAgentTypeResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\x96\x01\n\x0cSubscription\x12\x34\n\x10typeSubscription\x18\x01 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x02 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"X\n\x16\x41\x64\x64SubscriptionRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12*\n\x0csubscription\x18\x02 \x01(\x0b\x32\x14.agents.Subscription\"\\\n\x17\x41\x64\x64SubscriptionResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\"j\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\x12\x0f\n\x07success\x18\x02 \x01(\x08\x12\x12\n\x05\x65rror\x18\x03 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"B\n\x11SaveStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\xd6\x02\n\x07Message\x12,\n\ncloudEvent\x18\x01 \x01(\x0b\x32\x16.cloudevent.CloudEventH\x00\x12\x44\n\x18registerAgentTypeRequest\x18\x02 \x01(\x0b\x32 .agents.RegisterAgentTypeRequestH\x00\x12\x46\n\x19registerAgentTypeResponse\x18\x03 \x01(\x0b\x32!.agents.RegisterAgentTypeResponseH\x00\x12@\n\x16\x61\x64\x64SubscriptionRequest\x18\x04 \x01(\x0b\x32\x1e.agents.AddSubscriptionRequestH\x00\x12\x42\n\x17\x61\x64\x64SubscriptionResponse\x18\x05 \x01(\x0b\x32\x1f.agents.AddSubscriptionResponseH\x00\x42\t\n\x07message2\xb2\x01\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponseB!\xaa\x02\x1eMicrosoft.AutoGen.Abstractionsb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -24,52 +24,30 @@ if _descriptor._USE_C_DESCRIPTORS == False: _globals['DESCRIPTOR']._options = None _globals['DESCRIPTOR']._serialized_options = b'\252\002\036Microsoft.AutoGen.Abstractions' - _globals['_RPCREQUEST_METADATAENTRY']._options = None - _globals['_RPCREQUEST_METADATAENTRY']._serialized_options = b'8\001' - _globals['_RPCRESPONSE_METADATAENTRY']._options = None - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_options = b'8\001' - _globals['_EVENT_METADATAENTRY']._options = None - _globals['_EVENT_METADATAENTRY']._serialized_options = b'8\001' - _globals['_TOPICID']._serialized_start=75 - _globals['_TOPICID']._serialized_end=114 - _globals['_AGENTID']._serialized_start=116 - _globals['_AGENTID']._serialized_end=152 - _globals['_PAYLOAD']._serialized_start=154 - _globals['_PAYLOAD']._serialized_end=223 - _globals['_RPCREQUEST']._serialized_start=226 - _globals['_RPCREQUEST']._serialized_end=491 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_start=433 - _globals['_RPCREQUEST_METADATAENTRY']._serialized_end=480 - _globals['_RPCRESPONSE']._serialized_start=494 - _globals['_RPCRESPONSE']._serialized_end=678 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_start=433 - _globals['_RPCRESPONSE_METADATAENTRY']._serialized_end=480 - _globals['_EVENT']._serialized_start=681 - _globals['_EVENT']._serialized_end=909 - _globals['_EVENT_METADATAENTRY']._serialized_start=433 - _globals['_EVENT_METADATAENTRY']._serialized_end=480 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=911 - _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=971 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=973 - _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=1067 - _globals['_TYPESUBSCRIPTION']._serialized_start=1069 - _globals['_TYPESUBSCRIPTION']._serialized_end=1127 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=1129 - _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=1200 - _globals['_SUBSCRIPTION']._serialized_start=1203 - _globals['_SUBSCRIPTION']._serialized_end=1353 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=1355 - _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=1443 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=1445 - _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=1537 - _globals['_AGENTSTATE']._serialized_start=1540 - _globals['_AGENTSTATE']._serialized_end=1697 - _globals['_GETSTATERESPONSE']._serialized_start=1699 - _globals['_GETSTATERESPONSE']._serialized_end=1805 - _globals['_SAVESTATERESPONSE']._serialized_start=1807 - _globals['_SAVESTATERESPONSE']._serialized_end=1873 - _globals['_MESSAGE']._serialized_start=1876 - _globals['_MESSAGE']._serialized_end=2298 - _globals['_AGENTRPC']._serialized_start=2301 - _globals['_AGENTRPC']._serialized_end=2479 + _globals['_AGENTID']._serialized_start=75 + _globals['_AGENTID']._serialized_end=111 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_start=113 + _globals['_REGISTERAGENTTYPEREQUEST']._serialized_end=173 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_start=175 + _globals['_REGISTERAGENTTYPERESPONSE']._serialized_end=269 + _globals['_TYPESUBSCRIPTION']._serialized_start=271 + _globals['_TYPESUBSCRIPTION']._serialized_end=329 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_start=331 + _globals['_TYPEPREFIXSUBSCRIPTION']._serialized_end=402 + _globals['_SUBSCRIPTION']._serialized_start=405 + _globals['_SUBSCRIPTION']._serialized_end=555 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_start=557 + _globals['_ADDSUBSCRIPTIONREQUEST']._serialized_end=645 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_start=647 + _globals['_ADDSUBSCRIPTIONRESPONSE']._serialized_end=739 + _globals['_AGENTSTATE']._serialized_start=742 + _globals['_AGENTSTATE']._serialized_end=899 + _globals['_GETSTATERESPONSE']._serialized_start=901 + _globals['_GETSTATERESPONSE']._serialized_end=1007 + _globals['_SAVESTATERESPONSE']._serialized_start=1009 + _globals['_SAVESTATERESPONSE']._serialized_end=1075 + _globals['_MESSAGE']._serialized_start=1078 + _globals['_MESSAGE']._serialized_end=1420 + _globals['_AGENTRPC']._serialized_start=1423 + _globals['_AGENTRPC']._serialized_end=1601 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi index 79e384ab948b..7c9baa5e9ca7 100644 --- a/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-core/src/autogen_core/application/protos/agent_worker_pb2.pyi @@ -5,33 +5,13 @@ isort:skip_file import builtins import cloudevent_pb2 -import collections.abc import google.protobuf.any_pb2 import google.protobuf.descriptor -import google.protobuf.internal.containers import google.protobuf.message import typing DESCRIPTOR: google.protobuf.descriptor.FileDescriptor -@typing.final -class TopicId(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - TYPE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - type: builtins.str - source: builtins.str - def __init__( - self, - *, - type: builtins.str = ..., - source: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["source", b"source", "type", b"type"]) -> None: ... - -global___TopicId = TopicId - @typing.final class AgentId(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -50,170 +30,6 @@ class AgentId(google.protobuf.message.Message): global___AgentId = AgentId -@typing.final -class Payload(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - DATA_TYPE_FIELD_NUMBER: builtins.int - DATA_CONTENT_TYPE_FIELD_NUMBER: builtins.int - DATA_FIELD_NUMBER: builtins.int - data_type: builtins.str - data_content_type: builtins.str - data: builtins.bytes - def __init__( - self, - *, - data_type: builtins.str = ..., - data_content_type: builtins.str = ..., - data: builtins.bytes = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["data", b"data", "data_content_type", b"data_content_type", "data_type", b"data_type"]) -> None: ... - -global___Payload = Payload - -@typing.final -class RpcRequest(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - TARGET_FIELD_NUMBER: builtins.int - METHOD_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - method: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def target(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - source: global___AgentId | None = ..., - target: global___AgentId | None = ..., - method: builtins.str = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source", "target", b"target"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "method", b"method", "payload", b"payload", "request_id", b"request_id", "source", b"source", "target", b"target"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___RpcRequest = RpcRequest - -@typing.final -class RpcResponse(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - REQUEST_ID_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - ERROR_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - request_id: builtins.str - error: builtins.str - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - request_id: builtins.str = ..., - payload: global___Payload | None = ..., - error: builtins.str = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["payload", b"payload"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["error", b"error", "metadata", b"metadata", "payload", b"payload", "request_id", b"request_id"]) -> None: ... - -global___RpcResponse = RpcResponse - -@typing.final -class Event(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - @typing.final - class MetadataEntry(google.protobuf.message.Message): - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - KEY_FIELD_NUMBER: builtins.int - VALUE_FIELD_NUMBER: builtins.int - key: builtins.str - value: builtins.str - def __init__( - self, - *, - key: builtins.str = ..., - value: builtins.str = ..., - ) -> None: ... - def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... - - TOPIC_TYPE_FIELD_NUMBER: builtins.int - TOPIC_SOURCE_FIELD_NUMBER: builtins.int - SOURCE_FIELD_NUMBER: builtins.int - PAYLOAD_FIELD_NUMBER: builtins.int - METADATA_FIELD_NUMBER: builtins.int - topic_type: builtins.str - topic_source: builtins.str - @property - def source(self) -> global___AgentId: ... - @property - def payload(self) -> global___Payload: ... - @property - def metadata(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... - def __init__( - self, - *, - topic_type: builtins.str = ..., - topic_source: builtins.str = ..., - source: global___AgentId | None = ..., - payload: global___Payload | None = ..., - metadata: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., - ) -> None: ... - def HasField(self, field_name: typing.Literal["_source", b"_source", "payload", b"payload", "source", b"source"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["_source", b"_source", "metadata", b"metadata", "payload", b"payload", "source", b"source", "topic_source", b"topic_source", "topic_type", b"topic_type"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["_source", b"_source"]) -> typing.Literal["source"] | None: ... - -global___Event = Event - @typing.final class RegisterAgentTypeRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -435,18 +251,12 @@ global___SaveStateResponse = SaveStateResponse class Message(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - REQUEST_FIELD_NUMBER: builtins.int - RESPONSE_FIELD_NUMBER: builtins.int CLOUDEVENT_FIELD_NUMBER: builtins.int REGISTERAGENTTYPEREQUEST_FIELD_NUMBER: builtins.int REGISTERAGENTTYPERESPONSE_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONREQUEST_FIELD_NUMBER: builtins.int ADDSUBSCRIPTIONRESPONSE_FIELD_NUMBER: builtins.int @property - def request(self) -> global___RpcRequest: ... - @property - def response(self) -> global___RpcResponse: ... - @property def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... @property def registerAgentTypeRequest(self) -> global___RegisterAgentTypeRequest: ... @@ -459,16 +269,14 @@ class Message(google.protobuf.message.Message): def __init__( self, *, - request: global___RpcRequest | None = ..., - response: global___RpcResponse | None = ..., cloudEvent: cloudevent_pb2.CloudEvent | None = ..., registerAgentTypeRequest: global___RegisterAgentTypeRequest | None = ..., registerAgentTypeResponse: global___RegisterAgentTypeResponse | None = ..., addSubscriptionRequest: global___AddSubscriptionRequest | None = ..., addSubscriptionResponse: global___AddSubscriptionResponse | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... + def HasField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["addSubscriptionRequest", b"addSubscriptionRequest", "addSubscriptionResponse", b"addSubscriptionResponse", "cloudEvent", b"cloudEvent", "message", b"message", "registerAgentTypeRequest", b"registerAgentTypeRequest", "registerAgentTypeResponse", b"registerAgentTypeResponse"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["cloudEvent", "registerAgentTypeRequest", "registerAgentTypeResponse", "addSubscriptionRequest", "addSubscriptionResponse"] | None: ... global___Message = Message From 88113d3aa75ddf756a8fb032593a95aae0f9f551 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:02:14 -0500 Subject: [PATCH 06/10] fmt --- .../_single_threaded_agent_runtime.py | 6 +-- .../application/_worker_runtime.py | 1 - .../src/autogen_core/base/_base_agent.py | 11 ++++- .../src/autogen_core/base/_rpc.py | 6 +-- .../autogen_core/components/_closure_agent.py | 12 +++++- .../components/_publish_based_rpc.py | 41 ++++++++++--------- .../autogen_core/components/_routed_agent.py | 8 +++- 7 files changed, 52 insertions(+), 33 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py index 763dada6ab5d..06eabddf4f9c 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_single_threaded_agent_runtime.py @@ -15,9 +15,6 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from ..base._serialization import MessageSerializer, SerializationRegistry -from ..components._publish_based_rpc import PublishBasedRpcMixin - from ..base import ( Agent, AgentId, @@ -32,8 +29,9 @@ SubscriptionInstantiationContext, TopicId, ) - +from ..base._serialization import MessageSerializer, SerializationRegistry from ..base.intervention import DropMessage, InterventionHandler +from ..components._publish_based_rpc import PublishBasedRpcMixin from ._helpers import SubscriptionManager, get_impl from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py index d72d15554863..a9d2a3c3570d 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py @@ -28,7 +28,6 @@ cast, ) - from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index f2baa6d1a07e..318178913f02 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -57,6 +57,7 @@ def decorator(cls: Type[BaseAgentType]) -> Type[BaseAgentType]: return decorator + class BaseAgent(ABC, Agent): internal_unbound_subscriptions_list: ClassVar[List[UnboundSubscription]] = [] internal_extra_handles_types: ClassVar[List[Tuple[Type[Any], List[MessageSerializer[Any]]]]] = [] @@ -132,7 +133,10 @@ async def on_message(self, message: Any, ctx: MessageContext) -> None: elif self._forward_unbound_rpc_responses_to_handler: await self.on_message_impl(message, ctx) else: - warnings.warn(f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", stacklevel=2) + warnings.warn( + f"Received RPC response for unknown request {request_id}. To forward unbound rpc responses to the handler, set forward_unbound_rpc_responses_to_handler=True", + stacklevel=2, + ) return None await self.on_message_impl(message, ctx) @@ -148,7 +152,10 @@ async def send_message( if cancellation_token is None: cancellation_token = CancellationToken() - recipient_topic = TopicId(type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), source=recipient.key) + recipient_topic = TopicId( + type=format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=self.id.type), + source=recipient.key, + ) request_id = str(uuid.uuid4()) future = Future[Any]() diff --git a/python/packages/autogen-core/src/autogen_core/base/_rpc.py b/python/packages/autogen-core/src/autogen_core/base/_rpc.py index d6b857afd0e8..d6554e71844d 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/base/_rpc.py @@ -3,14 +3,14 @@ from typing import Optional - - def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_type: str) -> str: return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" -def format_rpc_response_topic(rpc_sender_agent_type: str,request_id: str) -> str: + +def format_rpc_response_topic(rpc_sender_agent_type: str, request_id: str) -> str: return f"{rpc_sender_agent_type}:rpc_response={request_id}" + # If is an rpc response, return the request id def is_rpc_response(topic_type: str) -> Optional[str]: topic_segments = topic_type.split(":") diff --git a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py index d6b100547bd1..36566f4ecbc1 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_closure_agent.py @@ -76,7 +76,11 @@ async def publish_message( class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, forward_unbound_rpc_responses_to_handler: bool = False + self, + description: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + forward_unbound_rpc_responses_to_handler: bool = False, ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -137,7 +141,11 @@ async def register_closure( subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure, forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler) + return ClosureAgent( + description=description, + closure=closure, + forward_unbound_rpc_responses_to_handler=forward_unbound_rpc_responses_to_handler, + ) agent_type = await cls.register( runtime=runtime, diff --git a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py index 2ae0e09eebaf..a091fe61e788 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py @@ -1,42 +1,44 @@ - -from typing import Any +import asyncio import uuid import warnings +from typing import Any from autogen_core.base._rpc import format_rpc_request_topic, format_rpc_response_topic from autogen_core.base._topic import TopicId -from ..base._message_context import MessageContext - from ..base._agent_id import AgentId +from ..base._agent_runtime import AgentRuntime from ..base._cancellation_token import CancellationToken +from ..base._message_context import MessageContext from ._closure_agent import ClosureAgent, ClosureContext -from ..base._agent_runtime import AgentRuntime - -import asyncio class PublishBasedRpcMixin(AgentRuntime): async def send_message( - self: AgentRuntime, - message: Any, - recipient: AgentId, - *, - cancellation_token: CancellationToken | None = None, - ) -> Any: - + self: AgentRuntime, + message: Any, + recipient: AgentId, + *, + cancellation_token: CancellationToken | None = None, + ) -> Any: rpc_request_id = str(uuid.uuid4()) # TODO add "-" to topic and agent type allowed characters in spec closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" future: asyncio.Future[Any] = asyncio.Future() - expected_response_topic_type = format_rpc_response_topic(rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id) - async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageContext) -> None: + expected_response_topic_type = format_rpc_response_topic( + rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id + ) + + async def set_result(closure_context: ClosureContext, message: Any, ctx: MessageContext) -> None: assert ctx.topic_id is not None if ctx.topic_id.type == expected_response_topic_type: future.set_result(message) else: - warnings.warn(f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", stacklevel=2) + warnings.warn( + f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", + stacklevel=2, + ) # TODO: remove agent after response is received @@ -47,7 +49,9 @@ async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageC forward_unbound_rpc_responses_to_handler=True, ) - rpc_request_topic_id = format_rpc_request_topic(rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type) + rpc_request_topic_id = format_rpc_request_topic( + rpc_recipient_agent_type=recipient.type, rpc_sender_agent_type=closure_agent_type + ) await self.publish_message( message=message, topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), @@ -58,4 +62,3 @@ async def set_result(closure_context:ClosureContext, message: Any, ctx: MessageC return await future # register a closure agent... - diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index 50ef43aa67f9..c8e6da3b2c9a 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -301,7 +301,9 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True # Wrap the match function with a check on the topic for rpc - wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler @@ -440,7 +442,9 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None wrapper_handler.target_types = list(target_types) wrapper_handler.produces_types = list(return_types) wrapper_handler.is_message_handler = True - wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and (match(_message, _ctx) if match else True) + wrapper_handler.router = lambda _message, _ctx: (is_rpc_request(_ctx.topic_id.type) is not None) and ( + match(_message, _ctx) if match else True + ) return wrapper_handler From 11dea884fbf9d18f466d5bbce91043ad4c3e7440 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Sun, 1 Dec 2024 15:21:32 -0500 Subject: [PATCH 07/10] lint, type, fmt fixes --- .../_group_chat/_sequential_routed_agent.py | 4 +- .../cookbook/local-llms-ollama-litellm.ipynb | 3 +- .../_worker_runtime_host_servicer.py | 49 ------------------- .../src/autogen_core/base/_agent_proxy.py | 2 - .../src/autogen_core/base/_base_agent.py | 2 +- .../autogen_core/components/_routed_agent.py | 6 +-- .../autogen-core/tests/test_routed_agent.py | 5 +- .../headless_web_surfer/test_web_surfer.py | 5 -- 8 files changed, 10 insertions(+), 66 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py index fe80c9e93923..6b92b21e883b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_sequential_routed_agent.py @@ -43,9 +43,9 @@ def __init__(self, description: str) -> None: super().__init__(description=description) self._fifo_lock = FIFOLock() - async def on_message(self, message: Any, ctx: MessageContext) -> Any | None: + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: await self._fifo_lock.acquire() try: - return await super().on_message(message, ctx) + await super().on_message_impl(message, ctx) finally: self._fifo_lock.release() diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb index 80fde2b71017..e9274ae0a523 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/local-llms-ollama-litellm.ipynb @@ -174,7 +174,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -211,7 +211,6 @@ "await runtime.send_message(\n", " Message(\"Joe, tell me a joke.\"),\n", " recipient=AgentId(joe, \"default\"),\n", - " sender=AgentId(cathy, \"default\"),\n", ")\n", "await runtime.stop_when_idle()" ] diff --git a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py index e24a7db3f30a..4dfd52b9949e 100644 --- a/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py @@ -102,18 +102,6 @@ async def _receive_messages( logger.info(f"Received message from client {client_id}: {message}") oneofcase = message.WhichOneof("message") match oneofcase: - case "request": - request: agent_worker_pb2.RpcRequest = message.request - task = asyncio.create_task(self._process_request(request, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - response: agent_worker_pb2.RpcResponse = message.response - task = asyncio.create_task(self._process_response(response, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) case "cloudEvent": # The proto typing doesnt resolve this one event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore @@ -140,43 +128,6 @@ async def _receive_messages( case None: logger.warning("Received empty message") - async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None: - # Deliver the message to a client given the target agent type. - async with self._agent_type_to_client_id_lock: - target_client_id = self._agent_type_to_client_id.get(request.target.type) - if target_client_id is None: - logger.error(f"Agent {request.target.type} not found, failed to deliver message.") - return - target_send_queue = self._send_queues.get(target_client_id) - if target_send_queue is None: - logger.error(f"Client {target_client_id} not found, failed to deliver message.") - return - await target_send_queue.put(agent_worker_pb2.Message(request=request)) - - # Create a future to wait for the response from the target. - future = asyncio.get_event_loop().create_future() - self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future - - # Create a task to wait for the response and send it back to the client. - send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) - self._background_tasks.add(send_response_task) - send_response_task.add_done_callback(self._raise_on_exception) - send_response_task.add_done_callback(self._background_tasks.discard) - - async def _wait_and_send_response(self, future: Future[agent_worker_pb2.RpcResponse], client_id: int) -> None: - response = await future - message = agent_worker_pb2.Message(response=response) - send_queue = self._send_queues.get(client_id) - if send_queue is None: - logger.error(f"Client {client_id} not found, failed to send response message.") - return - await send_queue.put(message) - - async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: int) -> None: - # Setting the result of the future will send the response back to the original sender. - future = self._pending_responses[client_id].pop(response.request_id) - future.set_result(response) - async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: topic_id = TopicId(type=event.type, source=event.source) recipients = await self._subscription_manager.get_subscribed_recipients(topic_id) diff --git a/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py b/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py index f3eb70f28270..5b8ddc314a0c 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py +++ b/python/packages/autogen-core/src/autogen_core/base/_agent_proxy.py @@ -29,13 +29,11 @@ async def send_message( self, message: Any, *, - sender: AgentId, cancellation_token: CancellationToken | None = None, ) -> Any: return await self._runtime.send_message( message, recipient=self._agent, - sender=sender, cancellation_token=cancellation_token, ) diff --git a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py index 318178913f02..0df1cc25caad 100644 --- a/python/packages/autogen-core/src/autogen_core/base/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/base/_base_agent.py @@ -126,7 +126,7 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... @final async def on_message(self, message: Any, ctx: MessageContext) -> None: # Intercept RPC responses - if ctx.topic_id is not None and (request_id := is_rpc_response(ctx.topic_id.type)) is not None: + if (request_id := is_rpc_response(ctx.topic_id.type)) is not None: if request_id in self._pending_rpc_requests: self._pending_rpc_requests[request_id].set_result(message) del self._pending_rpc_requests[request_id] diff --git a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py index c8e6da3b2c9a..b7a169f184c9 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/components/_routed_agent.py @@ -515,7 +515,7 @@ def __init__(self, description: str) -> None: super().__init__(description) - async def on_message_impl(self, message: Any, ctx: MessageContext): + async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: """Handle a message by routing it to the appropriate message handler. Do not override this method in subclasses. Instead, add message handlers as methods decorated with either the :func:`event` or :func:`rpc` decorator.""" @@ -526,8 +526,8 @@ async def on_message_impl(self, message: Any, ctx: MessageContext): # Call the first handler whose router returns True and then return the result. for h in handlers: if h.router(message, ctx): - return await h(self, message, ctx) - return await self.on_unhandled_message(message, ctx) # type: ignore + await h(self, message, ctx) + await self.on_unhandled_message(message, ctx) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: """Called when a message is received that does not have a matching message handler. diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index cab1b1d467ff..55de2432f4fd 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -5,6 +5,7 @@ import pytest from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, MessageContext, TopicId +from autogen_core.base._rpc import is_rpc_request from autogen_core.components import RoutedAgent, TypeSubscription, event, message_handler, rpc from test_utils import LoopbackAgent @@ -23,12 +24,12 @@ def __init__(self) -> None: self.num_calls_rpc = 0 self.num_calls_broadcast = 0 - @message_handler(match=lambda _, ctx: ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is not None) async def on_rpc_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.num_calls_rpc += 1 return message - @message_handler(match=lambda _, ctx: not ctx.is_rpc) + @message_handler(match=lambda _, ctx: is_rpc_request(ctx.topic_id.type) is None) async def on_broadcast_message(self, message: MessageType, ctx: MessageContext) -> None: self.num_calls_broadcast += 1 diff --git a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py index 769ac5080e88..6106a9f219a0 100644 --- a/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py +++ b/python/packages/autogen-magentic-one/tests/headless_web_surfer/test_web_surfer.py @@ -218,27 +218,22 @@ async def test_web_surfer_oai() -> None: ) ), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll down.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="Please scroll up.", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="When was it founded?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.send_message( BroadcastMessage(content=UserMessage(content="What's this page about?", source="user")), recipient=web_surfer.id, - sender=user_proxy.id, ) await runtime.stop_when_idle() From aae93c0c0aab553fdcb243cf44aa584da7a76e99 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 19 Dec 2024 13:28:23 -0500 Subject: [PATCH 08/10] WIP --- .../autogen-core/src/autogen_core/_base_agent.py | 3 +-- .../src/autogen_core/_closure_agent.py | 2 +- .../src/autogen_core/_intervention.py | 13 +++++++++++++ .../{components => }/_publish_based_rpc.py | 13 ++++++------- .../src/autogen_core/_routed_agent.py | 3 +++ .../src/autogen_core/{base => }/_rpc.py | 0 .../_single_threaded_agent_runtime.py | 15 ++++++++------- .../src/autogen_core/base/intervention.py | 7 +------ .../autogen-core/tests/test_routed_agent.py | 4 +--- python/packages/autogen-core/tests/test_types.py | 1 - .../autogen_ext/runtimes/grpc/_worker_runtime.py | 2 +- 11 files changed, 35 insertions(+), 28 deletions(-) create mode 100644 python/packages/autogen-core/src/autogen_core/_intervention.py rename python/packages/autogen-core/src/autogen_core/{components => }/_publish_based_rpc.py (86%) rename python/packages/autogen-core/src/autogen_core/{base => }/_rpc.py (100%) diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 855726592202..927470122538 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -10,7 +10,7 @@ from typing_extensions import Self -from autogen_core.base._rpc import format_rpc_request_topic, is_rpc_response +from autogen_core._rpc import format_rpc_request_topic, is_rpc_response from ._agent import Agent from ._agent_id import AgentId @@ -124,7 +124,6 @@ def id(self) -> AgentId: def runtime(self) -> AgentRuntime: return self._runtime - @abstractmethod async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... diff --git a/python/packages/autogen-core/src/autogen_core/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/_closure_agent.py index 6c762f2855ac..66a990a29204 100644 --- a/python/packages/autogen-core/src/autogen_core/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_closure_agent.py @@ -15,7 +15,7 @@ from ._subscription import Subscription from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId -from ._type_helpers import get_types +from ._type_helpers import AnyType, get_types from .exceptions import CantHandleException T = TypeVar("T") diff --git a/python/packages/autogen-core/src/autogen_core/_intervention.py b/python/packages/autogen-core/src/autogen_core/_intervention.py new file mode 100644 index 000000000000..c18a529aae6f --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/_intervention.py @@ -0,0 +1,13 @@ +from typing import Any, Awaitable, Callable, Protocol, final + +__all__ = [ + "DropMessage", + "InterventionFunction", +] + + +@final +class DropMessage: ... + + +InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]] diff --git a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py similarity index 86% rename from python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py rename to python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py index a091fe61e788..6134844abcfb 100644 --- a/python/packages/autogen-core/src/autogen_core/components/_publish_based_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py @@ -3,14 +3,13 @@ import warnings from typing import Any -from autogen_core.base._rpc import format_rpc_request_topic, format_rpc_response_topic -from autogen_core.base._topic import TopicId - -from ..base._agent_id import AgentId -from ..base._agent_runtime import AgentRuntime -from ..base._cancellation_token import CancellationToken -from ..base._message_context import MessageContext +from ._agent_id import AgentId +from ._agent_runtime import AgentRuntime +from ._cancellation_token import CancellationToken from ._closure_agent import ClosureAgent, ClosureContext +from ._message_context import MessageContext +from ._rpc import format_rpc_request_topic, format_rpc_response_topic +from ._topic import TopicId class PublishBasedRpcMixin(AgentRuntime): diff --git a/python/packages/autogen-core/src/autogen_core/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/_routed_agent.py index 4b6c04afcb75..5c41ce28de7f 100644 --- a/python/packages/autogen-core/src/autogen_core/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_routed_agent.py @@ -1,4 +1,5 @@ import logging +import warnings from functools import wraps from typing import ( Any, @@ -20,7 +21,9 @@ from ._base_agent import BaseAgent from ._message_context import MessageContext +from ._rpc import format_rpc_response_topic, is_rpc_request from ._serialization import MessageSerializer, try_get_known_serializers_for_type +from ._topic import TopicId from ._type_helpers import AnyType, get_types from .exceptions import CantHandleException diff --git a/python/packages/autogen-core/src/autogen_core/base/_rpc.py b/python/packages/autogen-core/src/autogen_core/_rpc.py similarity index 100% rename from python/packages/autogen-core/src/autogen_core/base/_rpc.py rename to python/packages/autogen-core/src/autogen_core/_rpc.py diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 4cc2b1659e97..29b917e19149 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -15,7 +15,7 @@ from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from ..base import ( +from . import ( Agent, AgentId, AgentInstantiationContext, @@ -29,11 +29,12 @@ SubscriptionInstantiationContext, TopicId, ) -from ..base._serialization import MessageSerializer, SerializationRegistry -from ..base.intervention import DropMessage, InterventionHandler -from ..components._publish_based_rpc import PublishBasedRpcMixin -from ._helpers import SubscriptionManager, get_impl -from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata +from ._intervention import DropMessage +from ._publish_based_rpc import PublishBasedRpcMixin +from ._runtime_impl_helpers import SubscriptionManager, get_impl +from ._serialization import MessageSerializer, SerializationRegistry +from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata +from .base.intervention import InterventionHandler logger = logging.getLogger("autogen_core") event_logger = logging.getLogger("autogen_core.events") @@ -490,4 +491,4 @@ async def get( ) def add_message_serializer(self, serializer: MessageSerializer[Any] | Sequence[MessageSerializer[Any]]) -> None: - self._serialization_registry.add_serializer(serializer) \ No newline at end of file + self._serialization_registry.add_serializer(serializer) diff --git a/python/packages/autogen-core/src/autogen_core/base/intervention.py b/python/packages/autogen-core/src/autogen_core/base/intervention.py index 5fe337b8776d..1e7d1937414a 100644 --- a/python/packages/autogen-core/src/autogen_core/base/intervention.py +++ b/python/packages/autogen-core/src/autogen_core/base/intervention.py @@ -9,12 +9,7 @@ "DefaultInterventionHandler", ] - -@final -class DropMessage: ... - - -InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]] +from .._intervention import DropMessage, InterventionFunction class InterventionHandler(Protocol): diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index 04d362320bfe..1256f3103bd2 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -2,9 +2,7 @@ from dataclasses import dataclass from typing import Callable, cast -from autogen_core.base._rpc import is_rpc_request import pytest - from autogen_core import ( AgentId, MessageContext, @@ -16,10 +14,10 @@ message_handler, rpc, ) +from autogen_core._rpc import is_rpc_request from autogen_test_utils import LoopbackAgent - @dataclass class UnhandledMessageType: ... diff --git a/python/packages/autogen-core/tests/test_types.py b/python/packages/autogen-core/tests/test_types.py index 5ce4cab2e735..16697e006e6e 100644 --- a/python/packages/autogen-core/tests/test_types.py +++ b/python/packages/autogen-core/tests/test_types.py @@ -6,7 +6,6 @@ from autogen_core._routed_agent import RoutedAgent, message_handler from autogen_core._serialization import has_nested_base_model from autogen_core._type_helpers import AnyType, get_types - from pydantic import BaseModel diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py index 1156d06500c6..4a66e0872b84 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime.py @@ -52,7 +52,7 @@ SerializationRegistry, ) from autogen_core._telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata -from autogen_core.components._publish_based_rpc import PublishBasedRpcMixin +from autogen_core._publish_based_rpc import PublishBasedRpcMixin from google.protobuf import any_pb2 from opentelemetry.trace import TracerProvider from typing_extensions import Self, deprecated From 23c9e7956a9c841f46b2a592e769fe6d1e75a379 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 19 Dec 2024 17:11:00 -0500 Subject: [PATCH 09/10] fix circular imports --- .../src/autogen_core/_base_agent.py | 3 +-- .../_single_threaded_agent_runtime.py | 26 +++++++++---------- .../src/autogen_core/_subscription_context.py | 2 +- .../autogen-core/src/autogen_core/logging.py | 2 +- 4 files changed, 15 insertions(+), 18 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 927470122538..7af46268c907 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -10,8 +10,7 @@ from typing_extensions import Self -from autogen_core._rpc import format_rpc_request_topic, is_rpc_response - +from ._rpc import format_rpc_request_topic, is_rpc_response from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 29b917e19149..8db453f6dc82 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -12,23 +12,21 @@ from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast +from ._message_context import MessageContext +from ._message_handler_context import MessageHandlerContext +from ._subscription import Subscription from opentelemetry.trace import TracerProvider from typing_extensions import deprecated -from . import ( - Agent, - AgentId, - AgentInstantiationContext, - AgentMetadata, - AgentRuntime, - AgentType, - CancellationToken, - MessageContext, - MessageHandlerContext, - Subscription, - SubscriptionInstantiationContext, - TopicId, -) +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_instantiation import AgentInstantiationContext +from ._agent_metadata import AgentMetadata +from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType +from ._topic import TopicId +from ._subscription_context import SubscriptionInstantiationContext +from ._cancellation_token import CancellationToken from ._intervention import DropMessage from ._publish_based_rpc import PublishBasedRpcMixin from ._runtime_impl_helpers import SubscriptionManager, get_impl diff --git a/python/packages/autogen-core/src/autogen_core/_subscription_context.py b/python/packages/autogen-core/src/autogen_core/_subscription_context.py index 1cfd3fd882ed..29b1e1629798 100644 --- a/python/packages/autogen-core/src/autogen_core/_subscription_context.py +++ b/python/packages/autogen-core/src/autogen_core/_subscription_context.py @@ -2,7 +2,7 @@ from contextvars import ContextVar from typing import Any, ClassVar, Generator -from autogen_core._agent_type import AgentType +from ._agent_type import AgentType class SubscriptionInstantiationContext: diff --git a/python/packages/autogen-core/src/autogen_core/logging.py b/python/packages/autogen-core/src/autogen_core/logging.py index 5e3870203e57..34a3012fe423 100644 --- a/python/packages/autogen-core/src/autogen_core/logging.py +++ b/python/packages/autogen-core/src/autogen_core/logging.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, cast -from autogen_core import AgentId +from ._agent_id import AgentId class LLMCallEvent: From 39a7a80485aaf9fe4ccc03425a93207c345acb78 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Fri, 20 Dec 2024 16:16:26 -0500 Subject: [PATCH 10/10] WIP checkpoint --- .../samples/slow_human_in_loop.py | 10 +- .../autogen-core/src/autogen_core/__init__.py | 3 + .../src/autogen_core/_base_agent.py | 105 +++++++++++++++++- .../src/autogen_core/_cancellation_token.py | 13 ++- .../src/autogen_core/_intervention.py | 8 +- .../src/autogen_core/_publish_based_rpc.py | 46 +++++++- .../src/autogen_core/_routed_agent.py | 70 ++++++------ .../_single_threaded_agent_runtime.py | 52 +++++++-- .../autogen-core/src/autogen_core/_types.py | 26 +++++ .../{_rpc.py => _well_known_topics.py} | 28 +++++ .../src/autogen_core/base/intervention.py | 2 +- .../autogen_core/tool_agent/_tool_agent.py | 21 ++-- .../autogen-core/tests/test_cancellation.py | 25 +++-- .../autogen-core/tests/test_intervention.py | 94 ++++------------ .../autogen-core/tests/test_routed_agent.py | 16 +-- .../autogen-core/tests/test_tool_agent.py | 18 +-- 16 files changed, 358 insertions(+), 179 deletions(-) rename python/packages/autogen-core/src/autogen_core/{_rpc.py => _well_known_topics.py} (52%) diff --git a/python/packages/autogen-core/samples/slow_human_in_loop.py b/python/packages/autogen-core/samples/slow_human_in_loop.py index 9c4476d06b5c..c8a70cca607b 100644 --- a/python/packages/autogen-core/samples/slow_human_in_loop.py +++ b/python/packages/autogen-core/samples/slow_human_in_loop.py @@ -31,7 +31,6 @@ from typing import Any, Mapping, Optional from autogen_core import ( - AgentId, CancellationToken, DefaultTopicId, FunctionCall, @@ -41,7 +40,6 @@ message_handler, type_subscription, ) -from autogen_core.base.intervention import DefaultInterventionHandler from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import ( AssistantMessage, @@ -207,11 +205,11 @@ async def load_state(self, state: Mapping[str, Any]) -> None: self._model_context.load_state({**state["memory"], "messages": [m for m in state["memory"]["messages"]]}) -class NeedsUserInputHandler(DefaultInterventionHandler): +class NeedsUserInputHandler: def __init__(self): self.question_for_user: GetSlowUserMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def __call__(self, message: Any, message_context: MessageContext) -> Any: if isinstance(message, GetSlowUserMessage): self.question_for_user = message return message @@ -227,11 +225,11 @@ def user_input_content(self) -> str | None: return self.question_for_user.content -class TerminationHandler(DefaultInterventionHandler): +class TerminationHandler: def __init__(self): self.terminateMessage: TerminateMessage | None = None - async def on_publish(self, message: Any, *, sender: AgentId | None) -> Any: + async def __call__(self, message: Any, message_context: MessageContext) -> Any: if isinstance(message, TerminateMessage): self.terminateMessage = message return message diff --git a/python/packages/autogen-core/src/autogen_core/__init__.py b/python/packages/autogen-core/src/autogen_core/__init__.py index 0f085d29bdfe..ffd85d16a18d 100644 --- a/python/packages/autogen-core/src/autogen_core/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/__init__.py @@ -24,6 +24,7 @@ from ._default_subscription import DefaultSubscription, default_subscription, type_subscription from ._default_topic import DefaultTopicId from ._image import Image +from ._intervention import DropMessage, InterventionFunction from ._message_context import MessageContext from ._message_handler_context import MessageHandlerContext from ._routed_agent import RoutedAgent, event, message_handler, rpc @@ -99,4 +100,6 @@ "ROOT_LOGGER_NAME", "EVENT_LOGGER_NAME", "TRACE_LOGGER_NAME", + "DropMessage", + "InterventionFunction", ] diff --git a/python/packages/autogen-core/src/autogen_core/_base_agent.py b/python/packages/autogen-core/src/autogen_core/_base_agent.py index 7af46268c907..b5cf4a94a924 100644 --- a/python/packages/autogen-core/src/autogen_core/_base_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_base_agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import inspect import uuid import warnings @@ -10,7 +11,15 @@ from typing_extensions import Self -from ._rpc import format_rpc_request_topic, is_rpc_response +from autogen_core._types import ( + CancelledRpc, + CancelRpc, + CantHandleMessageResponse, + RpcMessageDroppedResponse, + RpcNoneResponse, +) +from autogen_core.exceptions import CantHandleException + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -24,6 +33,15 @@ from ._subscription_context import SubscriptionInstantiationContext from ._topic import TopicId from ._type_prefix_subscription import TypePrefixSubscription +from ._well_known_topics import ( + format_error_topic, + format_rpc_request_topic, + format_rpc_response_topic, + is_error_message, + is_rpc_cancel, + is_rpc_request, + is_rpc_response, +) T = TypeVar("T", bound=Agent) @@ -109,6 +127,13 @@ def __init__(self, description: str, *, forward_unbound_rpc_responses_to_handler raise ValueError("Agent description must be a string") self._description = description self._pending_rpc_requests: Dict[str, Future[Any]] = {} + self._self_rpc_handlers_in_progress: Dict[str, Future[Any]] = {} + + # TODO: find a way to clean this up over time. + # Essentially, the reason for this existing is if a response is sent but we get an error back for that response + # We need to forward this error back to the original sender, so they can fail their RPC. + # Map of request_id -> (rpc_request_message_id, agent_type_of_rpc_sender) + self._sent_rpc_responses: Dict[str, Tuple[str, str]] = {} self._forward_unbound_rpc_responses_to_handler = forward_unbound_rpc_responses_to_handler @property @@ -128,9 +153,44 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: ... @final async def on_message(self, message: Any, ctx: MessageContext) -> None: + # Intercept errors for outstanding rpc requests, let the others pass through + if (request_id := is_error_message(ctx.topic_id.type)) is not None: + # Check if this error corresponds to an RPC response we have sent + if request_id in self._sent_rpc_responses: + # The recipient we were trying to send a response to never got this response, so we're going to send an error to them instead of the original message + # If this message gets dropped, we're just going to ignore things + original_rpc_request_message_id, agent_type_of_rpc_sender = self._sent_rpc_responses[request_id] + error_topic = format_error_topic( + error_recipient_agent_type=agent_type_of_rpc_sender, request_id=original_rpc_request_message_id + ) + await self.publish_message( + RpcMessageDroppedResponse(original_rpc_request_message_id), TopicId(error_topic, self.id.key) + ) + # Check if we have a pending RPC that is error corresponds to + elif request_id in self._pending_rpc_requests: + self._pending_rpc_requests[request_id].set_exception(message) + del self._pending_rpc_requests[request_id] + else: + await self.on_message_impl(message, ctx) + + return None + + # Intercept RPC cancel + if (request_id := is_rpc_cancel(ctx.topic_id.type)) is not None: + if request_id in self._self_rpc_handlers_in_progress: + if isinstance(message, CancelRpc): + self._self_rpc_handlers_in_progress[request_id].cancel() + del self._self_rpc_handlers_in_progress[request_id] + + return None + # Intercept RPC responses if (request_id := is_rpc_response(ctx.topic_id.type)) is not None: if request_id in self._pending_rpc_requests: + if isinstance(message, RpcNoneResponse): + message = None + if isinstance(message, CancelledRpc): + self._pending_rpc_requests[request_id].cancel() self._pending_rpc_requests[request_id].set_result(message) del self._pending_rpc_requests[request_id] elif self._forward_unbound_rpc_responses_to_handler: @@ -142,7 +202,17 @@ async def on_message(self, message: Any, ctx: MessageContext) -> None: ) return None - await self.on_message_impl(message, ctx) + try: + await self.on_message_impl(message, ctx) + # If the agent signalled it cannot handle this message, and it was an RPC request. Let's deliver this error to the RPC sender so they know. + except CantHandleException: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + error_topic = format_error_topic(error_recipient_agent_type=requestor_type, request_id=ctx.message_id) + await self.publish_message( + CantHandleMessageResponse(message_id=ctx.message_id), TopicId(error_topic, self.id.key) + ) + else: + raise async def send_message( self, @@ -150,6 +220,7 @@ async def send_message( recipient: AgentId, *, cancellation_token: CancellationToken | None = None, + timeout: float | None = None, ) -> Any: """See :py:meth:`autogen_core.AgentRuntime.send_message` for more information.""" if cancellation_token is None: @@ -173,7 +244,30 @@ async def send_message( self._pending_rpc_requests[request_id] = future - return future + async with asyncio.timeout(timeout): + return await future + + async def _rpc_response(self, handler_return_value: Any, ctx: MessageContext) -> None: + if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: + if handler_return_value is None: + handler_return_value = RpcNoneResponse() + + response_topic_id = TopicId( + type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), + source=self.id.key, + ) + message_id = str(uuid.uuid4()) + # Intentionally accessing a private attribute here + # We store this so that if the response is dropped, we can send an error to the client instead. + # request_id -> (rpc_request_message_id, agent_type_of_rpc_sender) + self._sent_rpc_responses[message_id] = (ctx.message_id, requestor_type) # type: ignore + + await self.publish_message( + message=handler_return_value, + topic_id=response_topic_id, + cancellation_token=ctx.cancellation_token, + message_id=message_id, + ) async def publish_message( self, @@ -181,8 +275,11 @@ async def publish_message( topic_id: TopicId, *, cancellation_token: CancellationToken | None = None, + message_id: str | None = None, ) -> None: - await self._runtime.publish_message(message, topic_id, sender=self.id, cancellation_token=cancellation_token) + await self._runtime.publish_message( + message, topic_id, sender=self.id, cancellation_token=cancellation_token, message_id=message_id + ) async def save_state(self) -> Mapping[str, Any]: warnings.warn("save_state not implemented", stacklevel=2) diff --git a/python/packages/autogen-core/src/autogen_core/_cancellation_token.py b/python/packages/autogen-core/src/autogen_core/_cancellation_token.py index 5aa44903963f..a4a089b3113c 100644 --- a/python/packages/autogen-core/src/autogen_core/_cancellation_token.py +++ b/python/packages/autogen-core/src/autogen_core/_cancellation_token.py @@ -1,26 +1,29 @@ +import inspect import threading from asyncio import Future -from typing import Any, Callable, List +from typing import Any, Awaitable, Callable, List class CancellationToken: def __init__(self) -> None: self._cancelled: bool = False self._lock: threading.Lock = threading.Lock() - self._callbacks: List[Callable[[], None]] = [] + self._callbacks: List[Callable[[], None] | Callable[[], Awaitable[None]]] = [] - def cancel(self) -> None: + async def cancel(self) -> None: with self._lock: if not self._cancelled: self._cancelled = True for callback in self._callbacks: - callback() + res = callback() + if inspect.isawaitable(res): + await res def is_cancelled(self) -> bool: with self._lock: return self._cancelled - def add_callback(self, callback: Callable[[], None]) -> None: + def add_callback(self, callback: Callable[[], None] | Callable[[], Awaitable[None]]) -> None: with self._lock: if self._cancelled: callback() diff --git a/python/packages/autogen-core/src/autogen_core/_intervention.py b/python/packages/autogen-core/src/autogen_core/_intervention.py index c18a529aae6f..e13359747188 100644 --- a/python/packages/autogen-core/src/autogen_core/_intervention.py +++ b/python/packages/autogen-core/src/autogen_core/_intervention.py @@ -1,4 +1,6 @@ -from typing import Any, Awaitable, Callable, Protocol, final +from typing import Any, Awaitable, Callable, final + +from autogen_core._message_context import MessageContext __all__ = [ "DropMessage", @@ -10,4 +12,6 @@ class DropMessage: ... -InterventionFunction = Callable[[Any], Any | Awaitable[type[DropMessage]]] +InterventionFunction = Callable[ + [Any, MessageContext], Any | Awaitable[Any] | type[DropMessage] | Awaitable[type[DropMessage]] +] diff --git a/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py index 6134844abcfb..6d8171dae52b 100644 --- a/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/_publish_based_rpc.py @@ -3,13 +3,21 @@ import warnings from typing import Any +from autogen_core._types import CancelledRpc, CancelRpc, CantHandleMessageResponse, RpcMessageDroppedResponse +from autogen_core.exceptions import CantHandleException, MessageDroppedException + from ._agent_id import AgentId from ._agent_runtime import AgentRuntime from ._cancellation_token import CancellationToken from ._closure_agent import ClosureAgent, ClosureContext from ._message_context import MessageContext -from ._rpc import format_rpc_request_topic, format_rpc_response_topic from ._topic import TopicId +from ._well_known_topics import ( + format_error_topic, + format_rpc_cancel_topic, + format_rpc_request_topic, + format_rpc_response_topic, +) class PublishBasedRpcMixin(AgentRuntime): @@ -19,7 +27,11 @@ async def send_message( recipient: AgentId, *, cancellation_token: CancellationToken | None = None, + timeout: float | None = None, ) -> Any: + if cancellation_token is None: + cancellation_token = CancellationToken() + rpc_request_id = str(uuid.uuid4()) # TODO add "-" to topic and agent type allowed characters in spec closure_agent_type = f"rpc_receiver_{recipient.type}_{rpc_request_id}" @@ -28,11 +40,27 @@ async def send_message( expected_response_topic_type = format_rpc_response_topic( rpc_sender_agent_type=closure_agent_type, request_id=rpc_request_id ) + expected_error_topic_type = format_error_topic(closure_agent_type, request_id=rpc_request_id) async def set_result(closure_context: ClosureContext, message: Any, ctx: MessageContext) -> None: assert ctx.topic_id is not None if ctx.topic_id.type == expected_response_topic_type: - future.set_result(message) + if isinstance(message, CancelledRpc): + future.cancel() + else: + future.set_result(message) + elif ctx.topic_id.type == expected_error_topic_type: + # Well known things we handle - dropped message, cant handle + # If the message is for a dropped message + if isinstance(message, CantHandleMessageResponse): + future.set_exception(CantHandleException()) + if isinstance(message, RpcMessageDroppedResponse): + future.set_exception(MessageDroppedException()) + else: + warnings.warn( + f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}.", + stacklevel=2, + ) else: warnings.warn( f"{closure_agent_type} received an unexpected message on topic type {ctx.topic_id.type}. Expected {expected_response_topic_type}", @@ -54,10 +82,20 @@ async def set_result(closure_context: ClosureContext, message: Any, ctx: Message await self.publish_message( message=message, topic_id=TopicId(type=rpc_request_topic_id, source=recipient.key), - cancellation_token=cancellation_token, message_id=rpc_request_id, + sender=AgentId(type=closure_agent_type, key=recipient.key), ) - return await future + async def send_cancel(): + cancel_topic = format_rpc_cancel_topic(rpc_recipient_agent_type=recipient.type, request_id=rpc_request_id) + await self.publish_message( + message=CancelRpc(), + topic_id=TopicId(cancel_topic, recipient.key), + ) + + cancellation_token.add_callback(send_cancel) + + async with asyncio.timeout(timeout): + return await future # register a closure agent... diff --git a/python/packages/autogen-core/src/autogen_core/_routed_agent.py b/python/packages/autogen-core/src/autogen_core/_routed_agent.py index 5c41ce28de7f..a7cfd5ed55a0 100644 --- a/python/packages/autogen-core/src/autogen_core/_routed_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_routed_agent.py @@ -1,5 +1,7 @@ +import asyncio import logging import warnings +from asyncio import CancelledError from functools import wraps from typing import ( Any, @@ -19,12 +21,13 @@ runtime_checkable, ) +from autogen_core._types import CancelledRpc + from ._base_agent import BaseAgent from ._message_context import MessageContext -from ._rpc import format_rpc_response_topic, is_rpc_request from ._serialization import MessageSerializer, try_get_known_serializers_for_type -from ._topic import TopicId from ._type_helpers import AnyType, get_types +from ._well_known_topics import is_rpc_request from .exceptions import CantHandleException logger = logging.getLogger("autogen_core") @@ -160,23 +163,13 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None # Dont return, but publish it if you need to... # Any return is treated as a response to the RPC request and is published accordingly - if return_value is not None: - if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: - response_topic_id = TopicId( - type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), - source=self.id.key, - ) - - await self.publish_message( - message=return_value, - topic_id=response_topic_id, - cancellation_token=ctx.cancellation_token, - ) - else: - warnings.warn( - "Returning a value from a message handler that is not an RPC request. This value will be ignored.", - stacklevel=2, - ) + if return_value is not None and is_rpc_request(ctx.topic_id.type) is None: + warnings.warn( + "Returning a value from a message handler that is not an RPC request. This value will be ignored.", + stacklevel=2, + ) + else: + await self._rpc_response(return_value, ctx) # type: ignore wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) @@ -410,7 +403,18 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None else: logger.warning(f"Message type {type(message)} not in target types {target_types}") - return_value = await func(self, message, ctx) + # Should be an rpc request, as the match function should have filtered it + assert is_rpc_request(ctx.topic_id.type) is not None + + try: + future = asyncio.ensure_future(func(self, message, ctx)) + self._self_rpc_handlers_in_progress[ctx.message_id] = future # type: ignore + return_value = await future + except CancelledError: + await self._rpc_response(CancelledRpc(), ctx) # type: ignore + return + finally: + del self._self_rpc_handlers_in_progress[ctx.message_id] # type: ignore if AnyType not in return_types and type(return_value) not in return_types: if strict: @@ -418,26 +422,9 @@ async def wrapper(self: AgentT, message: ReceivesT, ctx: MessageContext) -> None else: logger.warning(f"Return type {type(return_value)} not in return types {return_types}") - # Dont return, but publish it if you need to... + # Dont return, but publish # Any return is treated as a response to the RPC request and is published accordingly - - if return_value is not None: - if (requestor_type := is_rpc_request(ctx.topic_id.type)) is not None: - response_topic_id = TopicId( - type=format_rpc_response_topic(rpc_sender_agent_type=requestor_type, request_id=ctx.message_id), - source=self.id.key, - ) - - await self.publish_message( - message=return_value, - topic_id=response_topic_id, - cancellation_token=ctx.cancellation_token, - ) - else: - warnings.warn( - "Returning a value from a message handler that is not an RPC request. This value will be ignored.", - stacklevel=2, - ) + await self._rpc_response(return_value, ctx) # type: ignore wrapper_handler = cast(MessageHandler[AgentT, ReceivesT, ProducesT], wrapper) wrapper_handler.target_types = list(target_types) @@ -528,6 +515,11 @@ async def on_message_impl(self, message: Any, ctx: MessageContext) -> None: for h in handlers: if h.router(message, ctx): await h(self, message, ctx) + return + + if is_rpc_request(ctx.topic_id.type): + raise CantHandleException(f"No RPC handler found for message type {key_type}") + await self.on_unhandled_message(message, ctx) async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: diff --git a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py index 8db453f6dc82..b46fbacdcbb0 100644 --- a/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py +++ b/python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py @@ -12,27 +12,29 @@ from enum import Enum from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast -from ._message_context import MessageContext -from ._message_handler_context import MessageHandlerContext -from ._subscription import Subscription from opentelemetry.trace import TracerProvider from typing_extensions import deprecated +from autogen_core._types import RpcMessageDroppedResponse + from ._agent import Agent from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext from ._agent_metadata import AgentMetadata from ._agent_runtime import AgentRuntime from ._agent_type import AgentType -from ._topic import TopicId -from ._subscription_context import SubscriptionInstantiationContext from ._cancellation_token import CancellationToken -from ._intervention import DropMessage +from ._intervention import DropMessage, InterventionFunction +from ._message_context import MessageContext +from ._message_handler_context import MessageHandlerContext from ._publish_based_rpc import PublishBasedRpcMixin from ._runtime_impl_helpers import SubscriptionManager, get_impl from ._serialization import MessageSerializer, SerializationRegistry +from ._subscription import Subscription +from ._subscription_context import SubscriptionInstantiationContext from ._telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata -from .base.intervention import InterventionHandler +from ._topic import TopicId +from ._well_known_topics import format_error_topic logger = logging.getLogger("autogen_core") event_logger = logging.getLogger("autogen_core.events") @@ -144,7 +146,7 @@ class SingleThreadedAgentRuntime(PublishBasedRpcMixin, AgentRuntime): def __init__( self, *, - intervention_handlers: List[InterventionHandler] | None = None, + intervention_handlers: List[InterventionFunction] | None = None, tracer_provider: TracerProvider | None = None, ) -> None: self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime")) @@ -286,6 +288,16 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any: self._outstanding_tasks.decrement() # TODO if responses are given for a publish + async def _send_error(self, exception: Any, for_message_id: str, recipient: AgentId) -> None: + topic = format_error_topic(recipient.type, for_message_id) + + # Errors don't have an originating sender + await self.publish_message( + message=exception, + topic_id=TopicId(topic, recipient.key), + sender=None, + ) + async def process_next(self) -> None: """Process the next message in the queue.""" @@ -296,22 +308,40 @@ async def process_next(self) -> None: message_envelope = self._message_queue.pop(0) message = message_envelope.message - sender = message_envelope.sender if self._intervention_handlers is not None: + message_context = MessageContext( + sender=message_envelope.sender, + topic_id=message_envelope.topic_id, + cancellation_token=message_envelope.cancellation_token, + message_id=message_envelope.message_id, + ) for handler in self._intervention_handlers: with self._tracer_helper.trace_block( "intercept", handler.__class__.__name__, parent=message_envelope.metadata ): try: - temp_message = await handler.on_publish(message, sender=sender) + temp_message = handler(message, message_context) + if inspect.isawaitable(temp_message): + temp_message = await temp_message + _warn_if_none(temp_message, "on_publish") except BaseException as e: - # TODO: we should raise the intervention exception to the publisher. logger.error(f"Exception raised in in intervention handler: {e}", exc_info=True) return + if temp_message is DropMessage or isinstance(temp_message, DropMessage): # TODO log message dropped + # Send message dropped to sender + + # If it's None, then we don't know who to send the message to + if message_envelope.sender is not None: + await self._send_error( + RpcMessageDroppedResponse(message_id=message_envelope.message_id), + message_envelope.message_id, + message_envelope.sender, + ) + return message_envelope.message = temp_message diff --git a/python/packages/autogen-core/src/autogen_core/_types.py b/python/packages/autogen-core/src/autogen_core/_types.py index 5e3850ffae8b..d261909b98e6 100644 --- a/python/packages/autogen-core/src/autogen_core/_types.py +++ b/python/packages/autogen-core/src/autogen_core/_types.py @@ -10,3 +10,29 @@ class FunctionCall: arguments: str # Function to call name: str + + +# TODO: Make this xlang friendly +@dataclass +class RpcNoneResponse: + pass + + +@dataclass +class RpcMessageDroppedResponse: + message_id: str + + +@dataclass +class CantHandleMessageResponse: + message_id: str + + +@dataclass +class CancelRpc: + pass + + +@dataclass +class CancelledRpc: + pass diff --git a/python/packages/autogen-core/src/autogen_core/_rpc.py b/python/packages/autogen-core/src/autogen_core/_well_known_topics.py similarity index 52% rename from python/packages/autogen-core/src/autogen_core/_rpc.py rename to python/packages/autogen-core/src/autogen_core/_well_known_topics.py index d6554e71844d..e6f17a8d88fc 100644 --- a/python/packages/autogen-core/src/autogen_core/_rpc.py +++ b/python/packages/autogen-core/src/autogen_core/_well_known_topics.py @@ -7,6 +7,10 @@ def format_rpc_request_topic(rpc_recipient_agent_type: str, rpc_sender_agent_typ return f"{rpc_recipient_agent_type}:rpc_request={rpc_sender_agent_type}" +def format_rpc_cancel_topic(rpc_recipient_agent_type: str, request_id: str) -> str: + return f"{rpc_recipient_agent_type}:rpc_cancel={request_id}" + + def format_rpc_response_topic(rpc_sender_agent_type: str, request_id: str) -> str: return f"{rpc_sender_agent_type}:rpc_response={request_id}" @@ -21,6 +25,16 @@ def is_rpc_response(topic_type: str) -> Optional[str]: return None +# If is an rpc response, return the request id +def is_rpc_cancel(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_cancel= + for segment in topic_segments: + if segment.startswith("rpc_cancel="): + return segment[len("rpc_cancel=") :] + return None + + # If is an rpc response, return the requestor agent type def is_rpc_request(topic_type: str) -> Optional[str]: topic_segments = topic_type.split(":") @@ -29,3 +43,17 @@ def is_rpc_request(topic_type: str) -> Optional[str]: if segment.startswith("rpc_request="): return segment[len("rpc_request=") :] return None + + +# {AgentType}:error={RequestId} - error message that corresponds to a request +def is_error_message(topic_type: str) -> Optional[str]: + topic_segments = topic_type.split(":") + # Find if there is a segment starting with :rpc_response= + for segment in topic_segments: + if segment.startswith("error="): + return segment[len("error=") :] + return None + + +def format_error_topic(error_recipient_agent_type: str, request_id: str) -> str: + return f"{error_recipient_agent_type}:error={request_id}" diff --git a/python/packages/autogen-core/src/autogen_core/base/intervention.py b/python/packages/autogen-core/src/autogen_core/base/intervention.py index 1e7d1937414a..a6356a010091 100644 --- a/python/packages/autogen-core/src/autogen_core/base/intervention.py +++ b/python/packages/autogen-core/src/autogen_core/base/intervention.py @@ -1,4 +1,4 @@ -from typing import Any, Awaitable, Callable, Protocol, final +from typing import Any, Protocol from .._agent_id import AgentId diff --git a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py index 08d8f4b25376..3792d287281b 100644 --- a/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py +++ b/python/packages/autogen-core/src/autogen_core/tool_agent/_tool_agent.py @@ -2,7 +2,8 @@ from dataclasses import dataclass from typing import List -from .. import FunctionCall, MessageContext, RoutedAgent, message_handler +from .. import FunctionCall, MessageContext, RoutedAgent +from .._routed_agent import rpc from ..models import FunctionExecutionResult from ..tools import Tool @@ -16,7 +17,7 @@ @dataclass -class ToolException(BaseException): +class ToolException: call_id: str content: str @@ -58,8 +59,10 @@ def __init__( def tools(self) -> List[Tool]: return self._tools - @message_handler - async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) -> FunctionExecutionResult: + @rpc + async def handle_function_call( + self, message: FunctionCall, ctx: MessageContext + ) -> FunctionExecutionResult | ToolNotFoundException | InvalidToolArgumentsException | ToolExecutionException: """Handles a `FunctionCall` message by executing the requested tool with the provided arguments. Args: @@ -76,16 +79,16 @@ async def handle_function_call(self, message: FunctionCall, ctx: MessageContext) """ tool = next((tool for tool in self._tools if tool.name == message.name), None) if tool is None: - raise ToolNotFoundException(call_id=message.id, content=f"Error: Tool not found: {message.name}") + return ToolNotFoundException(call_id=message.id, content=f"Error: Tool not found: {message.name}") else: try: arguments = json.loads(message.arguments) result = await tool.run_json(args=arguments, cancellation_token=ctx.cancellation_token) result_as_str = tool.return_value_as_string(result) - except json.JSONDecodeError as e: - raise InvalidToolArgumentsException( + except json.JSONDecodeError: + return InvalidToolArgumentsException( call_id=message.id, content=f"Error: Invalid arguments: {message.arguments}" - ) from e + ) except Exception as e: - raise ToolExecutionException(call_id=message.id, content=f"Error: {e}") from e + return ToolExecutionException(call_id=message.id, content=f"Error: {e}") return FunctionExecutionResult(content=result_as_str, call_id=message.id) diff --git a/python/packages/autogen-core/tests/test_cancellation.py b/python/packages/autogen-core/tests/test_cancellation.py index 34a5d7f962c4..4d803813cef2 100644 --- a/python/packages/autogen-core/tests/test_cancellation.py +++ b/python/packages/autogen-core/tests/test_cancellation.py @@ -9,8 +9,8 @@ MessageContext, RoutedAgent, SingleThreadedAgentRuntime, - message_handler, ) +from autogen_core._routed_agent import rpc @dataclass @@ -28,7 +28,7 @@ def __init__(self) -> None: self.called = False self.cancelled = False - @message_handler + @rpc async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True sleep = asyncio.ensure_future(asyncio.sleep(100)) @@ -48,7 +48,7 @@ def __init__(self, nested_agent: AgentId) -> None: self.cancelled = False self._nested_agent = nested_agent - @message_handler + @rpc async def on_new_message(self, message: MessageType, ctx: MessageContext) -> MessageType: self.called = True response = self.send_message(message, self._nested_agent, cancellation_token=ctx.cancellation_token) @@ -74,9 +74,9 @@ async def test_cancellation_with_token() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() + runtime.start() - token.cancel() + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -85,6 +85,7 @@ async def test_cancellation_with_token() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled + await runtime.stop() @pytest.mark.asyncio @@ -107,8 +108,9 @@ async def test_nested_cancellation_only_outer_called() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() - token.cancel() + runtime.start() + + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -120,6 +122,7 @@ async def test_nested_cancellation_only_outer_called() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called is False assert long_running_agent.cancelled is False + await runtime.stop() @pytest.mark.asyncio @@ -143,10 +146,9 @@ async def test_nested_cancellation_inner_called() -> None: while len(runtime.unprocessed_messages) == 0: await asyncio.sleep(0.01) - await runtime.process_next() - # allow the inner agent to process - await runtime.process_next() - token.cancel() + runtime.start() + + await token.cancel() with pytest.raises(asyncio.CancelledError): await response @@ -158,3 +160,4 @@ async def test_nested_cancellation_inner_called() -> None: long_running_agent = await runtime.try_get_underlying_agent_instance(long_running_id, type=LongRunningAgent) assert long_running_agent.called assert long_running_agent.cancelled + await runtime.stop() diff --git a/python/packages/autogen-core/tests/test_intervention.py b/python/packages/autogen-core/tests/test_intervention.py index a046201feff3..f2a58dc72904 100644 --- a/python/packages/autogen-core/tests/test_intervention.py +++ b/python/packages/autogen-core/tests/test_intervention.py @@ -1,17 +1,19 @@ +from typing import Any + import pytest -from autogen_core import AgentId, SingleThreadedAgentRuntime -from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage +from autogen_core import AgentId, DropMessage, MessageContext, SingleThreadedAgentRuntime +from autogen_core._well_known_topics import is_rpc_request, is_rpc_response from autogen_core.exceptions import MessageDroppedException from autogen_test_utils import LoopbackAgent, MessageType @pytest.mark.asyncio async def test_intervention_count_messages() -> None: - class DebugInterventionHandler(DefaultInterventionHandler): + class DebugInterventionHandler: def __init__(self) -> None: self.num_messages = 0 - async def on_send(self, message: MessageType, *, sender: AgentId | None, recipient: AgentId) -> MessageType: + async def __call__(self, message: MessageType, message_context: MessageContext) -> MessageType: self.num_messages += 1 return message @@ -21,24 +23,23 @@ async def on_send(self, message: MessageType, *, sender: AgentId | None, recipie loopback = AgentId("name", key="default") runtime.start() - _response = await runtime.send_message(MessageType(), recipient=loopback) + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() - assert handler.num_messages == 1 + # 2 since request and response + assert handler.num_messages == 2 loopback_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) assert loopback_agent.num_calls == 1 @pytest.mark.asyncio -async def test_intervention_drop_send() -> None: - class DropSendInterventionHandler(DefaultInterventionHandler): - async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId - ) -> MessageType | type[DropMessage]: +async def test_intervention_drop_rpc_request() -> None: + async def handler(message: Any, message_context: MessageContext) -> Any | type[DropMessage]: + if is_rpc_request(message_context.topic_id.type): return DropMessage + return message - handler = DropSendInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) await LoopbackAgent.register(runtime, "name", LoopbackAgent) @@ -46,7 +47,7 @@ async def on_send( runtime.start() with pytest.raises(MessageDroppedException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() @@ -55,74 +56,21 @@ async def on_send( @pytest.mark.asyncio -async def test_intervention_drop_response() -> None: - class DropResponseInterventionHandler(DefaultInterventionHandler): - async def on_response( - self, message: MessageType, *, sender: AgentId, recipient: AgentId | None - ) -> MessageType | type[DropMessage]: +async def test_intervention_drop_rpc_esponse() -> None: + async def handler(message: Any, message_context: MessageContext) -> Any | type[DropMessage]: + # Only drop the response and not the request! + if is_rpc_response(message_context.topic_id.type): return DropMessage - handler = DropResponseInterventionHandler() - runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - - await LoopbackAgent.register(runtime, "name", LoopbackAgent) - loopback = AgentId("name", key="default") - runtime.start() - - with pytest.raises(MessageDroppedException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + return message - await runtime.stop() - - -@pytest.mark.asyncio -async def test_intervention_raise_exception_on_send() -> None: - class InterventionException(Exception): - pass - - class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_send( - self, message: MessageType, *, sender: AgentId | None, recipient: AgentId - ) -> MessageType | type[DropMessage]: # type: ignore - raise InterventionException - - handler = ExceptionInterventionHandler() runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) await LoopbackAgent.register(runtime, "name", LoopbackAgent) loopback = AgentId("name", key="default") runtime.start() - with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=loopback) - - await runtime.stop() - - long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) - assert long_running_agent.num_calls == 0 - - -@pytest.mark.asyncio -async def test_intervention_raise_exception_on_respond() -> None: - class InterventionException(Exception): - pass - - class ExceptionInterventionHandler(DefaultInterventionHandler): # type: ignore - async def on_response( - self, message: MessageType, *, sender: AgentId, recipient: AgentId | None - ) -> MessageType | type[DropMessage]: # type: ignore - raise InterventionException - - handler = ExceptionInterventionHandler() - runtime = SingleThreadedAgentRuntime(intervention_handlers=[handler]) - - await LoopbackAgent.register(runtime, "name", LoopbackAgent) - loopback = AgentId("name", key="default") - runtime.start() - with pytest.raises(InterventionException): - _response = await runtime.send_message(MessageType(), recipient=loopback) + with pytest.raises(MessageDroppedException): + _response = await runtime.send_message(MessageType(), recipient=loopback, timeout=120) await runtime.stop() - - long_running_agent = await runtime.try_get_underlying_agent_instance(loopback, type=LoopbackAgent) - assert long_running_agent.num_calls == 1 diff --git a/python/packages/autogen-core/tests/test_routed_agent.py b/python/packages/autogen-core/tests/test_routed_agent.py index 1256f3103bd2..ec30b575c078 100644 --- a/python/packages/autogen-core/tests/test_routed_agent.py +++ b/python/packages/autogen-core/tests/test_routed_agent.py @@ -14,7 +14,8 @@ message_handler, rpc, ) -from autogen_core._rpc import is_rpc_request +from autogen_core._well_known_topics import is_rpc_request +from autogen_core.exceptions import CantHandleException from autogen_test_utils import LoopbackAgent @@ -71,7 +72,7 @@ async def test_message_handler_router() -> None: # Send an RPC message. runtime.start() - await runtime.send_message(MessageType(), recipient=agent_id) + await runtime.send_message(MessageType(), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=CounterAgent) assert agent.num_calls_broadcast == 1 @@ -114,14 +115,14 @@ async def test_routed_agent_message_matching() -> None: assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True assert agent.handler_two_called is False runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(TestMessage("two"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RoutedAgentMessageCustomMatch) assert agent.handler_one_called is True @@ -167,7 +168,8 @@ async def test_event() -> None: # Send an RPC message, expect no change. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + with pytest.raises(CantHandleException): + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=EventAgent) assert agent.num_calls[0] == 1 @@ -199,7 +201,7 @@ async def test_rpc() -> None: # Send an RPC message. runtime.start() - await runtime.send_message(TestMessage("one"), recipient=agent_id) + await runtime.send_message(TestMessage("one"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 @@ -207,7 +209,7 @@ async def test_rpc() -> None: # Send another RPC message. runtime.start() - await runtime.send_message(TestMessage("two"), recipient=agent_id) + await runtime.send_message(TestMessage("two"), recipient=agent_id, timeout=60) await runtime.stop_when_idle() agent = await runtime.try_get_underlying_agent_instance(agent_id, type=RPCAgent) assert agent.num_calls[0] == 1 diff --git a/python/packages/autogen-core/tests/test_tool_agent.py b/python/packages/autogen-core/tests/test_tool_agent.py index d0d6dec8915b..98ef8e7e30c5 100644 --- a/python/packages/autogen-core/tests/test_tool_agent.py +++ b/python/packages/autogen-core/tests/test_tool_agent.py @@ -63,23 +63,27 @@ async def test_tool_agent() -> None: assert result == FunctionExecutionResult(call_id="1", content="pass") # Test raise function - with pytest.raises(ToolExecutionException): - await runtime.send_message(FunctionCall(id="2", arguments=json.dumps({"input": "raise"}), name="raise"), agent) + response = await runtime.send_message( + FunctionCall(id="2", arguments=json.dumps({"input": "raise"}), name="raise"), agent + ) + assert isinstance(response, ToolExecutionException) # Test invalid tool name - with pytest.raises(ToolNotFoundException): - await runtime.send_message(FunctionCall(id="3", arguments=json.dumps({"input": "pass"}), name="invalid"), agent) + response = await runtime.send_message( + FunctionCall(id="3", arguments=json.dumps({"input": "pass"}), name="invalid"), agent + ) + assert isinstance(response, ToolNotFoundException) # Test invalid arguments - with pytest.raises(InvalidToolArgumentsException): - await runtime.send_message(FunctionCall(id="3", arguments="invalid json /xd", name="pass"), agent) + response = await runtime.send_message(FunctionCall(id="3", arguments="invalid json /xd", name="pass"), agent) + assert isinstance(response, InvalidToolArgumentsException) # Test sleep and cancel. token = CancellationToken() result_future = runtime.send_message( FunctionCall(id="3", arguments=json.dumps({"input": "sleep"}), name="sleep"), agent, cancellation_token=token ) - token.cancel() + await token.cancel() with pytest.raises(asyncio.CancelledError): await result_future