Skip to content

Commit

Permalink
Add messaging attributes to telemetry spans (#490)
Browse files Browse the repository at this point in the history
* Downgrade protobuf from v5 to v4

* Refactor telemetry and include attributes

* Update

* Remove unused vars

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
heyitsaamir and ekzhu authored Sep 13, 2024
1 parent 1ba7a68 commit e25bd2c
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 240 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from enum import Enum
from typing import Any, Awaitable, Callable, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast

from opentelemetry.trace import NoOpTracerProvider, TracerProvider
from opentelemetry.trace import TracerProvider

from ..base import (
Agent,
Expand All @@ -30,7 +30,7 @@
from ..base.exceptions import MessageDroppedException
from ..base.intervention import DropMessage, InterventionHandler
from ._helpers import SubscriptionManager, get_impl
from .telemetry import EnvelopeMetadata, get_telemetry_envelope_metadata, trace_block
from .telemetry import EnvelopeMetadata, MessageRuntimeTracingConfig, TraceHelper, get_telemetry_envelope_metadata

logger = logging.getLogger("autogen_core")
event_logger = logging.getLogger("autogen_core.events")
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
intervention_handlers: List[InterventionHandler] | None = None,
tracer_provider: TracerProvider | None = None,
) -> None:
self._tracer = (tracer_provider if tracer_provider else NoOpTracerProvider()).get_tracer(__name__)
self._tracer_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("SingleThreadedAgentRuntime"))
self._message_queue: List[PublishMessageEnvelope | SendMessageEnvelope | ResponseMessageEnvelope] = []
# (namespace, type) -> List[AgentId]
self._agent_factories: Dict[
Expand Down Expand Up @@ -200,7 +200,12 @@ async def send_message(
# )
# )

with trace_block(self._tracer, "create", recipient, parent=None):
with self._tracer_helper.trace_block(
"create",
recipient,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
future = asyncio.get_event_loop().create_future()
if recipient.type not in self._known_agent_names:
future.set_exception(Exception("Recipient not found"))
Expand Down Expand Up @@ -231,7 +236,12 @@ async def publish_message(
sender: AgentId | None = None,
cancellation_token: CancellationToken | None = None,
) -> None:
with trace_block(self._tracer, "create", topic_id, parent=None):
with self._tracer_helper.trace_block(
"create",
topic_id,
parent=None,
extraAttributes={"message_type": type(message).__name__},
):
if cancellation_token is None:
cancellation_token = CancellationToken()
content = message.__dict__ if hasattr(message, "__dict__") else message
Expand Down Expand Up @@ -270,7 +280,7 @@ async def load_state(self, state: Mapping[str, Any]) -> None:
(await self._get_agent(agent_id)).load_state(state[str(agent_id)])

async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
with trace_block(self._tracer, "send", message_envelope.recipient, parent=message_envelope.metadata):
with self._tracer_helper.trace_block("send", message_envelope.recipient, parent=message_envelope.metadata):
recipient = message_envelope.recipient
# todo: check if recipient is in the known namespaces
# assert recipient in self._agents
Expand Down Expand Up @@ -319,7 +329,7 @@ async def _process_send(self, message_envelope: SendMessageEnvelope) -> None:
self._outstanding_tasks.decrement()

async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> None:
with trace_block(self._tracer, "publish", message_envelope.topic_id, parent=message_envelope.metadata):
with self._tracer_helper.trace_block("publish", message_envelope.topic_id, parent=message_envelope.metadata):
responses: List[Awaitable[Any]] = []
recipients = await self._subscription_manager.get_subscribed_recipients(message_envelope.topic_id)
for agent_id in recipients:
Expand Down Expand Up @@ -352,7 +362,7 @@ async def _process_publish(self, message_envelope: PublishMessageEnvelope) -> No
agent = await self._get_agent(agent_id)

async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
with trace_block(self._tracer, "process", agent.id, parent=None):
with self._tracer_helper.trace_block("process", agent.id, parent=None):
return await agent.on_message(
message_envelope.message,
ctx=message_context,
Expand All @@ -375,7 +385,7 @@ async def _on_message(agent: Agent, message_context: MessageContext) -> Any:
# TODO if responses are given for a publish

async def _process_response(self, message_envelope: ResponseMessageEnvelope) -> None:
with trace_block(self._tracer, "ack", message_envelope.recipient, parent=message_envelope.metadata):
with self._tracer_helper.trace_block("ack", message_envelope.recipient, parent=message_envelope.metadata):
content = (
message_envelope.message.__dict__
if hasattr(message_envelope.message, "__dict__")
Expand Down Expand Up @@ -409,8 +419,8 @@ async def process_next(self) -> None:
case SendMessageEnvelope(message=message, sender=sender, recipient=recipient, future=future):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with trace_block(
self._tracer, "intercept", handler.__class__.__name__, parent=message_envelope.metadata
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_send(message, sender=sender, recipient=recipient)
Expand All @@ -432,8 +442,8 @@ async def process_next(self) -> None:
):
if self._intervention_handlers is not None:
for handler in self._intervention_handlers:
with trace_block(
self._tracer, "intercept", handler.__class__.__name__, parent=message_envelope.metadata
with self._tracer_helper.trace_block(
"intercept", handler.__class__.__name__, parent=message_envelope.metadata
):
try:
temp_message = await handler.on_publish(message, sender=sender)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from ..components import TypeSubscription
from ._helpers import SubscriptionManager, get_impl
from .protos import agent_worker_pb2, agent_worker_pb2_grpc
from .telemetry import get_telemetry_grpc_metadata, trace_block
from .telemetry import MessageRuntimeTracingConfig, TraceHelper, get_telemetry_grpc_metadata

if TYPE_CHECKING:
from .protos.agent_worker_pb2_grpc import AgentRpcAsyncStub
Expand Down Expand Up @@ -157,7 +157,7 @@ async def recv(self) -> agent_worker_pb2.Message:

class WorkerAgentRuntime(AgentRuntime):
def __init__(self, tracer_provider: TracerProvider | None = None) -> None:
self._tracer = (tracer_provider if tracer_provider else NoOpTracerProvider()).get_tracer(__name__)
self._trace_helper = TraceHelper(tracer_provider, MessageRuntimeTracingConfig("Worker Runtime"))
self._per_type_subscribers: DefaultDict[tuple[str, str], Set[AgentId]] = defaultdict(set)
self._agent_factories: Dict[
str, Callable[[], Agent | Awaitable[Agent]] | Callable[[AgentRuntime, AgentId], Agent | Awaitable[Agent]]
Expand Down Expand Up @@ -241,7 +241,7 @@ async def _send_message(
) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
with trace_block(self._tracer, send_type, recipient, parent=telemetry_metadata):
with self._trace_helper.trace_block(send_type, recipient, parent=telemetry_metadata):
await self._host_connection.send(runtime_message)

async def send_message(
Expand All @@ -257,7 +257,9 @@ async def send_message(
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
data_type = MESSAGE_TYPE_REGISTRY.type_name(message)
with trace_block(self._tracer, "create", recipient, parent=None, attributes={"message_type": data_type}):
with self._trace_helper.trace_block(
"create", recipient, parent=None, extraAttributes={"message_type": data_type, "message_size": len(message)}
):
# create a new future for the result
future = asyncio.get_event_loop().create_future()
async with self._pending_requests_lock:
Expand Down Expand Up @@ -304,7 +306,9 @@ async def publish_message(
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message_type = MESSAGE_TYPE_REGISTRY.type_name(message)
with trace_block(self._tracer, "create", topic_id, parent=None, attributes={"message_type": message_type}):
with self._trace_helper.trace_block(
"create", topic_id, parent=None, extraAttributes={"message_type": message_type}
):
serialized_message = MESSAGE_TYPE_REGISTRY.serialize(
message, type_name=message_type, data_content_type=JSON_DATA_CONTENT_TYPE
)
Expand Down Expand Up @@ -368,12 +372,12 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
# Call the target agent.
try:
with MessageHandlerContext.populate_context(target_agent.id):
with trace_block(
self._tracer,
with self._trace_helper.trace_block(
"process",
target_agent.id,
parent=request.metadata,
attributes={"request_id": request.request_id},
extraAttributes={"message_type": request.payload.data_type},
):
result = await target_agent.on_message(message, ctx=message_context)
except BaseException as e:
Expand Down Expand Up @@ -411,8 +415,12 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest) -> None:
await self._host_connection.send(response_message)

async def _process_response(self, response: agent_worker_pb2.RpcResponse) -> None:
with trace_block(
self._tracer, "ack", None, parent=response.metadata, attributes={"request_id": response.request_id}
with self._trace_helper.trace_block(
"ack",
None,
parent=response.metadata,
attributes={"request_id": response.request_id},
extraAttributes={"message_type": response.payload.data_type},
):
# Deserialize the result.
result = MESSAGE_TYPE_REGISTRY.deserialize(
Expand Down Expand Up @@ -448,7 +456,12 @@ async def _process_event(self, event: agent_worker_pb2.Event) -> None:
with MessageHandlerContext.populate_context(agent.id):

async def send_message(agent: Agent, message_context: MessageContext) -> Any:
with trace_block(self._tracer, "process", agent.id, parent=event.metadata):
with self._trace_helper.trace_block(
"process",
agent.id,
parent=event.metadata,
extraAttributes={"message_type": event.payload.data_type},
):
await agent.on_message(message, ctx=message_context)

future = send_message(agent, message_context)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from ._tracing import (
from ._propagation import (
EnvelopeMetadata,
TelemetryMetadataContainer,
get_telemetry_envelope_metadata,
get_telemetry_grpc_metadata,
trace_block,
)
from ._tracing import TraceHelper
from ._tracing_config import MessageRuntimeTracingConfig

__all__ = [
"EnvelopeMetadata",
"get_telemetry_envelope_metadata",
"get_telemetry_grpc_metadata",
"TelemetryMetadataContainer",
"trace_block",
"TraceHelper",
"MessageRuntimeTracingConfig",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NAMESPACE = "autogen"
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from dataclasses import dataclass
from typing import Dict, Mapping, Optional

from opentelemetry.context import Context
from opentelemetry.propagate import extract
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator


@dataclass(kw_only=True)
class EnvelopeMetadata:
"""Metadata for an envelope."""

traceparent: Optional[str] = None
tracestate: Optional[str] = None


def _get_carrier_for_envelope_metadata(envelope_metadata: EnvelopeMetadata) -> Dict[str, str]:
carrier: Dict[str, str] = {}
if envelope_metadata.traceparent is not None:
carrier["traceparent"] = envelope_metadata.traceparent
if envelope_metadata.tracestate is not None:
carrier["tracestate"] = envelope_metadata.tracestate
return carrier


def get_telemetry_envelope_metadata() -> EnvelopeMetadata:
"""
Retrieves the telemetry envelope metadata.
Returns:
EnvelopeMetadata: The envelope metadata containing the traceparent and tracestate.
"""
carrier: Dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
return EnvelopeMetadata(
traceparent=carrier.get("traceparent"),
tracestate=carrier.get("tracestate"),
)


def _get_carrier_for_remote_call_metadata(remote_call_metadata: Mapping[str, str]) -> Dict[str, str]:
carrier: Dict[str, str] = {}
traceparent = remote_call_metadata.get("traceparent")
tracestate = remote_call_metadata.get("tracestate")
if traceparent:
carrier["traceparent"] = traceparent
if tracestate:
carrier["tracestate"] = tracestate
return carrier


def get_telemetry_grpc_metadata(existingMetadata: Optional[Mapping[str, str]] = None) -> Dict[str, str]:
"""
Retrieves the telemetry gRPC metadata.
Args:
existingMetadata (Optional[Mapping[str, str]]): The existing metadata to include in the gRPC metadata.
Returns:
Mapping[str, str]: The gRPC metadata containing the traceparent and tracestate.
"""
carrier: Dict[str, str] = {}
TraceContextTextMapPropagator().inject(carrier)
traceparent = carrier.get("traceparent")
tracestate = carrier.get("tracestate")
metadata: Dict[str, str] = {}
if existingMetadata is not None:
for key, value in existingMetadata.items():
metadata[key] = value
if traceparent is not None:
metadata["traceparent"] = traceparent
if tracestate is not None:
metadata["tracestate"] = tracestate
return metadata


TelemetryMetadataContainer = Optional[EnvelopeMetadata] | Mapping[str, str]


def get_telemetry_context(metadata: TelemetryMetadataContainer) -> Context:
"""
Retrieves the telemetry context from the given metadata.
Args:
metadata (Optional[EnvelopeMetadata]): The metadata containing the telemetry context.
Returns:
Context: The telemetry context extracted from the metadata, or an empty context if the metadata is None.
"""
if metadata is None:
return Context()
elif isinstance(metadata, EnvelopeMetadata):
return extract(_get_carrier_for_envelope_metadata(metadata))
elif hasattr(metadata, "__getitem__"):
return extract(_get_carrier_for_remote_call_metadata(metadata))
else:
raise ValueError(f"Unknown metadata type: {type(metadata)}")
Loading

0 comments on commit e25bd2c

Please sign in to comment.