diff --git a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py index dbd3e4e33..2c1420503 100644 --- a/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py +++ b/packages/nvidia_nat_opentelemetry/src/nat/plugins/opentelemetry/otel_span.py @@ -86,8 +86,9 @@ def __init__( self._name = name # Create a new SpanContext if none provided or if Context is provided if context is None or isinstance(context, Context): - trace_id = uuid.uuid4().int & ((1 << 128) - 1) - span_id = uuid.uuid4().int & ((1 << 64) - 1) + # Generate non-zero IDs per OTel spec (uuid4 is automatically non-zero) + trace_id = uuid.uuid4().int + span_id = uuid.uuid4().int >> 64 self._context = SpanContext( trace_id=trace_id, span_id=span_id, diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index 6d333b688..256e19c13 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -67,6 +67,8 @@ class ContextState(metaclass=Singleton): def __init__(self): self.conversation_id: ContextVar[str | None] = ContextVar("conversation_id", default=None) self.user_message_id: ContextVar[str | None] = ContextVar("user_message_id", default=None) + self.workflow_run_id: ContextVar[str | None] = ContextVar("workflow_run_id", default=None) + self.workflow_trace_id: ContextVar[int | None] = ContextVar("workflow_trace_id", default=None) self.input_message: ContextVar[typing.Any] = ContextVar("input_message", default=None) self.user_manager: ContextVar[typing.Any] = ContextVar("user_manager", default=None) self._metadata: ContextVar[RequestAttributes | None] = ContextVar("request_attributes", default=None) @@ -120,14 +122,14 @@ def __init__(self, context: ContextState): @property def input_message(self): """ - Retrieves the input message from the context state. + Retrieves the input message from the context state. - The input_message property is used to access the message stored in the - context state. This property returns the message as it is currently - maintained in the context. + The input_message property is used to access the message stored in the + context state. This property returns the message as it is currently + maintained in the context. - Returns: - str: The input message retrieved from the context state. + Returns: + str: The input message retrieved from the context state. """ return self._context_state.input_message.get() @@ -196,6 +198,20 @@ def user_message_id(self) -> str | None: """ return self._context_state.user_message_id.get() + @property + def workflow_run_id(self) -> str | None: + """ + Returns a stable identifier for the current workflow/agent invocation (UUID string). + """ + return self._context_state.workflow_run_id.get() + + @property + def workflow_trace_id(self) -> int | None: + """ + Returns the 128-bit trace identifier for the current run, used as the OpenTelemetry trace_id. + """ + return self._context_state.workflow_trace_id.get() + @contextmanager def push_active_function(self, function_name: str, diff --git a/src/nat/data_models/span.py b/src/nat/data_models/span.py index ae8fff231..5470fa9dd 100644 --- a/src/nat/data_models/span.py +++ b/src/nat/data_models/span.py @@ -128,10 +128,48 @@ class SpanStatus(BaseModel): message: str | None = Field(default=None, description="The status message of the span.") +def _generate_nonzero_trace_id() -> int: + """Generate a non-zero 128-bit trace ID.""" + return uuid.uuid4().int + + +def _generate_nonzero_span_id() -> int: + """Generate a non-zero 64-bit span ID.""" + return uuid.uuid4().int >> 64 + + class SpanContext(BaseModel): - trace_id: int = Field(default_factory=lambda: uuid.uuid4().int, description="The 128-bit trace ID of the span.") - span_id: int = Field(default_factory=lambda: uuid.uuid4().int & ((1 << 64) - 1), - description="The 64-bit span ID of the span.") + trace_id: int = Field(default_factory=_generate_nonzero_trace_id, + description="The OTel-syle 128-bit trace ID of the span.") + span_id: int = Field(default_factory=_generate_nonzero_span_id, + description="The OTel-syle 64-bit span ID of the span.") + + @field_validator("trace_id", mode="before") + @classmethod + def _validate_trace_id(cls, v: int | str | None) -> int: + """Regenerate if trace_id is None; raise an exception if trace_id is invalid;""" + if isinstance(v, str): + v = uuid.UUID(v).int + if isinstance(v, type(None)): + v = _generate_nonzero_trace_id() + if v <= 0 or v >> 128: + raise ValueError(f"Invalid trace_id: must be a non-zero 128-bit integer, got {v}") + return v + + @field_validator("span_id", mode="before") + @classmethod + def _validate_span_id(cls, v: int | str | None) -> int: + """Regenerate if span_id is None; raise an exception if span_id is invalid;""" + if isinstance(v, str): + try: + v = int(v, 16) + except ValueError: + raise ValueError(f"span_id unable to be parsed: {v}") + if isinstance(v, type(None)): + v = _generate_nonzero_span_id() + if v <= 0 or v >> 64: + raise ValueError(f"Invalid span_id: must be a non-zero 64-bit integer, got {v}") + return v class Span(BaseModel): diff --git a/src/nat/observability/exporter/span_exporter.py b/src/nat/observability/exporter/span_exporter.py index 0960359e3..14cbfc93a 100644 --- a/src/nat/observability/exporter/span_exporter.py +++ b/src/nat/observability/exporter/span_exporter.py @@ -126,6 +126,7 @@ def _process_start_event(self, event: IntermediateStep): parent_span = None span_ctx = None + workflow_trace_id = self._context_state.workflow_trace_id.get() # Look up the parent span to establish hierarchy # event.parent_id is the UUID of the last START step with a different UUID from current step @@ -141,6 +142,9 @@ def _process_start_event(self, event: IntermediateStep): parent_span = parent_span.model_copy() if isinstance(parent_span, Span) else None if parent_span and parent_span.context: span_ctx = SpanContext(trace_id=parent_span.context.trace_id) + # No parent: adopt workflow trace id if available to keep all spans in the same trace + if span_ctx is None and workflow_trace_id: + span_ctx = SpanContext(trace_id=workflow_trace_id) # Extract start/end times from the step # By convention, `span_event_timestamp` is the time we started, `event_timestamp` is the time we ended. @@ -154,23 +158,39 @@ def _process_start_event(self, event: IntermediateStep): else: sub_span_name = f"{event.payload.event_type}" + # Prefer parent/context trace id for attribute, else workflow trace id + _attr_trace_id = None + if span_ctx is not None: + _attr_trace_id = span_ctx.trace_id + elif parent_span and parent_span.context: + _attr_trace_id = parent_span.context.trace_id + elif workflow_trace_id: + _attr_trace_id = workflow_trace_id + + attributes = { + f"{self._span_prefix}.event_type": + event.payload.event_type.value, + f"{self._span_prefix}.function.id": + event.function_ancestry.function_id if event.function_ancestry else "unknown", + f"{self._span_prefix}.function.name": + event.function_ancestry.function_name if event.function_ancestry else "unknown", + f"{self._span_prefix}.subspan.name": + event.payload.name or "", + f"{self._span_prefix}.event_timestamp": + event.event_timestamp, + f"{self._span_prefix}.framework": + event.payload.framework.value if event.payload.framework else "unknown", + f"{self._span_prefix}.conversation.id": + self._context_state.conversation_id.get() or "unknown", + f"{self._span_prefix}.workflow.run_id": + self._context_state.workflow_run_id.get() or "unknown", + f"{self._span_prefix}.workflow.trace_id": (f"{_attr_trace_id:032x}" if _attr_trace_id else "unknown"), + } + sub_span = Span(name=sub_span_name, parent=parent_span, context=span_ctx, - attributes={ - f"{self._span_prefix}.event_type": - event.payload.event_type.value, - f"{self._span_prefix}.function.id": - event.function_ancestry.function_id if event.function_ancestry else "unknown", - f"{self._span_prefix}.function.name": - event.function_ancestry.function_name if event.function_ancestry else "unknown", - f"{self._span_prefix}.subspan.name": - event.payload.name or "", - f"{self._span_prefix}.event_timestamp": - event.event_timestamp, - f"{self._span_prefix}.framework": - event.payload.framework.value if event.payload.framework else "unknown", - }, + attributes=attributes, start_time=start_ns) span_kind = event_type_to_span_kind(event.event_type) diff --git a/src/nat/runtime/runner.py b/src/nat/runtime/runner.py index caef11cd2..ea5843b8d 100644 --- a/src/nat/runtime/runner.py +++ b/src/nat/runtime/runner.py @@ -15,11 +15,16 @@ import logging import typing +import uuid from enum import Enum from nat.builder.context import Context from nat.builder.context import ContextState from nat.builder.function import Function +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import StreamEventData +from nat.data_models.intermediate_step import TraceMetadata from nat.data_models.invocation_node import InvocationNode from nat.observability.exporter_manager import ExporterManager from nat.utils.reactive.subject import Subject @@ -130,17 +135,59 @@ async def result(self, to_type: type | None = None): if (self._state != RunnerState.INITIALIZED): raise ValueError("Cannot run the workflow without entering the context") + token_run_id = None + token_trace_id = None try: self._state = RunnerState.RUNNING if (not self._entry_fn.has_single_output): raise ValueError("Workflow does not support single output") + # Establish workflow run and trace identifiers + existing_run_id = self._context_state.workflow_run_id.get() + existing_trace_id = self._context_state.workflow_trace_id.get() + + workflow_run_id = existing_run_id or str(uuid.uuid4()) + + workflow_trace_id = existing_trace_id or uuid.uuid4().int + + token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) + token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + + # Prepare workflow-level intermediate step identifiers + workflow_step_uuid = str(uuid.uuid4()) + workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" + async with self._exporter_manager.start(context_state=self._context_state): - # Run the workflow - result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) + # Emit WORKFLOW_START + start_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_START, + name=workflow_name, + metadata=start_metadata)) + + result = await self._entry_fn.ainvoke(self._input_message, to_type=to_type) # type: ignore + + # Emit WORKFLOW_END with output + end_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_END, + name=workflow_name, + metadata=end_metadata, + data=StreamEventData(output=result))) - # Close the intermediate stream event_stream = self._context_state.event_stream.get() if event_stream: event_stream.on_complete() @@ -155,25 +202,71 @@ async def result(self, to_type: type | None = None): if event_stream: event_stream.on_complete() self._state = RunnerState.FAILED - raise + finally: + if token_run_id is not None: + self._context_state.workflow_run_id.reset(token_run_id) + if token_trace_id is not None: + self._context_state.workflow_trace_id.reset(token_trace_id) async def result_stream(self, to_type: type | None = None): if (self._state != RunnerState.INITIALIZED): raise ValueError("Cannot run the workflow without entering the context") + token_run_id = None + token_trace_id = None try: self._state = RunnerState.RUNNING if (not self._entry_fn.has_streaming_output): raise ValueError("Workflow does not support streaming output") + # Establish workflow run and trace identifiers + existing_run_id = self._context_state.workflow_run_id.get() + existing_trace_id = self._context_state.workflow_trace_id.get() + + workflow_run_id = existing_run_id or str(uuid.uuid4()) + + workflow_trace_id = existing_trace_id or uuid.uuid4().int + + token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) + token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + + # Prepare workflow-level intermediate step identifiers + workflow_step_uuid = str(uuid.uuid4()) + workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" + # Run the workflow async with self._exporter_manager.start(context_state=self._context_state): - async for m in self._entry_fn.astream(self._input_message, to_type=to_type): + # Emit WORKFLOW_START + start_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_START, + name=workflow_name, + metadata=start_metadata)) + + async for m in self._entry_fn.astream(self._input_message, to_type=to_type): # type: ignore yield m + # Emit WORKFLOW_END + end_metadata = TraceMetadata( + provided_metadata={ + "workflow_run_id": workflow_run_id, + "workflow_trace_id": f"{workflow_trace_id:032x}", + "conversation_id": self._context_state.conversation_id.get(), + }) + self._context.intermediate_step_manager.push_intermediate_step( + IntermediateStepPayload(UUID=workflow_step_uuid, + event_type=IntermediateStepType.WORKFLOW_END, + name=workflow_name, + metadata=end_metadata)) self._state = RunnerState.COMPLETED # Close the intermediate stream @@ -187,8 +280,12 @@ async def result_stream(self, to_type: type | None = None): if event_stream: event_stream.on_complete() self._state = RunnerState.FAILED - raise + finally: + if token_run_id is not None: + self._context_state.workflow_run_id.reset(token_run_id) + if token_trace_id is not None: + self._context_state.workflow_trace_id.reset(token_trace_id) # Compatibility aliases with previous releases diff --git a/src/nat/runtime/session.py b/src/nat/runtime/session.py index 5e70fb09f..08720dafb 100644 --- a/src/nat/runtime/session.py +++ b/src/nat/runtime/session.py @@ -16,6 +16,7 @@ import asyncio import contextvars import typing +import uuid from collections.abc import Awaitable from collections.abc import Callable from contextlib import asynccontextmanager @@ -161,6 +162,31 @@ def set_metadata_from_http_request(self, request: Request) -> None: if request.headers.get("user-message-id"): self._context_state.user_message_id.set(request.headers["user-message-id"]) + # W3C Trace Context header: traceparent: 00--- + traceparent = request.headers.get("traceparent") + if traceparent: + try: + parts = traceparent.split("-") + if len(parts) >= 4: + trace_id_hex = parts[1] + if len(trace_id_hex) == 32: + trace_id_int = uuid.UUID(trace_id_hex).int + self._context_state.workflow_trace_id.set(trace_id_int) + except Exception: + pass + + if not self._context_state.workflow_trace_id.get(): + workflow_trace_id = request.headers.get("workflow-trace-id") + if workflow_trace_id: + try: + self._context_state.workflow_trace_id.set(uuid.UUID(workflow_trace_id).int) + except Exception: + pass + + workflow_run_id = request.headers.get("workflow-run-id") + if workflow_run_id: + self._context_state.workflow_run_id.set(workflow_run_id) + def set_metadata_from_websocket(self, websocket: WebSocket, user_message_id: str | None, diff --git a/tests/nat/opentelemetry/test_otel_span_ids.py b/tests/nat/opentelemetry/test_otel_span_ids.py new file mode 100644 index 000000000..f0eb90a13 --- /dev/null +++ b/tests/nat/opentelemetry/test_otel_span_ids.py @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.plugins.opentelemetry.otel_span import OtelSpan + + +def test_otel_span_ids_are_non_zero(): + s = OtelSpan(name="test", context=None, parent=None, attributes={}) + ctx = s.get_span_context() + assert ctx.trace_id != 0 + assert ctx.span_id != 0 + assert len(f"{ctx.trace_id:032x}") == 32 + assert len(f"{ctx.span_id:016x}") == 16 diff --git a/tests/nat/runner/test_runner.py b/tests/nat/runtime/test_runner.py similarity index 100% rename from tests/nat/runner/test_runner.py rename to tests/nat/runtime/test_runner.py diff --git a/tests/nat/runtime/test_runner_trace_ids.py b/tests/nat/runtime/test_runner_trace_ids.py new file mode 100644 index 000000000..d5485bd0f --- /dev/null +++ b/tests/nat/runtime/test_runner_trace_ids.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import pytest + +from nat.builder.context import Context +from nat.builder.context import ContextState +from nat.builder.function import Function +from nat.observability.exporter_manager import ExporterManager +from nat.runtime.runner import Runner + + +class _DummyFunction: + has_single_output = True + has_streaming_output = True + instance_name = "workflow" + + def convert(self, v, to_type): + return v + + async def ainvoke(self, _message, to_type=None): + ctx = Context.get() + assert isinstance(ctx.workflow_trace_id, int) and ctx.workflow_trace_id != 0 + return {"ok": True} + + async def astream(self, _message, to_type=None): + ctx = Context.get() + assert isinstance(ctx.workflow_trace_id, int) and ctx.workflow_trace_id != 0 + yield "chunk-1" + + +class _DummyExporterManager: + + def start(self, context_state=None): + + class _Ctx: + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + return _Ctx() + + +@pytest.mark.parametrize("method", ["result", "result_stream"]) # result vs stream +@pytest.mark.parametrize("existing_run", [True, False]) +@pytest.mark.parametrize("existing_trace", [True, False]) +@pytest.mark.asyncio +async def test_runner_trace_and_run_ids(existing_trace: bool, existing_run: bool, method: str): + ctx_state = ContextState.get() + + # Seed existing values according to parameters + seeded_trace = int("f" * 32, 16) if existing_trace else None + seeded_run = "existing-run-id" if existing_run else None + + tkn_trace = ctx_state.workflow_trace_id.set(seeded_trace) + tkn_run = ctx_state.workflow_run_id.set(seeded_run) + + try: + runner = Runner( + "msg", + typing.cast(Function, _DummyFunction()), + ctx_state, + typing.cast(ExporterManager, _DummyExporterManager()), + ) + async with runner: + if method == "result": + out = await runner.result() + assert out == {"ok": True} + else: + chunks: list[str] = [] + async for c in runner.result_stream(): + chunks.append(c) + assert chunks == ["chunk-1"] + + # After run, context should be restored to seeded values + assert ctx_state.workflow_trace_id.get() == seeded_trace + assert ctx_state.workflow_run_id.get() == seeded_run + finally: + ctx_state.workflow_trace_id.reset(tkn_trace) + ctx_state.workflow_run_id.reset(tkn_run) diff --git a/tests/nat/runtime/test_session_traceparent.py b/tests/nat/runtime/test_session_traceparent.py new file mode 100644 index 000000000..c6332ac07 --- /dev/null +++ b/tests/nat/runtime/test_session_traceparent.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +import uuid + +import pytest +from starlette.requests import Request + +from nat.builder.context import ContextState +from nat.builder.workflow import Workflow +from nat.runtime.session import SessionManager + + +class _DummyWorkflow: + config = None + + +# Build parameter sets at import time to keep test bodies simple +_random_trace_hex = uuid.uuid4().hex +_random_workflow_uuid_hex = uuid.uuid4().hex +_random_workflow_uuid_str = str(uuid.uuid4()) + +TRACE_ID_CASES: list[tuple[list[tuple[bytes, bytes]], int | None]] = [ + # traceparent valid cases + ([(b"traceparent", f"00-{'a'*32}-{'b'*16}-01".encode())], int("a" * 32, 16)), + ([(b"traceparent", f"00-{'A'*32}-{'b'*16}-01".encode())], int("A" * 32, 16)), + ([(b"traceparent", f"00-{_random_trace_hex}-{'b'*16}-01".encode())], int(_random_trace_hex, 16)), + # workflow-trace-id valid cases (hex and hyphenated) + ([(b"workflow-trace-id", _random_workflow_uuid_hex.encode())], uuid.UUID(_random_workflow_uuid_hex).int), + ([(b"workflow-trace-id", _random_workflow_uuid_str.encode())], uuid.UUID(_random_workflow_uuid_str).int), + # invalid traceparent falls back to workflow-trace-id + ([ + (b"traceparent", f"00-{'a'*31}-{'b'*16}-01".encode()), + (b"workflow-trace-id", _random_workflow_uuid_str.encode()), + ], + uuid.UUID(_random_workflow_uuid_str).int), + # invalid both -> None + ([ + (b"traceparent", f"00-{'g'*32}-{'b'*16}-01".encode()), + (b"workflow-trace-id", b"z" * 32), + ], None), + # prefer traceparent when both valid + ([ + (b"traceparent", f"00-{'c'*32}-{'d'*16}-01".encode()), + (b"workflow-trace-id", str(uuid.uuid4()).encode()), + ], + int("c" * 32, 16)), + # zero values + ([(b"traceparent", f"00-{'0'*32}-{'b'*16}-01".encode())], 0), + ([(b"workflow-trace-id", ("0" * 32).encode())], 0), + # malformed span id but valid trace id + ([(b"traceparent", f"00-{'a'*32}-XYZ-01".encode())], int("a" * 32, 16)), + # too few parts -> ignore + ([(b"traceparent", f"00-{'a'*32}".encode())], None), + # extra parts -> still ok + ([(b"traceparent", f"00-{'b'*32}-{'c'*16}-01-extra".encode())], int("b" * 32, 16)), + # negative and overflow workflow-trace-id -> ignore + ([(b"workflow-trace-id", b"-1")], None), + ([(b"workflow-trace-id", ("f" * 33).encode())], None), +] + + +@pytest.mark.parametrize( + "headers,expected_trace_id", + TRACE_ID_CASES, +) +@pytest.mark.asyncio +async def test_session_trace_id_from_headers_parameterized(headers: list[tuple[bytes, bytes]], + expected_trace_id: int | None): + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": headers, + "client": ("127.0.0.1", 1234), + "scheme": "http", + "server": ("testserver", 80), + "query_string": b"", + } + request = Request(scope) + + ctx_state = ContextState.get() + token = ctx_state.workflow_trace_id.set(None) + try: + sm = SessionManager(workflow=typing.cast(Workflow, _DummyWorkflow()), max_concurrency=0) + sm.set_metadata_from_http_request(request) + assert ctx_state.workflow_trace_id.get() == expected_trace_id + finally: + ctx_state.workflow_trace_id.reset(token) + + +METADATA_CASES: list[tuple[list[tuple[bytes, bytes]], str | None, str | None, str | None]] = [ + ([(b"conversation-id", b"conv-123")], "conv-123", None, None), + ([(b"user-message-id", b"msg-456")], None, "msg-456", None), + ([(b"workflow-run-id", b"run-789")], None, None, "run-789"), + ( + [ + (b"conversation-id", b"conv-123"), + (b"user-message-id", b"msg-456"), + (b"workflow-run-id", b"run-789"), + (b"traceparent", f"00-{'e'*32}-{'f'*16}-01".encode()), + ], + "conv-123", + "msg-456", + "run-789", + ), +] + + +@pytest.mark.parametrize( + "headers,expected_conv,expected_msg,expected_run", + METADATA_CASES, +) +@pytest.mark.asyncio +async def test_session_metadata_headers_parameterized(headers: list[tuple[bytes, bytes]], + expected_conv: str | None, + expected_msg: str | None, + expected_run: str | None): + scope = { + "type": "http", + "method": "GET", + "path": "/", + "headers": headers, + "client": ("127.0.0.1", 1234), + "scheme": "http", + "server": ("testserver", 80), + "query_string": b"", + } + request = Request(scope) + + ctx_state = ContextState.get() + tkn_conv = ctx_state.conversation_id.set(None) + tkn_msg = ctx_state.user_message_id.set(None) + tkn_run = ctx_state.workflow_run_id.set(None) + tkn_trace = ctx_state.workflow_trace_id.set(None) + try: + sm = SessionManager(workflow=typing.cast(Workflow, _DummyWorkflow()), max_concurrency=0) + sm.set_metadata_from_http_request(request) + assert ctx_state.conversation_id.get() == expected_conv + assert ctx_state.user_message_id.get() == expected_msg + assert ctx_state.workflow_run_id.get() == expected_run + finally: + ctx_state.conversation_id.reset(tkn_conv) + ctx_state.user_message_id.reset(tkn_msg) + ctx_state.workflow_run_id.reset(tkn_run) + ctx_state.workflow_trace_id.reset(tkn_trace)