Skip to content

Commit

Permalink
Add Datadog propagator
Browse files Browse the repository at this point in the history
  • Loading branch information
majorgreys committed May 20, 2020
1 parent 9e58b8a commit 737c420
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
DD_ORIGIN = "_dd_origin"
AUTO_REJECT = 0
AUTO_KEEP = 1
USER_KEEP = 2
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from opentelemetry.sdk.trace.export import SpanExporter, SpanExportResult
from opentelemetry.trace.status import StatusCanonicalCode

# pylint:disable=relative-beyond-top-level
from .constants import DD_ORIGIN

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -128,6 +131,11 @@ def _translate_to_datadog(self, spans):

datadog_span.set_tags(span.attributes)

# add origin to root span
origin = _get_origin(span)
if origin and parent_id == 0:
datadog_span.set_tag(DD_ORIGIN, origin)

# span events and span links are not supported

datadog_spans.append(datadog_span)
Expand Down Expand Up @@ -202,3 +210,9 @@ def _get_exc_info(span):
"""Parse span status description for exception type and value"""
exc_type, exc_val = span.status.description.split(":", 1)
return exc_type, exc_val.strip()


def _get_origin(span):
ctx = span.get_context()
origin = ctx.trace_state.get(DD_ORIGIN)
return origin
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing

from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.trace.propagation import (
get_span_from_context,
set_span_in_context,
)
from opentelemetry.trace.propagation.httptextformat import (
Getter,
HTTPTextFormat,
HTTPTextFormatT,
Setter,
)

# pylint:disable=relative-beyond-top-level
from . import constants


class DatadogFormat(HTTPTextFormat):
"""Propagator for the Datadog HTTP header format.
"""

TRACE_ID_KEY = "x-datadog-trace-id"
PARENT_ID_KEY = "x-datadog-parent-id"
SAMPLING_PRIORITY_KEY = "x-datadog-sampling-priority"
ORIGIN_KEY = "x-datadog-origin"

def extract(
self,
get_from_carrier: Getter[HTTPTextFormatT],
carrier: HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> Context:
trace_id = extract_first_element(
get_from_carrier(carrier, self.TRACE_ID_KEY)
)

span_id = extract_first_element(
get_from_carrier(carrier, self.PARENT_ID_KEY)
)

sampled = extract_first_element(
get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY)
)

origin = extract_first_element(
get_from_carrier(carrier, self.ORIGIN_KEY)
)

trace_flags = trace.TraceFlags()
if sampled and int(sampled) in (
constants.AUTO_KEEP,
constants.USER_KEEP,
):
trace_flags |= trace.TraceFlags.SAMPLED

if trace_id is None or span_id is None:
return set_span_in_context(trace.INVALID_SPAN, context)

span_context = trace.SpanContext(
trace_id=int(trace_id),
span_id=int(span_id),
is_remote=True,
trace_flags=trace_flags,
trace_state=trace.TraceState({constants.DD_ORIGIN: origin}),
)

return set_span_in_context(trace.DefaultSpan(span_context), context)

def inject(
self,
set_in_carrier: Setter[HTTPTextFormatT],
carrier: HTTPTextFormatT,
context: typing.Optional[Context] = None,
) -> None:
span = get_span_from_context(context=context)
sampled = (trace.TraceFlags.SAMPLED & span.context.trace_flags) != 0
set_in_carrier(
carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id),
)
set_in_carrier(
carrier, self.PARENT_ID_KEY, format_span_id(span.context.span_id)
)
set_in_carrier(
carrier,
self.SAMPLING_PRIORITY_KEY,
str(constants.AUTO_KEEP if sampled else constants.AUTO_REJECT),
)
if constants.DD_ORIGIN in span.context.trace_state:
set_in_carrier(
carrier,
self.ORIGIN_KEY,
span.context.trace_state[constants.DD_ORIGIN],
)


def format_trace_id(trace_id: int) -> str:
"""Format the trace id for Datadog."""
return str(trace_id & 0xFFFFFFFFFFFFFFFF)


def format_span_id(span_id: int) -> str:
"""Format the span id for Datadog."""
return str(span_id)


def extract_first_element(
items: typing.Iterable[HTTPTextFormatT],
) -> typing.Optional[HTTPTextFormatT]:
if items is None:
return None
return next(iter(items), None)
37 changes: 37 additions & 0 deletions ext/opentelemetry-ext-datadog/tests/test_datadog_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,40 @@ def test_span_processor_scheduled_delay(self):
self.assertEqual(len(datadog_spans), 1)

tracer_provider.shutdown()

