Skip to content

Commit

Permalink
feat: support ThreadPool and update document
Browse files Browse the repository at this point in the history
  • Loading branch information
changemyminds committed Apr 10, 2024
1 parent 7079971 commit 57bd9c0
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ OpenTelemetry threading Instrumentation
:target: https://pypi.org/project/opentelemetry-instrumentation-threading/

This library provides instrumentation for the `threading` module to ensure that
the OpenTelemetry context is propagated across threads.
the OpenTelemetry context is propagated across threads. It is important to note
that this instrumentation does not produce any telemetry data on its own. It
merely ensures that the context is correctly propagated when threads are used.

Installation
------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "hatchling.build"
[project]
name = "opentelemetry-instrumentation-threading"
dynamic = ["version"]
description = "Threading tracing for OpenTelemetry"
description = "Thread context propagation support for OpenTelemetry"
readme = "README.rst"
license = "Apache-2.0"
requires-python = ">=3.8"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,72 +24,103 @@
ThreadingInstrumentor().instrument()
This library provides instrumentation for the `threading` module to ensure that
the OpenTelemetry context is propagated across threads.
the OpenTelemetry context is propagated across threads. It is important to note
that this instrumentation does not produce any telemetry data on its own. It
merely ensures that the context is correctly propagated when threads are used.
When instrumented, new threads created using `threading.Thread` or `threading.Timer`
will have the current OpenTelemetry context attached, and this context will be
re-activated in the thread's run method.
When instrumented, new threads created using threading.Thread, threading.Timer,
or within futures.ThreadPoolExecutor will have the current OpenTelemetry
context attached, and this context will be re-activated in the thread's
run method or the executor's worker thread."
"""

import threading
from concurrent import futures
from typing import Collection

from wrapt import wrap_function_wrapper

from opentelemetry import context, trace
from opentelemetry import context
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.instrumentation.threading.package import _instruments
from opentelemetry.instrumentation.utils import unwrap


class ThreadingInstrumentor(BaseInstrumentor):
__WRAPPER_START_METHOD = "start"
__WRAPPER_RUN_METHOD = "run"
__WRAPPER_SUBMIT_METHOD = "submit"
__WRAPPER_KWARGS = "kwargs"
__WRAPPER_CONTEXT = "_otel_context"

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
self._instrument_thread()
self._instrument_timer()
self._instrument_thread_pool()

def _uninstrument(self, **kwargs):
self._uninstrument_thread()
self._uninstrument_timer()
self._uninstrument_thread_pool()

@staticmethod
def _instrument_thread():
wrap_function_wrapper(
threading.Thread,
"start",
ThreadingInstrumentor.__WRAPPER_START_METHOD,
ThreadingInstrumentor.__wrap_threading_start,
)
wrap_function_wrapper(
threading.Thread, "run", ThreadingInstrumentor.__wrap_threading_run
threading.Thread,
ThreadingInstrumentor.__WRAPPER_RUN_METHOD,
ThreadingInstrumentor.__wrap_threading_run,
)

@staticmethod
def _instrument_timer():
wrap_function_wrapper(
threading.Timer,
"start",
ThreadingInstrumentor.__WRAPPER_START_METHOD,
ThreadingInstrumentor.__wrap_threading_start,
)
wrap_function_wrapper(
threading.Timer, "run", ThreadingInstrumentor.__wrap_threading_run
threading.Timer,
ThreadingInstrumentor.__WRAPPER_RUN_METHOD,
ThreadingInstrumentor.__wrap_threading_run,
)

@staticmethod
def _instrument_thread_pool():
wrap_function_wrapper(
futures.ThreadPoolExecutor,
ThreadingInstrumentor.__WRAPPER_SUBMIT_METHOD,
ThreadingInstrumentor.__wrap_thread_pool_submit,
)

@staticmethod
def _uninstrument_thread():
unwrap(threading.Thread, "start")
unwrap(threading.Thread, "run")
unwrap(threading.Thread, ThreadingInstrumentor.__WRAPPER_START_METHOD)
unwrap(threading.Thread, ThreadingInstrumentor.__WRAPPER_RUN_METHOD)

@staticmethod
def _uninstrument_timer():
unwrap(threading.Timer, "start")
unwrap(threading.Timer, "run")
unwrap(threading.Timer, ThreadingInstrumentor.__WRAPPER_START_METHOD)
unwrap(threading.Timer, ThreadingInstrumentor.__WRAPPER_RUN_METHOD)

@staticmethod
def _uninstrument_thread_pool():
unwrap(
futures.ThreadPoolExecutor,
ThreadingInstrumentor.__WRAPPER_SUBMIT_METHOD,
)

@staticmethod
def __wrap_threading_start(call_wrapped, instance, args, kwargs):
span = trace.get_current_span()
instance._otel_context = trace.set_span_in_context(span)
instance._otel_context = context.get_current()
return call_wrapped(*args, **kwargs)

@staticmethod
Expand All @@ -100,3 +131,30 @@ def __wrap_threading_run(call_wrapped, instance, args, kwargs):
return call_wrapped(*args, **kwargs)
finally:
context.detach(token)

@staticmethod
def __wrap_thread_pool_submit(call_wrapped, instance, args, kwargs):
# obtain the original function and wrapped kwargs
original_func = args[0]
wrapped_kwargs = {
ThreadingInstrumentor.__WRAPPER_KWARGS: kwargs,
ThreadingInstrumentor.__WRAPPER_CONTEXT: context.get_current(),
}

def wrapped_func(*func_args, **func_kwargs):
original_kwargs = func_kwargs.pop(
ThreadingInstrumentor.__WRAPPER_KWARGS
)
otel_context = func_kwargs.pop(
ThreadingInstrumentor.__WRAPPER_CONTEXT
)
token = None
try:
token = context.attach(otel_context)
return original_func(*func_args, **original_kwargs)
finally:
context.detach(token)

# replace the original function with the wrapped function
new_args = (wrapped_func,) + args[1:]
return call_wrapped(*new_args, **wrapped_kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,27 @@
# limitations under the License.

import threading
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from typing import List

from opentelemetry import trace
from opentelemetry.instrumentation.threading import ThreadingInstrumentor
from opentelemetry.test.test_base import TestBase


@dataclass
class MockContext:
span_context: trace.SpanContext = None
trace_id: int = None
span_id: int = None


class TestThreading(TestBase):
def setUp(self):
super().setUp()
self._tracer = self.tracer_provider.get_tracer(__name__)
self.global_context = None
self.global_trace_id = None
self.global_span_id = None
self._mock_contexts: List[MockContext] = []
ThreadingInstrumentor().instrument()

def tearDown(self):
Expand Down Expand Up @@ -54,31 +61,48 @@ def run_threading_test(self, thread: threading.Thread):
thread.join()

# check result
self.assertEqual(self.global_context, expected_context)
self.assertEqual(self.global_trace_id, expected_trace_id)
self.assertEqual(self.global_span_id, expected_span_id)
self.assertEqual(len(self._mock_contexts), 1)

def test_trace_context_propagation_in_thread_pool(self):
with self.get_root_span() as span:
span_context = span.get_span_context()
expected_context = span_context
expected_trace_id = span_context.trace_id
expected_span_id = span_context.span_id
current_mock_context = self._mock_contexts[0]
self.assertEqual(
current_mock_context.span_context, expected_context
)
self.assertEqual(current_mock_context.trace_id, expected_trace_id)
self.assertEqual(current_mock_context.span_id, expected_span_id)

with futures.ThreadPoolExecutor(max_workers=1) as executor:
def test_trace_context_propagation_in_thread_pool(self):
max_workers = 10
executor = ThreadPoolExecutor(max_workers=max_workers)

expected_contexts: List[trace.SpanContext] = []
futures_list = []
for num in range(max_workers):
with self._tracer.start_as_current_span(f"trace_{num}") as span:
span_context = span.get_span_context()
expected_contexts.append(span_context)
future = executor.submit(self.fake_func)
future.result()
futures_list.append(future)

for future in as_completed(futures_list):
future.result()

# check result
self.assertEqual(self.global_context, expected_context)
self.assertEqual(self.global_trace_id, expected_trace_id)
self.assertEqual(self.global_span_id, expected_span_id)
# check result
self.assertEqual(len(self._mock_contexts), max_workers)
self.assertEqual(len(self._mock_contexts), len(expected_contexts))
for index, mock_context in enumerate(self._mock_contexts):
span_context = expected_contexts[index]
self.assertEqual(mock_context.span_context, span_context)
self.assertEqual(mock_context.trace_id, span_context.trace_id)
self.assertEqual(mock_context.span_id, span_context.span_id)

def fake_func(self):
span_context = trace.get_current_span().get_span_context()
self.global_context = span_context
self.global_trace_id = span_context.trace_id
self.global_span_id = span_context.span_id
mock_context = MockContext(
span_context=span_context,
trace_id=span_context.trace_id,
span_id=span_context.span_id,
)
self._mock_contexts.append(mock_context)

def print_square(self, num):
with self._tracer.start_as_current_span("square"):
Expand Down

0 comments on commit 57bd9c0

Please sign in to comment.