28
28
cast ,
29
29
)
30
30
31
+ from google .protobuf import any_pb2
31
32
from opentelemetry .trace import TracerProvider
32
33
from typing_extensions import Self , deprecated
33
34
35
+ from autogen_core .application .protos import cloudevent_pb2
36
+
34
37
from ..base import (
35
38
JSON_DATA_CONTENT_TYPE ,
39
+ PROTOBUF_DATA_CONTENT_TYPE ,
36
40
Agent ,
37
41
AgentId ,
38
42
AgentInstantiationContext ,
49
53
from ..base ._serialization import MessageSerializer , SerializationRegistry
50
54
from ..base ._type_helpers import ChannelArgumentType
51
55
from ..components import TypePrefixSubscription , TypeSubscription
56
+ from . import _constants
57
+ from ._constants import GRPC_IMPORT_ERROR_STR
52
58
from ._helpers import SubscriptionManager , get_impl
53
- from ._utils import GRPC_IMPORT_ERROR_STR
54
59
from .protos import agent_worker_pb2 , agent_worker_pb2_grpc
55
60
from .telemetry import MessageRuntimeTracingConfig , TraceHelper , get_telemetry_grpc_metadata
56
61
@@ -178,6 +183,7 @@ def __init__(
178
183
host_address : str ,
179
184
tracer_provider : TracerProvider | None = None ,
180
185
extra_grpc_config : ChannelArgumentType | None = None ,
186
+ payload_serialization_format : str = JSON_DATA_CONTENT_TYPE ,
181
187
) -> None :
182
188
self ._host_address = host_address
183
189
self ._trace_helper = TraceHelper (tracer_provider , MessageRuntimeTracingConfig ("Worker Runtime" ))
@@ -198,6 +204,11 @@ def __init__(
198
204
self ._serialization_registry = SerializationRegistry ()
199
205
self ._extra_grpc_config = extra_grpc_config or []
200
206
207
+ if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE , PROTOBUF_DATA_CONTENT_TYPE }:
208
+ raise ValueError (f"Unsupported payload serialization format: { payload_serialization_format } " )
209
+
210
+ self ._payload_serialization_format = payload_serialization_format
211
+
201
212
def start (self ) -> None :
202
213
"""Start the runtime in a background task."""
203
214
if self ._running :
@@ -236,8 +247,10 @@ async def _run_read_loop(self) -> None:
236
247
self ._background_tasks .add (task )
237
248
task .add_done_callback (self ._raise_on_exception )
238
249
task .add_done_callback (self ._background_tasks .discard )
239
- case "event" :
240
- task = asyncio .create_task (self ._process_event (message .event ))
250
+ case "cloudEvent" :
251
+ # The proto typing doesnt resolve this one
252
+ cloud_event = cast (cloudevent_pb2 .CloudEvent , message .cloudEvent ) # type: ignore
253
+ task = asyncio .create_task (self ._process_event (cloud_event ))
241
254
self ._background_tasks .add (task )
242
255
task .add_done_callback (self ._raise_on_exception )
243
256
task .add_done_callback (self ._background_tasks .discard )
@@ -257,8 +270,6 @@ async def _run_read_loop(self) -> None:
257
270
task .add_done_callback (self ._background_tasks .discard )
258
271
case None :
259
272
logger .warning ("No message" )
260
- case other :
261
- logger .error (f"Unknown message type: { other } " )
262
273
except Exception as e :
263
274
logger .error ("Error in read loop" , exc_info = e )
264
275
@@ -381,30 +392,64 @@ async def publish_message(
381
392
if message_id is None :
382
393
message_id = str (uuid .uuid4 ())
383
394
384
- # TODO: consume message_id
385
-
386
395
message_type = self ._serialization_registry .type_name (message )
387
396
with self ._trace_helper .trace_block (
388
397
"create" , topic_id , parent = None , extraAttributes = {"message_type" : message_type }
389
398
):
390
399
serialized_message = self ._serialization_registry .serialize (
391
- message , type_name = message_type , data_content_type = JSON_DATA_CONTENT_TYPE
400
+ message , type_name = message_type , data_content_type = self . _payload_serialization_format
392
401
)
393
- telemetry_metadata = get_telemetry_grpc_metadata ()
394
- runtime_message = agent_worker_pb2 .Message (
395
- event = agent_worker_pb2 .Event (
396
- topic_type = topic_id .type ,
397
- topic_source = topic_id .source ,
398
- source = agent_worker_pb2 .AgentId (type = sender .type , key = sender .key ) if sender is not None else None ,
399
- metadata = telemetry_metadata ,
400
- payload = agent_worker_pb2 .Payload (
401
- data_type = message_type ,
402
- data = serialized_message ,
403
- data_content_type = JSON_DATA_CONTENT_TYPE ,
404
- ),
402
+
403
+ sender_id = sender or AgentId ("unknown" , "unknown" )
404
+ attributes = {
405
+ _constants .DATA_CONTENT_TYPE_ATTR : cloudevent_pb2 .CloudEvent .CloudEventAttributeValue (
406
+ ce_string = self ._payload_serialization_format
407
+ ),
408
+ _constants .DATA_SCHEMA_ATTR : cloudevent_pb2 .CloudEvent .CloudEventAttributeValue (ce_string = message_type ),
409
+ _constants .AGENT_SENDER_TYPE_ATTR : cloudevent_pb2 .CloudEvent .CloudEventAttributeValue (
410
+ ce_string = sender_id .type
411
+ ),
412
+ _constants .AGENT_SENDER_KEY_ATTR : cloudevent_pb2 .CloudEvent .CloudEventAttributeValue (
413
+ ce_string = sender_id .key
414
+ ),
415
+ _constants .MESSAGE_KIND_ATTR : cloudevent_pb2 .CloudEvent .CloudEventAttributeValue (
416
+ ce_string = _constants .MESSAGE_KIND_VALUE_PUBLISH
417
+ ),
418
+ }
419
+
420
+ # If sending JSON we fill text_data with the serialized message
421
+ # If sending Protobuf we fill proto_data with the serialized message
422
+ # TODO: add an encoding field for serializer
423
+
424
+ if self ._payload_serialization_format == JSON_DATA_CONTENT_TYPE :
425
+ runtime_message = agent_worker_pb2 .Message (
426
+ cloudEvent = cloudevent_pb2 .CloudEvent (
427
+ id = message_id ,
428
+ spec_version = "1.0" ,
429
+ type = topic_id .type ,
430
+ source = topic_id .source ,
431
+ attributes = attributes ,
432
+ # TODO: use text, or proto fields appropriately
433
+ binary_data = serialized_message ,
434
+ )
435
+ )
436
+ else :
437
+ # We need to unpack the serialized proto back into an Any
438
+ # TODO: find a way to prevent the roundtrip serialization
439
+ any_proto = any_pb2 .Any ()
440
+ any_proto .ParseFromString (serialized_message )
441
+ runtime_message = agent_worker_pb2 .Message (
442
+ cloudEvent = cloudevent_pb2 .CloudEvent (
443
+ id = message_id ,
444
+ spec_version = "1.0" ,
445
+ type = topic_id .type ,
446
+ source = topic_id .source ,
447
+ attributes = attributes ,
448
+ proto_data = any_proto ,
449
+ )
405
450
)
406
- )
407
451
452
+ telemetry_metadata = get_telemetry_grpc_metadata ()
408
453
task = asyncio .create_task (self ._send_message (runtime_message , "publish" , topic_id , telemetry_metadata ))
409
454
self ._background_tasks .add (task )
410
455
task .add_done_callback (self ._raise_on_exception )
@@ -523,28 +568,58 @@ async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> Non
523
568
else :
524
569
future .set_result (result )
525
570
526
- async def _process_event (self , event : agent_worker_pb2 .Event ) -> None :
527
- message = self ._serialization_registry .deserialize (
528
- event .payload .data , type_name = event .payload .data_type , data_content_type = event .payload .data_content_type
529
- )
571
+ async def _process_event (self , event : cloudevent_pb2 .CloudEvent ) -> None :
572
+ event_attributes = event .attributes
530
573
sender : AgentId | None = None
531
- if event .HasField ("source" ):
532
- sender = AgentId (event .source .type , event .source .key )
533
- topic_id = TopicId (event .topic_type , event .topic_source )
574
+ if (
575
+ _constants .AGENT_SENDER_TYPE_ATTR in event_attributes
576
+ and _constants .AGENT_SENDER_KEY_ATTR in event_attributes
577
+ ):
578
+ sender = AgentId (
579
+ event_attributes [_constants .AGENT_SENDER_TYPE_ATTR ].ce_string ,
580
+ event_attributes [_constants .AGENT_SENDER_KEY_ATTR ].ce_string ,
581
+ )
582
+ topic_id = TopicId (event .type , event .source )
534
583
# Get the recipients for the topic.
535
584
recipients = await self ._subscription_manager .get_subscribed_recipients (topic_id )
585
+
586
+ message_content_type = event_attributes [_constants .DATA_CONTENT_TYPE_ATTR ].ce_string
587
+ message_type = event_attributes [_constants .DATA_SCHEMA_ATTR ].ce_string
588
+
589
+ if message_content_type == JSON_DATA_CONTENT_TYPE :
590
+ message = self ._serialization_registry .deserialize (
591
+ event .binary_data , type_name = message_type , data_content_type = message_content_type
592
+ )
593
+ elif message_content_type == PROTOBUF_DATA_CONTENT_TYPE :
594
+ # TODO: find a way to prevent the roundtrip serialization
595
+ proto_binary_data = event .proto_data .SerializeToString ()
596
+ message = self ._serialization_registry .deserialize (
597
+ proto_binary_data , type_name = message_type , data_content_type = message_content_type
598
+ )
599
+ else :
600
+ raise ValueError (f"Unsupported message content type: { message_content_type } " )
601
+
602
+ # TODO: dont read these values in the runtime
603
+ topic_type_suffix = topic_id .type .split (":" , maxsplit = 1 )[1 ] if ":" in topic_id .type else ""
604
+ is_rpc = topic_type_suffix == _constants .MESSAGE_KIND_VALUE_RPC_REQUEST
605
+ is_marked_rpc_type = (
606
+ _constants .MESSAGE_KIND_ATTR in event_attributes
607
+ and event_attributes [_constants .MESSAGE_KIND_ATTR ].ce_string == _constants .MESSAGE_KIND_VALUE_RPC_REQUEST
608
+ )
609
+ if is_rpc and not is_marked_rpc_type :
610
+ warnings .warn ("Received RPC request with topic type suffix but not marked as RPC request." , stacklevel = 2 )
611
+
536
612
# Send the message to each recipient.
537
613
responses : List [Awaitable [Any ]] = []
538
614
for agent_id in recipients :
539
615
if agent_id == sender :
540
616
continue
541
- # TODO: consume message_id
542
617
message_context = MessageContext (
543
618
sender = sender ,
544
619
topic_id = topic_id ,
545
- is_rpc = False ,
620
+ is_rpc = is_rpc ,
546
621
cancellation_token = CancellationToken (),
547
- message_id = "NOT_DEFINED_TODO_FIX" ,
622
+ message_id = event . id ,
548
623
)
549
624
agent = await self ._get_agent (agent_id )
550
625
with MessageHandlerContext .populate_context (agent .id ):
@@ -554,7 +629,7 @@ async def send_message(agent: Agent, message_context: MessageContext) -> Any:
554
629
"process" ,
555
630
agent .id ,
556
631
parent = event .metadata ,
557
- extraAttributes = {"message_type" : event . payload . data_type },
632
+ extraAttributes = {"message_type" : message_type },
558
633
):
559
634
await agent .on_message (message , ctx = message_context )
560
635
0 commit comments