diff --git a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_custom_headers.py b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_custom_headers.py index 2d50d0704f..f5701f6cc4 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_custom_headers.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/tests/test_asgi_custom_headers.py @@ -1,6 +1,18 @@ +import typing from unittest import mock import opentelemetry.instrumentation.asgi as otel_asgi +from opentelemetry import trace +from opentelemetry.context import Context +from opentelemetry.propagate import get_global_textmap, set_global_textmap +from opentelemetry.propagators.textmap import ( + CarrierT, + Getter, + Setter, + TextMapPropagator, + default_getter, + default_setter, +) from opentelemetry.test.asgitestutil import AsgiTestBase from opentelemetry.test.test_base import TestBase from opentelemetry.trace import SpanKind @@ -13,6 +25,60 @@ from .test_asgi_middleware import simple_asgi +class MockTextMapPropagator(TextMapPropagator): + """Mock propagator for testing purposes using both getter `get` and `all`.""" + + TRACE_ID_KEY = "mock-traceid" + SPAN_ID_KEY = "mock-spanid" + + def extract( + self, + carrier: CarrierT, + context: typing.Optional[Context] = None, + getter: Getter = default_getter, + ) -> Context: + if context is None: + context = Context() + + trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) + span_id_list = getter.get(carrier, self.SPAN_ID_KEY) + carrier_keys = getter.keys(carrier) + + if not trace_id_list or not span_id_list: + assert not any(key in carrier_keys for key in self.fields) + return context + + assert all(key in carrier_keys for key in self.fields) + return trace.set_span_in_context( + trace.NonRecordingSpan( + trace.SpanContext( + trace_id=int(trace_id_list[0]), + span_id=int(span_id_list[0]), + is_remote=True, + ) + ), + context, + ) + + def inject( + self, + carrier: CarrierT, + context: typing.Optional[Context] = None, + setter: Setter = default_setter, + ) -> None: + span = trace.get_current_span(context) + setter.set( + carrier, self.TRACE_ID_KEY, str(span.get_span_context().trace_id) + ) + setter.set( + carrier, self.SPAN_ID_KEY, str(span.get_span_context().span_id) + ) + + @property + def fields(self): + return {self.TRACE_ID_KEY, self.SPAN_ID_KEY} + + async def http_app_with_custom_headers(scope, receive, send): message = await receive() assert scope["type"] == "http" @@ -34,6 +100,8 @@ async def http_app_with_custom_headers(scope, receive, send): b"my-custom-regex-value-3,my-custom-regex-value-4", ), (b"my-secret-header", b"my-secret-value"), + (MockTextMapPropagator.TRACE_ID_KEY.encode(), b"1"), + (MockTextMapPropagator.SPAN_ID_KEY.encode(), b"2"), ], } ) @@ -60,6 +128,8 @@ async def websocket_app_with_custom_headers(scope, receive, send): b"my-custom-regex-value-3,my-custom-regex-value-4", ), (b"my-secret-header", b"my-secret-value"), + (MockTextMapPropagator.TRACE_ID_KEY.encode(), b"1"), + (MockTextMapPropagator.SPAN_ID_KEY.encode(), b"2"), ], } ) @@ -88,6 +158,11 @@ def setUp(self): self.app = otel_asgi.OpenTelemetryMiddleware( simple_asgi, tracer_provider=self.tracer_provider ) + self.previous_propagator = get_global_textmap() + set_global_textmap(MockTextMapPropagator()) + + def tearDown(self): + set_global_textmap(self.previous_propagator) def test_http_custom_request_headers_in_span_attributes(self): self.scope["headers"].extend(