Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def __init__(
self._name = name
# Create a new SpanContext if none provided or if Context is provided
if context is None or isinstance(context, Context):
trace_id = uuid.uuid4().int & ((1 << 128) - 1)
span_id = uuid.uuid4().int & ((1 << 64) - 1)
# Generate non-zero IDs per OTel spec (uuid4 is automatically non-zero)
trace_id = uuid.uuid4().int
span_id = uuid.uuid4().int >> 64
self._context = SpanContext(
trace_id=trace_id,
span_id=span_id,
Expand Down
28 changes: 22 additions & 6 deletions src/nat/builder/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton):
def __init__(self):
self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None)
self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None)
self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None)
self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None)
self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None)
self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None)
self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None)
Expand Down Expand Up @@ -120,14 +122,14 @@ def __init__(self, context: ContextState):
@property
def input_message(self):
"""
Retrieves the input message from the context state.
Retrieves the input message from the context state.

The input_message property is used to access the message stored in the
context state. This property returns the message as it is currently
maintained in the context.
The input_message property is used to access the message stored in the
context state. This property returns the message as it is currently
maintained in the context.

Returns:
str: The input message retrieved from the context state.
Returns:
str: The input message retrieved from the context state.
"""
return self._context_state.input_message.get()

Expand Down Expand Up @@ -196,6 +198,20 @@ def user_message_id(self) -> str | None:
"""
return self._context_state.user_message_id.get()

@property
def workflow_run_id(self) -> str | None:
"""
Returns a stable identifier for the current workflow/agent invocation (UUID string).
"""
return self._context_state.workflow_run_id.get()

@property
def workflow_trace_id(self) -> int | None:
"""
Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id.
"""
return self._context_state.workflow_trace_id.get()

@contextmanager
def push_active_function(self,
function_name: str,
Expand Down
44 changes: 41 additions & 3 deletions src/nat/data_models/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,48 @@ class SpanStatus(BaseModel):
message: str | None = Field(default=None, description="The status message of the span.")


def _generate_nonzero_trace_id() -> int:
"""Generate a non-zero 128-bit trace ID."""
return uuid.uuid4().int


def _generate_nonzero_span_id() -> int:
"""Generate a non-zero 64-bit span ID."""
return uuid.uuid4().int >> 64


class SpanContext(BaseModel):
trace_id: int = Field(default_factory=lambda: uuid.uuid4().int, description="The 128-bit trace ID of the span.")
span_id: int = Field(default_factory=lambda: uuid.uuid4().int & ((1 << 64) - 1),
description="The 64-bit span ID of the span.")
trace_id: int = Field(default_factory=_generate_nonzero_trace_id,
description="The OTel-syle 128-bit trace ID of the span.")
span_id: int = Field(default_factory=_generate_nonzero_span_id,
description="The OTel-syle 64-bit span ID of the span.")

@field_validator("trace_id", mode="before")
@classmethod
def _validate_trace_id(cls, v: int | str | None) -> int:
"""Regenerate if trace_id is None; raise an exception if trace_id is invalid;"""
if isinstance(v, str):
v = uuid.UUID(v).int
if isinstance(v, type(None)):
v = _generate_nonzero_trace_id()
if v <= 0 or v >> 128:
raise ValueError(f"Invalid trace_id: must be a non-zero 128-bit integer, got {v}")
return v

@field_validator("span_id", mode="before")
@classmethod
def _validate_span_id(cls, v: int | str | None) -> int:
"""Regenerate if span_id is None; raise an exception if span_id is invalid;"""
if isinstance(v, str):
try:
v = int(v, 16)
except ValueError:
raise ValueError(f"span_id unable to be parsed: {v}")
if isinstance(v, type(None)):
v = _generate_nonzero_span_id()
if v <= 0 or v >> 64:
raise ValueError(f"Invalid span_id: must be a non-zero 64-bit integer, got {v}")
return v


class Span(BaseModel):
Expand Down
48 changes: 34 additions & 14 deletions src/nat/observability/exporter/span_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def _process_start_event(self, event: IntermediateStep):

parent_span = None
span_ctx = None
workflow_trace_id = self._context_state.workflow_trace_id.get()

# Look up the parent span to establish hierarchy
# event.parent_id is the UUID of the last START step with a different UUID from current step
Expand All @@ -141,6 +142,9 @@ def _process_start_event(self, event: IntermediateStep):
parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None
if parent_span and parent_span.context:
span_ctx = SpanContext(trace_id=parent_span.context.trace_id)
# No parent: adopt workflow trace id if available to keep all spans in the same trace
if span_ctx is None and workflow_trace_id:
span_ctx = SpanContext(trace_id=workflow_trace_id)

# Extract start/end times from the step
# By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended.
Expand All @@ -154,23 +158,39 @@ def _process_start_event(self, event: IntermediateStep):
else:
sub_span_name = f"{event.payload.event_type}"

# Prefer parent/context trace id for attribute, else workflow trace id
_attr_trace_id = None
if span_ctx is not None:
_attr_trace_id = span_ctx.trace_id
elif parent_span and parent_span.context:
_attr_trace_id = parent_span.context.trace_id
elif workflow_trace_id:
_attr_trace_id = workflow_trace_id

attributes = {
f"{self._span_prefix}.event_type":
event.payload.event_type.value,
f"{self._span_prefix}.function.id":
event.function_ancestry.function_id if event.function_ancestry else "unknown",
f"{self._span_prefix}.function.name":
event.function_ancestry.function_name if event.function_ancestry else "unknown",
f"{self._span_prefix}.subspan.name":
event.payload.name or "",
f"{self._span_prefix}.event_timestamp":
event.event_timestamp,
f"{self._span_prefix}.framework":
event.payload.framework.value if event.payload.framework else "unknown",
f"{self._span_prefix}.conversation.id":
self._context_state.conversation_id.get() or "unknown",
f"{self._span_prefix}.workflow.run_id":
self._context_state.workflow_run_id.get() or "unknown",
f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"),
}

sub_span = Span(name=sub_span_name,
parent=parent_span,
context=span_ctx,
attributes={
f"{self._span_prefix}.event_type":
event.payload.event_type.value,
f"{self._span_prefix}.function.id":
event.function_ancestry.function_id if event.function_ancestry else "unknown",
f"{self._span_prefix}.function.name":
event.function_ancestry.function_name if event.function_ancestry else "unknown",
f"{self._span_prefix}.subspan.name":
event.payload.name or "",
f"{self._span_prefix}.event_timestamp":
event.event_timestamp,
f"{self._span_prefix}.framework":
event.payload.framework.value if event.payload.framework else "unknown",
},
attributes=attributes,
start_time=start_ns)

span_kind = event_type_to_span_kind(event.event_type)
Expand Down
109 changes: 103 additions & 6 deletions src/nat/runtime/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@

import logging
import typing
import uuid
from enum import Enum

from nat.builder.context import Context
from nat.builder.context import ContextState
from nat.builder.function import Function
from nat.data_models.intermediate_step import IntermediateStepPayload
from nat.data_models.intermediate_step import IntermediateStepType
from nat.data_models.intermediate_step import StreamEventData
from nat.data_models.intermediate_step import TraceMetadata
from nat.data_models.invocation_node import InvocationNode
from nat.observability.exporter_manager import ExporterManager
from nat.utils.reactive.subject import Subject
Expand Down Expand Up @@ -130,17 +135,59 @@ async def result(self, to_type: type | None = None):
if (self._state != RunnerState.INITIALIZED):
raise ValueError("Cannot run the workflow without entering the context")

token_run_id = None
token_trace_id = None
try:
self._state = RunnerState.RUNNING

if (not self._entry_fn.has_single_output):
raise ValueError("Workflow does not support single output")

# Establish workflow run and trace identifiers
existing_run_id = self._context_state.workflow_run_id.get()
existing_trace_id = self._context_state.workflow_trace_id.get()

workflow_run_id = existing_run_id or str(uuid.uuid4())

workflow_trace_id = existing_trace_id or uuid.uuid4().int

token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)

# Prepare workflow-level intermediate step identifiers
workflow_step_uuid = str(uuid.uuid4())
workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"

async with self._exporter_manager.start(context_state=self._context_state):
# Run the workflow
result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type)
# Emit WORKFLOW_START
start_metadata = TraceMetadata(
provided_metadata={
"workflow_run_id": workflow_run_id,
"workflow_trace_id": f"{workflow_trace_id:032x}",
"conversation_id": self._context_state.conversation_id.get(),
})
self._context.intermediate_step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_START,
name=workflow_name,
metadata=start_metadata))

result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore

# Emit WORKFLOW_END with output
end_metadata = TraceMetadata(
provided_metadata={
"workflow_run_id": workflow_run_id,
"workflow_trace_id": f"{workflow_trace_id:032x}",
"conversation_id": self._context_state.conversation_id.get(),
})
self._context.intermediate_step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_END,
name=workflow_name,
metadata=end_metadata,
data=StreamEventData(output=result)))

# Close the intermediate stream
event_stream = self._context_state.event_stream.get()
if event_stream:
event_stream.on_complete()
Expand All @@ -155,25 +202,71 @@ async def result(self, to_type: type | None = None):
if event_stream:
event_stream.on_complete()
self._state = RunnerState.FAILED

raise
finally:
if token_run_id is not None:
self._context_state.workflow_run_id.reset(token_run_id)
if token_trace_id is not None:
self._context_state.workflow_trace_id.reset(token_trace_id)

async def result_stream(self, to_type: type | None = None):

if (self._state != RunnerState.INITIALIZED):
raise ValueError("Cannot run the workflow without entering the context")

token_run_id = None
token_trace_id = None
try:
self._state = RunnerState.RUNNING

if (not self._entry_fn.has_streaming_output):
raise ValueError("Workflow does not support streaming output")

# Establish workflow run and trace identifiers
existing_run_id = self._context_state.workflow_run_id.get()
existing_trace_id = self._context_state.workflow_trace_id.get()

workflow_run_id = existing_run_id or str(uuid.uuid4())

workflow_trace_id = existing_trace_id or uuid.uuid4().int

token_run_id = self._context_state.workflow_run_id.set(workflow_run_id)
token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id)

# Prepare workflow-level intermediate step identifiers
workflow_step_uuid = str(uuid.uuid4())
workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow"

# Run the workflow
async with self._exporter_manager.start(context_state=self._context_state):
async for m in self._entry_fn.astream(self._input_message, to_type=to_type):
# Emit WORKFLOW_START
start_metadata = TraceMetadata(
provided_metadata={
"workflow_run_id": workflow_run_id,
"workflow_trace_id": f"{workflow_trace_id:032x}",
"conversation_id": self._context_state.conversation_id.get(),
})
self._context.intermediate_step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_START,
name=workflow_name,
metadata=start_metadata))

async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore
yield m

# Emit WORKFLOW_END
end_metadata = TraceMetadata(
provided_metadata={
"workflow_run_id": workflow_run_id,
"workflow_trace_id": f"{workflow_trace_id:032x}",
"conversation_id": self._context_state.conversation_id.get(),
})
self._context.intermediate_step_manager.push_intermediate_step(
IntermediateStepPayload(UUID=workflow_step_uuid,
event_type=IntermediateStepType.WORKFLOW_END,
name=workflow_name,
metadata=end_metadata))
self._state = RunnerState.COMPLETED

# Close the intermediate stream
Expand All @@ -187,8 +280,12 @@ async def result_stream(self, to_type: type | None = None):
if event_stream:
event_stream.on_complete()
self._state = RunnerState.FAILED

raise
finally:
if token_run_id is not None:
self._context_state.workflow_run_id.reset(token_run_id)
if token_trace_id is not None:
self._context_state.workflow_trace_id.reset(token_trace_id)


# Compatibility aliases with previous releases
Expand Down
26 changes: 26 additions & 0 deletions src/nat/runtime/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import asyncio
import contextvars
import typing
import uuid
from collections.abc import Awaitable
from collections.abc import Callable
from contextlib import asynccontextmanager
Expand Down Expand Up @@ -161,6 +162,31 @@ def set_metadata_from_http_request(self, request: Request) -> None:
if request.headers.get("user-message-id"):
self._context_state.user_message_id.set(request.headers["user-message-id"])

# W3C Trace Context header: traceparent: 00-<trace-id>-<span-id>-<flags>
traceparent = request.headers.get("traceparent")
if traceparent:
try:
parts = traceparent.split("-")
if len(parts) >= 4:
trace_id_hex = parts[1]
if len(trace_id_hex) == 32:
trace_id_int = uuid.UUID(trace_id_hex).int
self._context_state.workflow_trace_id.set(trace_id_int)
except Exception:
pass

if not self._context_state.workflow_trace_id.get():
workflow_trace_id = request.headers.get("workflow-trace-id")
if workflow_trace_id:
try:
self._context_state.workflow_trace_id.set(uuid.UUID(workflow_trace_id).int)
except Exception:
pass

workflow_run_id = request.headers.get("workflow-run-id")
if workflow_run_id:
self._context_state.workflow_run_id.set(workflow_run_id)

def set_metadata_from_websocket(self,
websocket: WebSocket,
user_message_id: str | None,
Expand Down
Loading