Skip to content

Commit a4067f6

Browse files
authored
Migrate Python distributed runtime to use cloud events for event (#4407)
* Cloud event publishing * Implement cloud event receiving * impl host servicer and
1 parent bd77ccb commit a4067f6

File tree

11 files changed

+210
-65
lines changed

11 files changed

+210
-65
lines changed

protos/agent_worker.proto

+1-2
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,11 @@ message Message {
117117
oneof message {
118118
RpcRequest request = 1;
119119
RpcResponse response = 2;
120-
Event event = 3;
120+
cloudevent.CloudEvent cloudEvent = 3;
121121
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
122122
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
123123
AddSubscriptionRequest addSubscriptionRequest = 6;
124124
AddSubscriptionResponse addSubscriptionResponse = 7;
125-
cloudevent.CloudEvent cloudEvent = 8;
126125
}
127126
}
128127

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
GRPC_IMPORT_ERROR_STR = (
2+
"Distributed runtime features require additional dependencies. Install them with: pip install autogen-core[grpc]"
3+
)
4+
5+
DATA_CONTENT_TYPE_ATTR = "datacontenttype"
6+
DATA_SCHEMA_ATTR = "dataschema"
7+
AGENT_SENDER_TYPE_ATTR = "agagentsendertype"
8+
AGENT_SENDER_KEY_ATTR = "agagentsenderkey"
9+
MESSAGE_KIND_ATTR = "agmsgkind"
10+
MESSAGE_KIND_VALUE_PUBLISH = "publish"
11+
MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"
12+
MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"
13+
MESSAGE_KIND_VALUE_RPC_ERROR = "error"

python/packages/autogen-core/src/autogen_core/application/_utils.py

-3
This file was deleted.

python/packages/autogen-core/src/autogen_core/application/_worker_runtime.py

+107-32
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,15 @@
2828
cast,
2929
)
3030

31+
from google.protobuf import any_pb2
3132
from opentelemetry.trace import TracerProvider
3233
from typing_extensions import Self, deprecated
3334

