Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/packages/autogen-core/src/autogen_core/_agent.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Mapping, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable

from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._message_context import MessageContext

# Forward declaration for type checking only
if TYPE_CHECKING:
from ._agent_runtime import AgentRuntime

Check warning on line 9 in python/packages/autogen-core/src/autogen_core/_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_agent.py#L9

Added line #L9 was not covered by tests


@runtime_checkable
class Agent(Protocol):
Expand All @@ -17,6 +21,15 @@
"""ID of the agent."""
...

async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
"""Function used to bind an Agent instance to an `AgentRuntime`.

Args:
agent_id (AgentId): ID of the agent.
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
"""
...

Check warning on line 31 in python/packages/autogen-core/src/autogen_core/_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_agent.py#L31

Added line #L31 was not covered by tests

async def on_message(self, message: Any, ctx: MessageContext) -> Any:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ def current_agent_id(cls) -> AgentId:
raise RuntimeError(
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e

@classmethod
def is_in_runtime(cls) -> bool:
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
return False
return True
54 changes: 54 additions & 0 deletions python/packages/autogen-core/src/autogen_core/_agent_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,60 @@
"""
...

async def register_agent_instance(
self,
agent_instance: T | Awaitable[T],
agent_id: AgentId,
) -> AgentId:
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.
.. note::
This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.
Example:
.. code-block:: python
from dataclasses import dataclass
from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
from autogen_core.models import UserMessage
@dataclass
class MyMessage:
content: str
class MyAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("My core agent")
@event
async def handler(self, message: UserMessage, context: MessageContext) -> None:
print("Event received: ", message.content)
async def main() -> None:
runtime: AgentRuntime = ... # type: ignore
agent: Agent = MyAgent()
await runtime.register_agent_instance(
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
)
import asyncio
asyncio.run(main())
Args:
agent_instance (T | Awaitable[T]): A concrete instance of the agent.
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
"""
...

Check warning on line 185 in python/packages/autogen-core/src/autogen_core/_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_agent_runtime.py#L185

Added line #L185 was not covered by tests

# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.
Expand Down
76 changes: 66 additions & 10 deletions python/packages/autogen-core/src/autogen_core/_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription

T = TypeVar("T", bound=Agent)

Expand Down Expand Up @@ -82,20 +83,25 @@
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)

def __init__(self, description: str) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e

self._runtime: AgentRuntime = runtime
self._id: AgentId = id
if AgentInstantiationContext.is_in_runtime():
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
self._id = AgentInstantiationContext.current_agent_id()
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description

async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
if hasattr(self, "_id"):
if self._id != id:
raise RuntimeError("Agent is already bound to a different ID")

Check warning on line 96 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L95-L96

Added lines #L95 - L96 were not covered by tests

if hasattr(self, "_runtime"):
if self._runtime != runtime:
raise RuntimeError("Agent is already bound to a different runtime")

Check warning on line 100 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L99-L100

Added lines #L99 - L100 were not covered by tests

self._id = id
self._runtime = runtime

@property
def type(self) -> str:
return self.id.type
Expand Down Expand Up @@ -155,6 +161,56 @@
async def close(self) -> None:
pass

async def register_instance(
self,
runtime: AgentRuntime,
agent_id: AgentId,
*,
skip_class_subscriptions: bool = True,
skip_direct_message_subscription: bool = False,
) -> AgentId:
"""
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
"""
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)

id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
await runtime.add_subscription(id_subscription)

if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
subscriptions: List[Subscription] = []
for unbound_subscription in self._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result

Check warning on line 186 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L181-L186

Added lines #L181 - L186 were not covered by tests
else:
subscriptions_list = subscriptions_list_result

Check warning on line 188 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L188

Added line #L188 was not covered by tests

subscriptions.extend(subscriptions_list)
for subscription in subscriptions:
await runtime.add_subscription(subscription)

Check warning on line 192 in python/packages/autogen-core/src/autogen_core/_base_agent.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_base_agent.py#L190-L192

Added lines #L190 - L192 were not covered by tests

if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
try:
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_id.type + ":",
agent_type=agent_id.type,
)
)
except ValueError:
# We don't care if the subscription already exists
pass

# TODO: deduplication
for _message_type, serializer in self._handles_types():
runtime.add_message_serializer(serializer)

return agent_id

@classmethod
async def register(
cls,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
self._agent_instance_types: Dict[str, Type[Agent]] = {}

@property
def unprocessed_messages_count(
Expand Down Expand Up @@ -830,6 +831,33 @@

return type

async def register_agent_instance(
self,
agent_instance: T | Awaitable[T],
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> T:
raise RuntimeError("Agent factory should not be called when registering an agent instance.")

if inspect.isawaitable(agent_instance):
agent_instance = await agent_instance

Check warning on line 843 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L843

Added line #L843 was not covered by tests

if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")

if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")

Check warning on line 853 in python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py

View check run for this annotation

Codecov / codecov/patch

python/packages/autogen-core/src/autogen_core/_single_threaded_agent_runtime.py#L853

Added line #L853 was not covered by tests
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")

await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id

async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
Expand All @@ -851,8 +879,7 @@
raise ValueError("Agent factory must take 0 or 2 arguments.")

if inspect.isawaitable(agent):
return cast(T, await agent)

agent = cast(T, await agent)
return agent

except BaseException as e:
Expand Down
24 changes: 12 additions & 12 deletions python/packages/autogen-core/tests/test_base_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import pytest
from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime
from autogen_test_utils import NoopAgent
from pytest_mock import MockerFixture
# import pytest
# from autogen_core import AgentId, AgentInstantiationContext, AgentRuntime
# from autogen_test_utils import NoopAgent
# from pytest_mock import MockerFixture


@pytest.mark.asyncio
async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)
# @pytest.mark.asyncio
# async def test_base_agent_create(mocker: MockerFixture) -> None:
# runtime = mocker.Mock(spec=AgentRuntime)

# Shows how to set the context for the agent instantiation in a test context
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")
# # Shows how to set the context for the agent instantiation in a test context
# with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))):
# agent2 = NoopAgent()
# assert agent2.runtime == runtime
# assert agent2.id == AgentId("name2", "namespace2")
54 changes: 54 additions & 0 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,60 @@ def agent_factory() -> NoopAgent:
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)


@pytest.mark.asyncio
async def test_agent_type_register_instance() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent2_id = AgentId(type="name", key="notdefault")
agent1 = NoopAgent()
agent1_dup = NoopAgent()
agent2 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
await agent2.register_instance(runtime=runtime, agent_id=agent2_id)

assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
with pytest.raises(ValueError):
await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id)


@pytest.mark.asyncio
async def test_agent_type_register_instance_different_types() -> None:
runtime = SingleThreadedAgentRuntime()
agent_id1 = AgentId(type="name", key="noop")
agent_id2 = AgentId(type="name", key="loopback")
agent1 = NoopAgent()
agent2 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id1)
with pytest.raises(ValueError):
await agent2.register_instance(runtime=runtime, agent_id=agent_id2)


@pytest.mark.asyncio
async def test_agent_type_register_instance_publish_new_source() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
agent_id = AgentId(type="name", key="default")
agent1 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id)
await runtime.add_subscription(TypeSubscription("notdefault", "name"))

runtime.start()
with pytest.raises(RuntimeError):
await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
await runtime.stop_when_idle()
await runtime.close()


@pytest.mark.asyncio
async def test_register_instance_factory() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent1 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
with pytest.raises(ValueError):
await NoopAgent.register(runtime, "name", lambda: NoopAgent())


@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)
Expand Down
Loading
Loading