Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT: mid change commit #4372

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,11 @@ message SaveStateResponse {

message Message {
oneof message {
RpcRequest request = 1;
RpcResponse response = 2;
Event event = 3;
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
AddSubscriptionRequest addSubscriptionRequest = 6;
AddSubscriptionResponse addSubscriptionResponse = 7;
cloudevent.CloudEvent cloudEvent = 8;
cloudevent.CloudEvent cloudEvent = 1;
RegisterAgentTypeRequest registerAgentTypeRequest = 2;
RegisterAgentTypeResponse registerAgentTypeResponse = 3;
AddSubscriptionRequest addSubscriptionRequest = 4;
AddSubscriptionResponse addSubscriptionResponse = 5;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ async def publish_message(
topic_id: TopicId,
*,
sender: AgentId | None = None,
# TODO: handle request_id being passed in
request_id: str | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
with self._tracer_helper.trace_block(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
TypeVar,
cast,
)
from uuid import uuid4

import grpc
from google.protobuf import any_pb2
from grpc.aio import StreamStreamCall
from opentelemetry.trace import TracerProvider
from typing_extensions import Self, deprecated
Expand All @@ -52,7 +54,7 @@
)
from ..components import TypeSubscription
from ._helpers import SubscriptionManager, get_impl
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .protos import agent_worker_pb2, agent_worker_pb2_grpc, cloudevent_pb2
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata

if TYPE_CHECKING:
Expand Down Expand Up @@ -184,6 +186,7 @@ def __init__(
self._read_task: None | Task[None] = None
self._running = False
self._pending_requests: Dict[str, Future[Any]] = {}
self._pending_callbacks: Dict[TopicId, Callable[[Any], None]] = {}
self._pending_requests_lock = asyncio.Lock()
self._next_request_id = 0
self._host_connection: HostConnection | None = None
Expand Down Expand Up @@ -314,63 +317,66 @@ async def _send_message(
with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata):
await self._host_connection.send(runtime_message)

async def send_message(
self,
message: Any,
recipient: AgentId,
*,
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> Any:
if not self._running:
raise ValueError("Runtime must be running when sending message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
data_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", recipient, parent=None, extraAttributes={"message_type": data_type}
):
# create a new future for the result
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future
serialized_message = self._serialization_registry.serialize(
message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
telemetry_metadata = get_telemetry_grpc_metadata()
runtime_message = agent_worker_pb2.Message(
request=agent_worker_pb2.RpcRequest(
request_id=request_id,
target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
metadata=telemetry_metadata,
payload=agent_worker_pb2.Payload(
data_type=data_type,
data=serialized_message,
data_content_type=JSON_DATA_CONTENT_TYPE,
),
)
)

# TODO: Find a way to handle timeouts/errors
task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
return await future
# async def send_message(
# self,
# message: Any,
# recipient: AgentId,
# *,
# sender: AgentId | None = None,
# cancellation_token: CancellationToken | None = None,
# ) -> Any:
# if not self._running:
# raise ValueError("Runtime must be running when sending message.")
# if self._host_connection is None:
# raise RuntimeError("Host connection is not set.")
# data_type = self._serialization_registry.type_name(message)
# with self._trace_helper.trace_block(
# "create", recipient, parent=None, extraAttributes={"message_type": data_type}
# ):
# # create a new future for the result
# future = asyncio.get_event_loop().create_future()
# request_id = await self._get_new_request_id()
# self._pending_requests[request_id] = future
# serialized_message = self._serialization_registry.serialize(
# message, type_name=data_type, data_content_type=JSON_DATA_CONTENT_TYPE
# )
# telemetry_metadata = get_telemetry_grpc_metadata()
# runtime_message = agent_worker_pb2.Message(
# request=agent_worker_pb2.RpcRequest(
# request_id=request_id,
# target=agent_worker_pb2.AgentId(type=recipient.type, key=recipient.key),
# source=agent_worker_pb2.AgentId(type=sender.type, key=sender.key) if sender is not None else None,
# metadata=telemetry_metadata,
# payload=agent_worker_pb2.Payload(
# data_type=data_type,
# data=serialized_message,
# data_content_type=JSON_DATA_CONTENT_TYPE,
# ),
# )
# )

# # TODO: Find a way to handle timeouts/errors
# task = asyncio.create_task(self._send_message(runtime_message, "send", recipient, telemetry_metadata))
# self._background_tasks.add(task)
# task.add_done_callback(self._raise_on_exception)
# task.add_done_callback(self._background_tasks.discard)
# return await future

async def publish_message(
self,
message: Any,
topic_id: TopicId,
*,
sender: AgentId | None = None,
request_id: str | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
if not self._running:
raise ValueError("Runtime must be running when publishing message.")
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if not request_id:
request_id = self.get_new_request_id()
message_type = self._serialization_registry.type_name(message)
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
Expand Down Expand Up @@ -413,10 +419,8 @@ async def agent_save_state(self, agent: AgentId) -> Mapping[str, Any]:
async def agent_load_state(self, agent: AgentId, state: Mapping[str, Any]) -> None:
raise NotImplementedError("Agent load_state is not yet implemented.")

async def _get_new_request_id(self) -> str:
async with self._pending_requests_lock:
self._next_request_id += 1
return str(self._next_request_id)
def get_new_request_id(self) -> str:
return str(uuid4())

async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
assert self._host_connection is not None
Expand Down Expand Up @@ -489,35 +493,39 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
# Send the response.
await self._host_connection.send(response_message)

async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
with self._trace_helper.trace_block(
"ack",
None,
parent=response.metadata,
attributes={"request_id": response.request_id},
extraAttributes={"message_type": response.payload.data_type},
):
# Deserialize the result.
result = self._serialization_registry.deserialize(
response.payload.data,
type_name=response.payload.data_type,
data_content_type=response.payload.data_content_type,
)
# Get the future and set the result.
future = self._pending_requests.pop(response.request_id)
if len(response.error) > 0:
future.set_exception(Exception(response.error))
else:
future.set_result(result)

async def _process_event(self, event: agent_worker_pb2.Event) -> None:
message = self._serialization_registry.deserialize(
event.payload.data, type_name=event.payload.data_type, data_content_type=event.payload.data_content_type
)
async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
oneofcase = cloudevent_pb2.CloudEvent.WhichOneof(event, "data")
match oneofcase:
case "binary_data":
message = self._serialization_registry.deserialize(
event.binary_data, type_name=event.metadata["dataschema"], data_content_type=event.metadata["datacontenttype"],
)
case "text_data":
logger.error(f"Unsupported data type: {oneofcase}")
case "proto_data":
# TODO: Need to actually create the protobuf message instance
# in order to decode it... For now, everything gets put into
# binary_data
pass
sender: AgentId | None = None
if event.HasField("source"):
sender = AgentId(event.source.type, event.source.key)
topic_id = TopicId(event.topic_type, event.topic_source)

# If there's a callback for this topic, use the callback and return
if topic_id in self._pending_callbacks:
# This is processing an rpc response.
# TODO: process rpc request...
with self._trace_helper.trace_block(
"ack",
None,
parent=event.metadata,
attributes={"request_id": event.id},
extraAttributes={"message_type": event.metadata["dataschema"]},
):
self._pending_callbacks[topic_id](message)
return

# Get the recipients for the topic.
recipients = await self._subscription_manager.get_subscribed_recipients(topic_id)
# Send the message to each recipient.
Expand Down Expand Up @@ -571,7 +579,7 @@ async def register(

# Create a future for the registration response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
request_id = self.get_new_request_id()
self._pending_requests[request_id] = future

# Send the registration request message to the host.
Expand Down Expand Up @@ -627,7 +635,7 @@ async def factory_wrapper() -> T:

# Create a future for the registration response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
request_id = self.get_new_request_id()
self._pending_requests[request_id] = future

# Send the registration request message to the host.
Expand Down Expand Up @@ -697,17 +705,22 @@ async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = A

return agent_instance

async def add_subscription(self, subscription: Subscription) -> None:
async def add_subscription(self, subscription: Subscription, callback: Callable[[Any], None] = None) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
if not isinstance(subscription, TypeSubscription):
raise ValueError("Only TypeSubscription is supported.")
# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

# Register the callback
if callback is not None:
callback_topic = TopicId(type=subscription.topic_type, source=subscription.agent_type)
self._pending_callbacks[callback_topic] = callback

# Create a future for the subscription response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
request_id = self.get_new_request_id()
self._pending_requests[request_id] = future

# Send the subscription to the host.
Expand Down
Loading