35+
from autogen_core.application.protos import cloudevent_pb2
36+
3437
from ..base import (
3538
JSON_DATA_CONTENT_TYPE,
39+
PROTOBUF_DATA_CONTENT_TYPE,
3640
Agent,
3741
AgentId,
3842
AgentInstantiationContext,
@@ -49,8 +53,9 @@
4953
from ..base._serialization import MessageSerializer, SerializationRegistry
5054
from ..base._type_helpers import ChannelArgumentType
5155
from ..components import TypePrefixSubscription, TypeSubscription
56+
from . import _constants
57+
from ._constants import GRPC_IMPORT_ERROR_STR
5258
from ._helpers import SubscriptionManager, get_impl
53-
from ._utils import GRPC_IMPORT_ERROR_STR
5459
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
5560
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata
5661

@@ -178,6 +183,7 @@ def __init__(
178183
host_address: str,
179184
tracer_provider: TracerProvider | None = None,
180185
extra_grpc_config: ChannelArgumentType | None = None,
186+
payload_serialization_format: str = JSON_DATA_CONTENT_TYPE,
181187
) -> None:
182188
self._host_address = host_address
183189
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
@@ -198,6 +204,11 @@ def __init__(
198204
self._serialization_registry = SerializationRegistry()
199205
self._extra_grpc_config = extra_grpc_config or []
200206

207+
if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
208+
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
209+
210+
self._payload_serialization_format = payload_serialization_format
211+
201212
def start(self) -> None:
202213
"""Start the runtime in a background task."""
203214
if self._running:
@@ -236,8 +247,10 @@ async def _run_read_loop(self) -> None:
236247
self._background_tasks.add(task)
237248
task.add_done_callback(self._raise_on_exception)
238249
task.add_done_callback(self._background_tasks.discard)
239-
case "event":
240-
task = asyncio.create_task(self._process_event(message.event))
250+
case "cloudEvent":
251+
# The proto typing doesnt resolve this one
252+
cloud_event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
253+
task = asyncio.create_task(self._process_event(cloud_event))
241254
self._background_tasks.add(task)
242255
task.add_done_callback(self._raise_on_exception)
243256
task.add_done_callback(self._background_tasks.discard)
@@ -257,8 +270,6 @@ async def _run_read_loop(self) -> None:
257270
task.add_done_callback(self._background_tasks.discard)
258271
case None:
259272
logger.warning("No message")
260-
case other:
261-
logger.error(f"Unknown message type: {other}")
262273
except Exception as e:
263274
logger.error("Error in read loop", exc_info=e)
264275

@@ -381,30 +392,64 @@ async def publish_message(
381392
if message_id is None:
382393
message_id = str(uuid.uuid4())
383394

384-
# TODO: consume message_id
385-
386395
message_type = self._serialization_registry.type_name(message)
387396
with self._trace_helper.trace_block(
388397
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
389398
):
390399
serialized_message = self._serialization_registry.serialize(
391-
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
400+
message, type_name=message_type, data_content_type=self._payload_serialization_format
392401
)
393-
telemetry_metadata = get_telemetry_grpc_metadata()
394-
runtime_message = agent_worker_pb2.Message(
395-
event=agent_worker_pb2.Event(
396-
topic_type=topic_id.type,
397-
topic_source=topic_id.source,
398-
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
399-
metadata=telemetry_metadata,
400-
payload=agent_worker_pb2.Payload(
401-
data_type=message_type,
402-
data=serialized_message,
403-
data_content_type=JSON_DATA_CONTENT_TYPE,
404-
),
402+
403+
sender_id = sender or AgentId("unknown", "unknown")
404+
attributes = {
405+
_constants.DATA_CONTENT_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
406+
ce_string=self._payload_serialization_format
407+
),
408+
_constants.DATA_SCHEMA_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(ce_string=message_type),
409+
_constants.AGENT_SENDER_TYPE_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
410+
ce_string=sender_id.type
411+
),
412+
_constants.AGENT_SENDER_KEY_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
413+
ce_string=sender_id.key
414+
),
415+
_constants.MESSAGE_KIND_ATTR: cloudevent_pb2.CloudEvent.CloudEventAttributeValue(
416+
ce_string=_constants.MESSAGE_KIND_VALUE_PUBLISH
417+
),
418+
}
419+
420+
# If sending JSON we fill text_data with the serialized message
421+
# If sending Protobuf we fill proto_data with the serialized message
422+
# TODO: add an encoding field for serializer
423+
424+
if self._payload_serialization_format == JSON_DATA_CONTENT_TYPE:
425+
runtime_message = agent_worker_pb2.Message(
426+
cloudEvent=cloudevent_pb2.CloudEvent(
427+
id=message_id,
428+
spec_version="1.0",
429+
type=topic_id.type,
430+
source=topic_id.source,
431+
attributes=attributes,
432+
# TODO: use text, or proto fields appropriately
433+
binary_data=serialized_message,
434+
)
435+
)
436+
else:
437+
# We need to unpack the serialized proto back into an Any
438+
# TODO: find a way to prevent the roundtrip serialization
439+
any_proto = any_pb2.Any()
440+
any_proto.ParseFromString(serialized_message)
441+
runtime_message = agent_worker_pb2.Message(
442+
cloudEvent=cloudevent_pb2.CloudEvent(
443+
id=message_id,
444+
spec_version="1.0",
445+
type=topic_id.type,
446+
source=topic_id.source,
447+
attributes=attributes,
448+
proto_data=any_proto,
449+
)
405450
)
406-
)
407451

452+
telemetry_metadata = get_telemetry_grpc_metadata()
408453
task = asyncio.create_task(self._send_message(runtime_message, "publish", topic_id, telemetry_metadata))
409454
self._background_tasks.add(task)
410455
task.add_done_callback(self._raise_on_exception)
@@ -523,28 +568,58 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> Non
523568
else:
524569
future.set_result(result)
525570

526-
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
527-
message = self._serialization_registry.deserialize(
528-
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
529-
)
571+
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
572+
event_attributes = event.attributes
530573
sender: AgentId | None = None
531-
if event.HasField("source"):
532-
sender = AgentId(event.source.type, event.source.key)
533-
topic_id = TopicId(event.topic_type, event.topic_source)
574+
if (
575+
_constants.AGENT_SENDER_TYPE_ATTR in event_attributes
576+
and _constants.AGENT_SENDER_KEY_ATTR in event_attributes
577+
):
578+
sender = AgentId(
579+
event_attributes[_constants.AGENT_SENDER_TYPE_ATTR].ce_string,
580+
event_attributes[_constants.AGENT_SENDER_KEY_ATTR].ce_string,
581+
)
582+
topic_id = TopicId(event.type, event.source)
534583
# Get the recipients for the topic.
535584
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
585+
586+
message_content_type = event_attributes[_constants.DATA_CONTENT_TYPE_ATTR].ce_string
587+
message_type = event_attributes[_constants.DATA_SCHEMA_ATTR].ce_string
588+
589+
if message_content_type == JSON_DATA_CONTENT_TYPE:
590+
message = self._serialization_registry.deserialize(
591+
event.binary_data, type_name=message_type, data_content_type=message_content_type
592+
)
593+
elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE:
594+
# TODO: find a way to prevent the roundtrip serialization
595+
proto_binary_data = event.proto_data.SerializeToString()
596+
message = self._serialization_registry.deserialize(
597+
proto_binary_data, type_name=message_type, data_content_type=message_content_type
598+
)
599+
else:
600+
raise ValueError(f"Unsupported message content type: {message_content_type}")
601+
602+
# TODO: dont read these values in the runtime
603+
topic_type_suffix = topic_id.type.split(":", maxsplit=1)[1] if ":" in topic_id.type else ""
604+
is_rpc = topic_type_suffix == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
605+
is_marked_rpc_type = (
606+
_constants.MESSAGE_KIND_ATTR in event_attributes
607+
and event_attributes[_constants.MESSAGE_KIND_ATTR].ce_string == _constants.MESSAGE_KIND_VALUE_RPC_REQUEST
608+
)
609+
if is_rpc and not is_marked_rpc_type:
610+
warnings.warn("Received RPC request with topic type suffix but not marked as RPC request.", stacklevel=2)
611+
536612
# Send the message to each recipient.
537613
responses: List[Awaitable[Any]] = []
538614
for agent_id in recipients:
539615
if agent_id == sender:
540616
continue
541-
# TODO: consume message_id
542617
message_context = MessageContext(
543618
sender=sender,
544619
topic_id=topic_id,
545-
is_rpc=False,
620+
is_rpc=is_rpc,
546621
cancellation_token=CancellationToken(),
547-
message_id="NOT_DEFINED_TODO_FIX",
622+
message_id=event.id,
548623
)
549624
agent = await self._get_agent(agent_id)
550625
with MessageHandlerContext.populate_context(agent.id):
@@ -554,7 +629,7 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any:
554629
"process",
555630
agent.id,
556631
parent=event.metadata,
557-
extraAttributes={"message_type": event.payload.data_type},
632+
extraAttributes={"message_type": message_type},
558633
):
559634
await agent.on_message(message, ctx=message_context)
560635