def test_origin(self):
context = trace_api.SpanContext(
trace_id=0x000000000000000000000000DEADBEEF,
span_id=trace_api.INVALID_SPAN,
is_remote=True,
trace_state=trace_api.TraceState(
{datadog.constants.DD_ORIGIN: "origin-service"}
),
)

root_span = trace.Span(name="root", context=context, parent=None)
child_span = trace.Span(
name="child", context=context, parent=root_span
)
root_span.start()
child_span.start()
child_span.end()
root_span.end()

# pylint: disable=protected-access
exporter = datadog.DatadogSpanExporter()
datadog_spans = [
span.to_dict()
for span in exporter._translate_to_datadog([root_span, child_span])
]

self.assertEqual(len(datadog_spans), 2)

actual = [
span["meta"].get(datadog.constants.DD_ORIGIN)
if "meta" in span
else None
for span in datadog_spans
]
expected = ["origin-service", None]
self.assertListEqual(actual, expected)
173 changes: 173 additions & 0 deletions ext/opentelemetry-ext-datadog/tests/test_datadog_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from opentelemetry import trace as trace_api
from opentelemetry.ext.datadog import constants, propagator
from opentelemetry.sdk import trace
from opentelemetry.trace.propagation import (
get_span_from_context,
set_span_in_context,
)

FORMAT = propagator.DatadogFormat()


def get_as_list(dict_object, key):
value = dict_object.get(key)
return [value] if value is not None else []


class TestDatadogFormat(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.serialized_trace_id = propagator.format_trace_id(
trace.generate_trace_id()
)
cls.serialized_parent_id = propagator.format_span_id(
trace.generate_span_id()
)
cls.serialized_origin = "origin-service"

def test_malformed_headers(self):
"""Test with no Datadog headers"""
malformed_trace_id_key = FORMAT.TRACE_ID_KEY + "-x"
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_span_from_context(
FORMAT.extract(
get_as_list,
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
},
)
).get_context()

self.assertNotEqual(context.trace_id, int(self.serialized_trace_id))
self.assertNotEqual(context.span_id, int(self.serialized_parent_id))
self.assertFalse(context.is_remote)

def test_missing_trace_id(self):
"""If a trace id is missing, populate an invalid trace id."""
carrier = {
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = get_span_from_context(ctx).get_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)

def test_missing_parent_id(self):
"""If a parent id is missing, populate an invalid trace id."""
carrier = {
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}

ctx = FORMAT.extract(get_as_list, carrier)
span_context = get_span_from_context(ctx).get_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
parent_context = get_span_from_context(
FORMAT.extract(
get_as_list,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
FORMAT.SAMPLING_PRIORITY_KEY: str(constants.AUTO_KEEP),
FORMAT.ORIGIN_KEY: self.serialized_origin,
},
)
).get_context()

self.assertEqual(
parent_context.trace_id, int(self.serialized_trace_id)
)
self.assertEqual(
parent_context.span_id, int(self.serialized_parent_id)
)
self.assertEqual(parent_context.trace_flags, constants.AUTO_KEEP)
self.assertEqual(
parent_context.trace_state.get(constants.DD_ORIGIN),
self.serialized_origin,
)
self.assertTrue(parent_context.is_remote)

child = trace.Span(
"child",
trace_api.SpanContext(
parent_context.trace_id,
trace.generate_span_id(),
is_remote=False,
trace_flags=parent_context.trace_flags,
trace_state=parent_context.trace_state,
),
parent=parent_context,
)

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
)
self.assertEqual(
child_carrier[FORMAT.PARENT_ID_KEY], str(child.context.span_id)
)
self.assertEqual(
child_carrier[FORMAT.SAMPLING_PRIORITY_KEY],
str(constants.AUTO_KEEP),
)
self.assertEqual(
child_carrier.get(FORMAT.ORIGIN_KEY), self.serialized_origin
)

def test_sampling_priority_auto_reject(self):
"""Test sampling priority rejected."""
parent_context = get_span_from_context(
FORMAT.extract(
get_as_list,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
FORMAT.SAMPLING_PRIORITY_KEY: str(constants.AUTO_REJECT),
},
)
).get_context()

self.assertEqual(parent_context.trace_flags, constants.AUTO_REJECT)

child = trace.Span(
"child",
trace_api.SpanContext(
parent_context.trace_id,
trace.generate_span_id(),
is_remote=False,
trace_flags=parent_context.trace_flags,
trace_state=parent_context.trace_state,
),
parent=parent_context,
)

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.SAMPLING_PRIORITY_KEY],
str(constants.AUTO_REJECT),
)

0 comments on commit 737c420

Please sign in to comment.