Skip to content

Commit

Permalink
Test fix generating ASGI keys
Browse files Browse the repository at this point in the history
  • Loading branch information
hasier committed Nov 10, 2022
1 parent a0b4c4f commit afcf3c5
Showing 1 changed file with 75 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import typing
from unittest import mock

import opentelemetry.instrumentation.asgi as otel_asgi
from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.propagate import get_global_textmap, set_global_textmap
from opentelemetry.propagators.textmap import (
CarrierT,
Getter,
Setter,
TextMapPropagator,
default_getter,
default_setter,
)
from opentelemetry.test.asgitestutil import AsgiTestBase
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind
Expand All @@ -13,6 +25,60 @@
from .test_asgi_middleware import simple_asgi


class MockTextMapPropagator(TextMapPropagator):
"""Mock propagator for testing purposes using both getter `get` and `all`."""

TRACE_ID_KEY = "mock-traceid"
SPAN_ID_KEY = "mock-spanid"

def extract(
self,
carrier: CarrierT,
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
if context is None:
context = Context()

trace_id_list = getter.get(carrier, self.TRACE_ID_KEY)
span_id_list = getter.get(carrier, self.SPAN_ID_KEY)
carrier_keys = getter.keys(carrier)

if not trace_id_list or not span_id_list:
assert not any(key in carrier_keys for key in self.fields)
return context

assert all(key in carrier_keys for key in self.fields)
return trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(trace_id_list[0]),
span_id=int(span_id_list[0]),
is_remote=True,
)
),
context,
)

def inject(
self,
carrier: CarrierT,
context: typing.Optional[Context] = None,
setter: Setter = default_setter,
) -> None:
span = trace.get_current_span(context)
setter.set(
carrier, self.TRACE_ID_KEY, str(span.get_span_context().trace_id)
)
setter.set(
carrier, self.SPAN_ID_KEY, str(span.get_span_context().span_id)
)

@property
def fields(self):
return {self.TRACE_ID_KEY, self.SPAN_ID_KEY}


async def http_app_with_custom_headers(scope, receive, send):
message = await receive()
assert scope["type"] == "http"
Expand All @@ -34,6 +100,8 @@ async def http_app_with_custom_headers(scope, receive, send):
b"my-custom-regex-value-3,my-custom-regex-value-4",
),
(b"my-secret-header", b"my-secret-value"),
(MockTextMapPropagator.TRACE_ID_KEY.encode(), b"1"),
(MockTextMapPropagator.SPAN_ID_KEY.encode(), b"2"),
],
}
)
Expand All @@ -60,6 +128,8 @@ async def websocket_app_with_custom_headers(scope, receive, send):
b"my-custom-regex-value-3,my-custom-regex-value-4",
),
(b"my-secret-header", b"my-secret-value"),
(MockTextMapPropagator.TRACE_ID_KEY.encode(), b"1"),
(MockTextMapPropagator.SPAN_ID_KEY.encode(), b"2"),
],
}
)
Expand Down Expand Up @@ -88,6 +158,11 @@ def setUp(self):
self.app = otel_asgi.OpenTelemetryMiddleware(
simple_asgi, tracer_provider=self.tracer_provider
)
self.previous_propagator = get_global_textmap()
set_global_textmap(MockTextMapPropagator())

def tearDown(self):
set_global_textmap(self.previous_propagator)

def test_http_custom_request_headers_in_span_attributes(self):
self.scope["headers"].extend(
Expand Down

0 comments on commit afcf3c5

Please sign in to comment.