python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional, Sequence
55

66
from ..base._type_helpers import ChannelArgumentType
7-
from ._utils import GRPC_IMPORT_ERROR_STR
7+
from ._constants import GRPC_IMPORT_ERROR_STR
88
from ._worker_runtime_host_servicer import WorkerAgentRuntimeHostServicer
99

1010
try:

python/packages/autogen-core/src/autogen_core/application/_worker_runtime_host_servicer.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
import logging
33
from _collections_abc import AsyncIterator, Iterator
44
from asyncio import Future, Task
5-
from typing import Any, Dict, Set
5+
from typing import Any, Dict, Set, cast
66

77
from autogen_core.base._type_prefix_subscription import TypePrefixSubscription
88

99
from ..base import Subscription, TopicId
1010
from ..components import TypeSubscription
11+
from ._constants import GRPC_IMPORT_ERROR_STR
1112
from ._helpers import SubscriptionManager
12-
from ._utils import GRPC_IMPORT_ERROR_STR
1313

1414
try:
1515
import grpc
1616
except ImportError as e:
1717
raise ImportError(GRPC_IMPORT_ERROR_STR) from e
1818

19-
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
19+
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
2020

2121
logger = logging.getLogger("autogen_core")
2222
event_logger = logging.getLogger("autogen_core.events")
@@ -84,7 +84,7 @@ async def _on_client_disconnect(self, client_id: int) -> None:
8484
for agent_type in agent_types:
8585
logger.info(f"Removing agent type {agent_type} from agent type to client id mapping")
8686
del self._agent_type_to_client_id[agent_type]
87-
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, []):
87+
for sub_id in self._client_id_to_subscription_id_mapping.get(client_id, set()):
8888
logger.info(f"Client id {client_id} disconnected. Removing corresponding subscription with id {id}")
8989
await self._subscription_manager.remove_subscription(sub_id)
9090
logger.info(f"Client {client_id} disconnected successfully")
@@ -114,8 +114,9 @@ async def _receive_messages(
114114
self._background_tasks.add(task)
115115
task.add_done_callback(self._raise_on_exception)
116116
task.add_done_callback(self._background_tasks.discard)
117-
case "event":
118-
event: agent_worker_pb2.Event = message.event
117+
case "cloudEvent":
118+
# The proto typing doesnt resolve this one
119+
event = cast(cloudevent_pb2.CloudEvent, message.cloudEvent) # type: ignore
119120
task = asyncio.create_task(self._process_event(event))
120121
self._background_tasks.add(task)
121122
task.add_done_callback(self._raise_on_exception)
@@ -138,8 +139,6 @@ async def _receive_messages(
138139
logger.warning(f"Received unexpected message type: {oneofcase}")
139140
case None:
140141
logger.warning("Received empty message")
141-
case other:
142-
logger.error(f"Received unexpected message: {other}")
143142

144143
async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: int) -> None:
145144
# Deliver the message to a client given the target agent type.
@@ -178,8 +177,8 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse, client
178177
future = self._pending_responses[client_id].pop(response.request_id)
179178
future.set_result(response)
180179

181-
async def _process_event(self, event: agent_worker_pb2.Event) -> None:
182-
topic_id = TopicId(type=event.topic_type, source=event.topic_source)
180+
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
181+
topic_id = TopicId(type=event.type, source=event.source)
183182
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
184183
# Get the client ids of the recipients.
185184
async with self._agent_type_to_client_id_lock:
@@ -192,7 +191,7 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
192191
logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.")
193192
# Deliver the event to clients.
194193
for client_id in client_ids:
195-
await self._send_queues[client_id].put(agent_worker_pb2.Message(event=event))
194+
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))
196195

197196
async def _process_register_agent_type_request(
198197
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: int

0 commit comments

Comments
 (0)