Skip to content

Commit

Permalink
format (#593)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekzhu authored Sep 19, 2024
1 parent 2073305 commit 46ca778
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,9 @@ async def register_factory(
agent_factory: Callable[[], T | Awaitable[T]],
expected_class: type[T],
) -> AgentType:
if type.type in self._agent_factories:
raise ValueError(f"Agent with type {type} already exists.")

async def factory_wrapper() -> T:
maybe_agent_instance = agent_factory()
if inspect.isawaitable(maybe_agent_instance):
Expand Down
99 changes: 58 additions & 41 deletions python/packages/autogen-core/tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@
from autogen_core.base import (
AgentId,
AgentInstantiationContext,
AgentType,
Subscription,
SubscriptionInstantiationContext,
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import (
DefaultSubscription,
DefaultTopicId,
TypeSubscription,
default_subscription,
type_subscription,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
from opentelemetry.sdk.trace import TracerProvider
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent
from test_utils.telemetry_test_utils import TestExporter, get_test_tracer_provider
Expand All @@ -24,7 +32,7 @@ def tracer_provider() -> TracerProvider:


@pytest.mark.asyncio
async def test_agent_names_must_be_unique() -> None:
async def test_agent_type_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()

def agent_factory() -> NoopAgent:
Expand All @@ -34,29 +42,30 @@ def agent_factory() -> NoopAgent:
assert agent.id == id
return agent

await runtime.register("name1", agent_factory)
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)

with pytest.raises(ValueError):
await runtime.register("name1", NoopAgent)
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)

await runtime.register("name3", NoopAgent)
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)


@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)

await runtime.register("name", LoopbackAgent)
runtime.start()
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
await runtime.register_factory(
type=AgentType("name"), agent_factory=lambda: LoopbackAgent(), expected_class=LoopbackAgent
)
await runtime.add_subscription(TypeSubscription("default", "name"))
agent_id = AgentId("name", key="default")
topic_id = TopicId("default", "default")
await runtime.publish_message(MessageType(), topic_id=topic_id)

runtime.start()
await runtime.publish_message(MessageType(), topic_id=TopicId("default", "default"))
await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(AgentId("name", "default"), type=LoopbackAgent)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
Expand All @@ -77,24 +86,24 @@ async def test_register_receives_publish(tracer_provider: TracerProvider) -> Non

@pytest.mark.asyncio
async def test_register_receives_publish_cascade() -> None:
runtime = SingleThreadedAgentRuntime()
num_agents = 5
num_initial_messages = 5
max_rounds = 5
total_num_calls_expected = 0
for i in range(0, max_rounds):
total_num_calls_expected += num_initial_messages * ((num_agents - 1) ** i)

runtime = SingleThreadedAgentRuntime()

# Register agents
for i in range(num_agents):
await runtime.register(f"name{i}", lambda: CascadingAgent(max_rounds), lambda: [DefaultSubscription()])
await CascadingAgent.register(runtime, f"name{i}", lambda: CascadingAgent(max_rounds))

runtime.start()

# Publish messages
topic_id = TopicId("default", "default")
for _ in range(num_initial_messages):
await runtime.publish_message(CascadingMessageType(round=1), topic_id)
await runtime.publish_message(CascadingMessageType(round=1), DefaultTopicId())

# Process until idle.
await runtime.stop_when_idle()
Expand Down Expand Up @@ -206,64 +215,72 @@ async def test_register_factory_direct_list() -> None:
@pytest.mark.asyncio
async def test_default_subscription() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)

agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(
agent_id, type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
async def test_type_subscription() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
runtime.start()
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))

@type_subscription(topic_type="Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...

await LoopbackAgentWithSubscription.register(runtime, "name", LoopbackAgentWithSubscription)

agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=TopicId("Other", "default"))
await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgentWithSubscription)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
)
assert other_long_running_agent.num_calls == 0


@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
async def test_default_subscription_publish_to_other_source() -> None:
runtime = SingleThreadedAgentRuntime()

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
runtime.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", LoopbackAgentWithDefaultSubscription)

agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))

await runtime.stop_when_idle()

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
long_running_agent = await runtime.try_get_underlying_agent_instance(
agent_id, type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 0

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 1
110 changes: 70 additions & 40 deletions python/packages/autogen-core/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
TopicId,
try_get_known_serializers_for_type,
)
from autogen_core.components import DefaultSubscription, DefaultTopicId, TypeSubscription
from autogen_core.components import (
DefaultSubscription,
DefaultTopicId,
TypeSubscription,
default_subscription,
type_subscription,
)
from test_utils import CascadingAgent, CascadingMessageType, LoopbackAgent, MessageType, NoopAgent


Expand Down Expand Up @@ -190,83 +196,107 @@ async def test_default_subscription() -> None:
host_address = "localhost:50054"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(worker, "name", lambda: LoopbackAgentWithDefaultSubscription())

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId())
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId())

await asyncio.sleep(2)

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
# Agent in default topic source should have received the message.
long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
# Agent in other topic source should not have received the message.
other_long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 0

await runtime.stop()
await worker.stop()
await publisher.stop()
await host.stop()


@pytest.mark.asyncio
async def test_non_default_default_subscription() -> None:
host_address = "localhost:50055"
async def test_default_subscription_other_source() -> None:
host_address = "localhost:50056"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()

@default_subscription
class LoopbackAgentWithDefaultSubscription(LoopbackAgent): ...

await LoopbackAgentWithDefaultSubscription.register(runtime, "name", lambda: LoopbackAgentWithDefaultSubscription())

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription(topic_type="Other")])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(type="Other"))
await publisher.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))

await asyncio.sleep(2)

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 1
long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithDefaultSubscription
)
assert long_running_agent.num_calls == 0

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
other_long_running_agent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithDefaultSubscription
)
assert other_long_running_agent.num_calls == 0
assert other_long_running_agent.num_calls == 1

await runtime.stop()
await publisher.stop()
await host.stop()


@pytest.mark.asyncio
async def test_non_publish_to_other_source() -> None:
host_address = "localhost:50056"
async def test_type_subscription() -> None:
host_address = "localhost:50055"
host = WorkerAgentRuntimeHost(address=host_address)
host.start()
runtime = WorkerAgentRuntime(host_address=host_address)
runtime.add_message_serializer(try_get_known_serializers_for_type(MessageType))
runtime.start()
worker = WorkerAgentRuntime(host_address=host_address)
worker.start()
publisher = WorkerAgentRuntime(host_address=host_address)
publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
publisher.start()

@type_subscription("Other")
class LoopbackAgentWithSubscription(LoopbackAgent): ...

await runtime.register("name", LoopbackAgent, lambda: [DefaultSubscription()])
agent_id = AgentId("name", key="default")
await runtime.publish_message(MessageType(), topic_id=DefaultTopicId(source="other"))
await LoopbackAgentWithSubscription.register(worker, "name", lambda: LoopbackAgentWithSubscription())

await publisher.publish_message(MessageType(), topic_id=TopicId(type="Other", source="default"))

await asyncio.sleep(2)

# Agent in default namespace should have received the message
long_running_agent = await runtime.try_get_underlying_agent_instance(agent_id, type=LoopbackAgent)
assert long_running_agent.num_calls == 0
# Agent in default topic source should have received the message.
long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", "default"), type=LoopbackAgentWithSubscription
)
assert long_running_agent.num_calls == 1

# Agent in other namespace should not have received the message
other_long_running_agent: LoopbackAgent = await runtime.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgent
# Agent in other topic source should not have received the message.
other_long_running_agent = await worker.try_get_underlying_agent_instance(
AgentId("name", key="other"), type=LoopbackAgentWithSubscription
)
assert other_long_running_agent.num_calls == 1
assert other_long_running_agent.num_calls == 0

await runtime.stop()
await worker.stop()
await publisher.stop()
await host.stop()

0 comments on commit 46ca778

Please sign in to comment.