Skip to content

Commit

Permalink
Implement default sub and topic (#398)
Browse files Browse the repository at this point in the history
* Implement default sub and topic

* format

* update test
  • Loading branch information
jackgerrits authored Aug 23, 2024
1 parent 8f082ce commit 4c964fa
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 14 deletions.
19 changes: 11 additions & 8 deletions python/src/agnext/application/_single_threaded_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
Expand Down Expand Up @@ -264,10 +265,11 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
is_rpc=True,
cancellation_token=message_envelope.cancellation_token,
)
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
with MessageHandlerContext.populate_context(recipient_agent.id):
response = await recipient_agent.on_message(
message_envelope.message,
ctx=message_context,
)
except BaseException as e:
message_envelope.future.set_exception(e)
return
Expand Down Expand Up @@ -313,10 +315,11 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No
cancellation_token=message_envelope.cancellation_token,
)
agent = await self._get_agent(agent_id)
future = agent.on_message(
message_envelope.message,
ctx=message_context,
)
with MessageHandlerContext.populate_context(agent.id):
future = agent.on_message(
message_envelope.message,
ctx=message_context,
)
responses.append(future)

try:
Expand Down
7 changes: 5 additions & 2 deletions python/src/agnext/application/_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
AgentType,
CancellationToken,
MessageContext,
MessageHandlerContext,
Subscription,
SubscriptionInstantiationContext,
TopicId,
Expand Down Expand Up @@ -323,7 +324,8 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:

# Call the target agent.
try:
result = await target_agent.on_message(message, ctx=message_context)
with MessageHandlerContext.populate_context(target_agent.id):
result = await target_agent.on_message(message, ctx=message_context)
except BaseException as e:
response_message = agent_worker_pb2.Message(
response=agent_worker_pb2.RpcResponse(
Expand Down Expand Up @@ -377,7 +379,8 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
cancellation_token=CancellationToken(),
)
agent = await self._get_agent(agent_id)
future = agent.on_message(message, ctx=message_context)
with MessageHandlerContext.populate_context(agent.id):
future = agent.on_message(message, ctx=message_context)
responses.append(future)
# Wait for all responses.
try:
Expand Down
13 changes: 12 additions & 1 deletion python/src/agnext/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
"""

from ._closure_agent import ClosureAgent
from ._default_subscription import DefaultSubscription
from ._default_topic import DefaultTopicId
from ._image import Image
from ._type_routed_agent import TypeRoutedAgent, message_handler
from ._type_subscription import TypeSubscription
from ._types import FunctionCall

__all__ = ["Image", "TypeRoutedAgent", "ClosureAgent", "message_handler", "FunctionCall", "TypeSubscription"]
__all__ = [
"Image",
"TypeRoutedAgent",
"ClosureAgent",
"message_handler",
"FunctionCall",
"TypeSubscription",
"DefaultSubscription",
"DefaultTopicId",
]
32 changes: 32 additions & 0 deletions python/src/agnext/components/_default_subscription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from agnext.core.exceptions import CantHandleException

from ..core import SubscriptionInstantiationContext
from ._type_subscription import TypeSubscription


class DefaultSubscription(TypeSubscription):
def __init__(self, topic_type: str = "default", agent_type: str | None = None):
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents.
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context.
Example:
.. code-block:: python
await runtime.register("MyAgent", agent_factory, lambda: [DefaultSubscription()])
Args:
topic_type (str, optional): The topic type to subscribe to. Defaults to "default".
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context.
"""

if agent_type is None:
try:
agent_type = SubscriptionInstantiationContext.agent_type().type
except RuntimeError as e:
raise CantHandleException(
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register"
) from e

super().__init__(topic_type, agent_type)
21 changes: 21 additions & 0 deletions python/src/agnext/components/_default_topic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from ..core import MessageHandlerContext, TopicId


class DefaultTopicId(TopicId):
def __init__(self, type: str = "default", source: str | None = None) -> None:
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId.
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default".
Args:
type (str, optional): Topic type to publish message to. Defaults to "default".
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None.
"""
if source is None:
try:
source = MessageHandlerContext.agent_id().key
# If we aren't in the context of a message handler, we use the default source
except RuntimeError:
source = "default"

super().__init__(type, source)
2 changes: 2 additions & 0 deletions python/src/agnext/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._message_context import MessageContext
from ._message_handler_context import MessageHandlerContext
from ._serialization import MESSAGE_TYPE_REGISTRY, Serialization, TypeDeserializer, TypeSerializer
from ._subscription import Subscription
from ._subscription_context import SubscriptionInstantiationContext
Expand All @@ -37,4 +38,5 @@
"Serialization",
"AgentType",
"SubscriptionInstantiationContext",
"MessageHandlerContext",
]
30 changes: 30 additions & 0 deletions python/src/agnext/core/_message_handler_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, ClassVar, Generator

from ._agent_id import AgentId


class MessageHandlerContext:
def __init__(self) -> None:
raise RuntimeError(
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
)

MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("MESSAGE_HANDLER_CONTEXT")

@classmethod
@contextmanager
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]:
token = MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.set(ctx)
try:
yield
finally:
MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.reset(token)

@classmethod
def agent_id(cls) -> AgentId:
try:
return cls.MESSAGE_HANDLER_CONTEXT.get()
except LookupError as e:
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e
65 changes: 62 additions & 3 deletions python/tests/test_runtime.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import pytest
from agnext.application import SingleThreadedAgentRuntime
from agnext.components._type_subscription import TypeSubscription
from agnext.components import TypeSubscription, DefaultTopicId, DefaultSubscription
from agnext.core import AgentId, AgentInstantiationContext
from agnext.core import TopicId
from agnext.core._subscription import Subscription
from agnext.core._subscription_context import SubscriptionInstantiationContext
from agnext.core import Subscription
from agnext.core import SubscriptionInstantiationContext
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent


Expand Down Expand Up @@ -163,3 +163,62 @@ async def test_register_factory_direct_list() -> None:
# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0

@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 0

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(AgentId("name", key="other"), type=LoopbackAgent)
assert other_long_running_agent.num_calls == 1

0 comments on commit 4c964fa

Please sign in to comment.