-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement default sub and topic (#398)
* Implement default sub and topic * format * update test
- Loading branch information
1 parent
8f082ce
commit 4c964fa
Showing
8 changed files
with
175 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from agnext.core.exceptions import CantHandleException | ||
|
||
from ..core import SubscriptionInstantiationContext | ||
from ._type_subscription import TypeSubscription | ||
|
||
|
||
class DefaultSubscription(TypeSubscription): | ||
def __init__(self, topic_type: str = "default", agent_type: str | None = None): | ||
"""The default subscription is designed to be a sensible default for applications that only need global scope for agents. | ||
This topic by default uses the "default" topic type and attempts to detect the agent type to use based on the instantiation context. | ||
Example: | ||
.. code-block:: python | ||
await runtime.register("MyAgent", agent_factory, lambda: [DefaultSubscription()]) | ||
Args: | ||
topic_type (str, optional): The topic type to subscribe to. Defaults to "default". | ||
agent_type (str, optional): The agent type to use for the subscription. Defaults to None, in which case it will attempt to detect the agent type based on the instantiation context. | ||
""" | ||
|
||
if agent_type is None: | ||
try: | ||
agent_type = SubscriptionInstantiationContext.agent_type().type | ||
except RuntimeError as e: | ||
raise CantHandleException( | ||
"If agent_type is not specified DefaultSubscription must be created within the subscription callback in AgentRuntime.register" | ||
) from e | ||
|
||
super().__init__(topic_type, agent_type) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from ..core import MessageHandlerContext, TopicId | ||
|
||
|
||
class DefaultTopicId(TopicId): | ||
def __init__(self, type: str = "default", source: str | None = None) -> None: | ||
"""DefaultTopicId provides a sensible default for the topic_id and source fields of a TopicId. | ||
If created in the context of a message handler, the source will be set to the agent_id of the message handler, otherwise it will be set to "default". | ||
Args: | ||
type (str, optional): Topic type to publish message to. Defaults to "default". | ||
source (str | None, optional): Topic source to publish message to. If None, the source will be set to the agent_id of the message handler if in the context of a message handler, otherwise it will be set to "default". Defaults to None. | ||
""" | ||
if source is None: | ||
try: | ||
source = MessageHandlerContext.agent_id().key | ||
# If we aren't in the context of a message handler, we use the default source | ||
except RuntimeError: | ||
source = "default" | ||
|
||
super().__init__(type, source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from contextlib import contextmanager | ||
from contextvars import ContextVar | ||
from typing import Any, ClassVar, Generator | ||
|
||
from ._agent_id import AgentId | ||
|
||
|
||
class MessageHandlerContext: | ||
def __init__(self) -> None: | ||
raise RuntimeError( | ||
"MessageHandlerContext cannot be instantiated. It is a static class that provides context management for agent instantiation." | ||
) | ||
|
||
MESSAGE_HANDLER_CONTEXT: ClassVar[ContextVar[AgentId]] = ContextVar("MESSAGE_HANDLER_CONTEXT") | ||
|
||
@classmethod | ||
@contextmanager | ||
def populate_context(cls, ctx: AgentId) -> Generator[None, Any, None]: | ||
token = MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.set(ctx) | ||
try: | ||
yield | ||
finally: | ||
MessageHandlerContext.MESSAGE_HANDLER_CONTEXT.reset(token) | ||
|
||
@classmethod | ||
def agent_id(cls) -> AgentId: | ||
try: | ||
return cls.MESSAGE_HANDLER_CONTEXT.get() | ||
except LookupError as e: | ||
raise RuntimeError("MessageHandlerContext.agent_id() must be called within a message handler.") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters