Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions python/packages/autogen-core/src/autogen_core/_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,57 @@
"""
...

async def register_agent_instance(
self,
agent_instance: T | Awaitable[T],
) -> AgentId:
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.

.. note::

This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.

Example:

.. code-block:: python

from dataclasses import dataclass
from typing import Optional

from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
from autogen_core.models import UserMessage


@dataclass
class MyMessage:
content: str


class MyAgent(RoutedAgent):
def __init__(self, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None) -> None:
super().__init__("My core agent", runtime, agent_id)

@event
async def handler(self, message: UserMessage, context: MessageContext) -> None:
print("Event received: ", message.content)


async def main() -> None:
runtime: AgentRuntime = ... # type: ignore
agent = MyAgent(runtime=runtime, agent_id=AgentId(type="my_agent", key="default"))
await runtime.register_agent_instance(agent)


import asyncio

asyncio.run(main())


Args:
agent_instance (T | Awaitable[T]): A concrete instance of the agent.
"""
...

Check warning on line 182 in python/packages/autogen-core/src/autogen_core/_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_agent_runtime.py#L182

Added line #L182 was not covered by tests

# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
Expand Down
82 changes: 72 additions & 10 deletions python/packages/autogen-core/src/autogen_core/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Tuple, Type, TypeVar, final
from typing import Any, Awaitable, Callable, ClassVar, List, Mapping, Optional, Tuple, Type, TypeVar, final

from typing_extensions import Self

Expand Down Expand Up @@ -81,17 +81,33 @@
assert self._id is not None
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)

def __init__(self, description: str) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e
def __init__(
self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None
) -> None:
param_count = 0
if runtime is not None:
param_count += 1
if agent_id is not None:
param_count += 1

if param_count != 0 and param_count != 2:
raise ValueError("BaseAgent must be instantiated with both runtime and agent_id or neither.")
if param_count == 0:
try:
runtime = AgentInstantiationContext.current_runtime()
agent_id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(

Check warning on line 100 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L99-L100

Added lines #L99 - L100 were not covered by tests
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e
else:
if not isinstance(runtime, AgentRuntime):
raise ValueError("Agent must be initialized with runtime of type AgentRuntime")

Check warning on line 105 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L105

Added line #L105 was not covered by tests
if not isinstance(agent_id, AgentId):
raise ValueError("Agent must be initialized with agent_id of type AgentId")

Check warning on line 107 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L107

Added line #L107 was not covered by tests

self._runtime: AgentRuntime = runtime
self._id: AgentId = id
self._id: AgentId = agent_id
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description
Expand Down Expand Up @@ -155,6 +171,52 @@
async def close(self) -> None:
pass

async def register_instance(
self,
*,
skip_class_subscriptions: bool = False,
skip_direct_message_subscription: bool = False,
) -> AgentId:
runtime = self.runtime
agent_id = await runtime.register_agent_instance(agent_instance=self)
if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
subscriptions: List[Subscription] = []
for unbound_subscription in self._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result

Check warning on line 188 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L186-L188

Added lines #L186 - L188 were not covered by tests
else:
subscriptions_list = subscriptions_list_result

Check warning on line 190 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L190

Added line #L190 was not covered by tests

subscriptions.extend(subscriptions_list)

Check warning on line 192 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L192

Added line #L192 was not covered by tests
try:
for subscription in subscriptions:
await runtime.add_subscription(subscription)
except ValueError:

Check warning on line 196 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L195-L196

Added lines #L195 - L196 were not covered by tests
# We don't care if the subscription already exists
pass

Check warning on line 198 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L198

Added line #L198 was not covered by tests

if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
try:
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_id.type + ":",
agent_type=agent_id.type,
)
)
except ValueError:
# We don't care if the subscription already exists
pass

# TODO: deduplication
for _message_type, serializer in self._handles_types():
runtime.add_message_serializer(serializer)

return agent_id

@classmethod
async def register(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
DefaultDict,
List,
Literal,
Optional,
Protocol,
Sequence,
Tuple,
Expand All @@ -18,6 +19,8 @@
runtime_checkable,
)

from ._agent_id import AgentId
from ._agent_runtime import AgentRuntime
from ._base_agent import BaseAgent
from ._message_context import MessageContext
from ._serialization import MessageSerializer, try_get_known_serializers_for_type
Expand Down Expand Up @@ -457,7 +460,9 @@ async def handle_special_rpc_message(self, message: MessageWithContent, ctx: Mes
return Response()
"""

def __init__(self, description: str) -> None:
def __init__(
self, description: str, runtime: Optional[AgentRuntime] = None, agent_id: Optional[AgentId] = None
) -> None:
# Self is already bound to the handlers
self._handlers: DefaultDict[
Type[Any],
Expand All @@ -469,7 +474,7 @@ def __init__(self, description: str) -> None:
for target_type in message_handler.target_types:
self._handlers[target_type].append(message_handler)

super().__init__(description)
super().__init__(description, runtime, agent_id)

async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any | None:
"""Handle a message by routing it to the appropriate message handler.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from ._agent_metadata import AgentMetadata
from ._agent_runtime import AgentRuntime
from ._agent_type import AgentType
from ._base_agent import BaseAgent
from ._cancellation_token import CancellationToken
from ._intervention import DropMessage, InterventionHandler
from ._message_context import MessageContext
Expand Down Expand Up @@ -265,6 +266,7 @@
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
self._agent_instance_types: Dict[str, Type[Agent]] = {}

@property
def unprocessed_messages_count(
Expand Down Expand Up @@ -830,6 +832,37 @@

return type

async def register_agent_instance(
self,
agent_instance: T | Awaitable[T],
) -> AgentId:
def agent_factory() -> T:
raise RuntimeError("Agent factory should not be called when registering an agent instance.")

if inspect.isawaitable(agent_instance):
agent_instance = await agent_instance

Check warning on line 843 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L843

Added line #L843 was not covered by tests

# Agent type does not have the concept of a runtime
if isinstance(agent_instance, BaseAgent):
if agent_instance.runtime is not self:
raise ValueError("Agent instance is associated with a different runtime.")

Check warning on line 848 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L848

Added line #L848 was not covered by tests
agent_id = agent_instance.id

if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")

if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")

Check warning on line 859 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L859

Added line #L859 was not covered by tests
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")

self._instantiated_agents[agent_id] = agent_instance
return agent_id

async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
Expand All @@ -851,7 +884,7 @@
raise ValueError("Agent factory must take 0 or 2 arguments.")

if inspect.isawaitable(agent):
return cast(T, await agent)
agent = cast(T, await agent)

return agent

Expand Down
17 changes: 13 additions & 4 deletions python/packages/autogen-core/tests/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@
async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)

agent1 = NoopAgent(runtime=runtime, agent_id=AgentId("name1", "namespace1"))
assert agent1.runtime == runtime
assert agent1.id == AgentId("name1", "namespace1")

with pytest.raises(ValueError):
NoopAgent(runtime=runtime, agent_id=None)
with pytest.raises(ValueError):
NoopAgent(runtime=None, agent_id=AgentId("name_fail", "namespace_fail"))

# Shows how to set the context for the agent instantiation in a test context
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")
with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))):
agent2 = NoopAgent()
assert agent2.runtime == runtime
assert agent2.id == AgentId("name2", "namespace2")
53 changes: 53 additions & 0 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,59 @@ def agent_factory() -> NoopAgent:
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)


@pytest.mark.asyncio
async def test_agent_type_register_instance() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent2_id = AgentId(type="name", key="notdefault")
agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id)
agent1_dup = NoopAgent(runtime=runtime, agent_id=agent1_id)
agent2 = NoopAgent(runtime=runtime, agent_id=agent2_id)
await agent1.register_instance()
await agent2.register_instance()

assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
with pytest.raises(ValueError):
await agent1_dup.register_instance()


@pytest.mark.asyncio
async def test_agent_type_register_instance_different_types() -> None:
runtime = SingleThreadedAgentRuntime()
agent_id1 = AgentId(type="name", key="noop")
agent_id2 = AgentId(type="name", key="loopback")
agent1 = NoopAgent(runtime=runtime, agent_id=agent_id1)
agent2 = LoopbackAgent(runtime=runtime, agent_id=agent_id2)
await agent1.register_instance()
with pytest.raises(ValueError):
await agent2.register_instance()


@pytest.mark.asyncio
async def test_agent_type_register_instance_publish_new_source() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
agent_id = AgentId(type="name", key="default")
agent1 = LoopbackAgent(runtime=runtime, agent_id=agent_id)
await agent1.register_instance()
await runtime.add_subscription(TypeSubscription("notdefault", "name"))

with pytest.raises(RuntimeError):
runtime.start()
await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
await runtime.stop_when_idle()


@pytest.mark.asyncio
async def test_register_instance_factory() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent1 = NoopAgent(runtime=runtime, agent_id=agent1_id)
await agent1.register_instance()
with pytest.raises(ValueError):
await NoopAgent.register(runtime, "name", lambda: NoopAgent())


@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
Expand Down
Loading
Loading