diff --git a/CHANGELOG.md b/CHANGELOG.md index 3e8f981e0e..1aca02bbe9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix exception in Urllib3 when dealing with filelike body. ([#1399](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1399)) +- Add request and response hooks for GRPC instrumentation (client only) + ([#14](https://github.com/helios/opentelemetry-python-contrib/pull/14)) + ### Added - Add connection attributes to sqlalchemy connect span diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py index 25010e147b..440d1facc8 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/__init__.py @@ -434,6 +434,8 @@ def __init__(self, filter_=None): else: filter_ = any_of(filter_, excluded_service_filter) self._filter = filter_ + self._request_hook = None + self._response_hook = None # Figures out which channel type we need to wrap def _which_channel(self, kwargs): @@ -455,6 +457,8 @@ def instrumentation_dependencies(self) -> Collection[str]: return _instruments def _instrument(self, **kwargs): + self._request_hook = kwargs.get("request_hook") + self._response_hook = kwargs.get("response_hook") for ctype in self._which_channel(kwargs): _wrap( "grpc", @@ -469,11 +473,15 @@ def _uninstrument(self, **kwargs): def wrapper_fn(self, original_func, instance, args, kwargs): channel = original_func(*args, **kwargs) tracer_provider = kwargs.get("tracer_provider") + request_hook = self._request_hook + response_hook = self._response_hook return intercept_channel( channel, client_interceptor( tracer_provider=tracer_provider, filter_=self._filter, + request_hook=request_hook, + response_hook=response_hook, ), ) @@ -499,6 +507,8 @@ def __init__(self, filter_=None): else: filter_ = any_of(filter_, excluded_service_filter) self._filter = filter_ + self._request_hook = None + self._response_hook = None def instrumentation_dependencies(self) -> Collection[str]: return _instruments @@ -507,13 +517,19 @@ def _add_interceptors(self, tracer_provider, kwargs): if "interceptors" in kwargs and kwargs["interceptors"]: kwargs["interceptors"] = ( aio_client_interceptors( - tracer_provider=tracer_provider, filter_=self._filter + tracer_provider=tracer_provider, + filter_=self._filter, + request_hook=self._request_hook, + response_hook=self._response_hook, ) + kwargs["interceptors"] ) else: kwargs["interceptors"] = aio_client_interceptors( - tracer_provider=tracer_provider, filter_=self._filter + tracer_provider=tracer_provider, + filter_=self._filter, + request_hook=self._request_hook, + response_hook=self._response_hook, ) return kwargs @@ -521,6 +537,8 @@ def _add_interceptors(self, tracer_provider, kwargs): def _instrument(self, **kwargs): self._original_insecure = grpc.aio.insecure_channel self._original_secure = grpc.aio.secure_channel + self._request_hook = kwargs.get("request_hook") + self._response_hook = kwargs.get("response_hook") tracer_provider = kwargs.get("tracer_provider") def insecure(*args, **kwargs): @@ -541,7 +559,9 @@ def _uninstrument(self, **kwargs): grpc.aio.secure_channel = self._original_secure -def client_interceptor(tracer_provider=None, filter_=None): +def client_interceptor( + tracer_provider=None, filter_=None, request_hook=None, response_hook=None +): """Create a gRPC client channel interceptor. Args: @@ -558,7 +578,12 @@ def client_interceptor(tracer_provider=None, filter_=None): tracer = trace.get_tracer(__name__, __version__, tracer_provider) - return _client.OpenTelemetryClientInterceptor(tracer, filter_=filter_) + return _client.OpenTelemetryClientInterceptor( + tracer, + filter_=filter_, + request_hook=request_hook, + response_hook=response_hook, + ) def server_interceptor(tracer_provider=None, filter_=None): @@ -581,7 +606,9 @@ def server_interceptor(tracer_provider=None, filter_=None): return _server.OpenTelemetryServerInterceptor(tracer, filter_=filter_) -def aio_client_interceptors(tracer_provider=None, filter_=None): +def aio_client_interceptors( + tracer_provider=None, filter_=None, request_hook=None, response_hook=None +): """Create a gRPC client channel interceptor. Args: @@ -595,10 +622,30 @@ def aio_client_interceptors(tracer_provider=None, filter_=None): tracer = trace.get_tracer(__name__, __version__, tracer_provider) return [ - _aio_client.UnaryUnaryAioClientInterceptor(tracer, filter_=filter_), - _aio_client.UnaryStreamAioClientInterceptor(tracer, filter_=filter_), - _aio_client.StreamUnaryAioClientInterceptor(tracer, filter_=filter_), - _aio_client.StreamStreamAioClientInterceptor(tracer, filter_=filter_), + _aio_client.UnaryUnaryAioClientInterceptor( + tracer, + filter_=filter_, + request_hook=request_hook, + response_hook=response_hook, + ), + _aio_client.UnaryStreamAioClientInterceptor( + tracer, + filter_=filter_, + request_hook=request_hook, + response_hook=response_hook, + ), + _aio_client.StreamUnaryAioClientInterceptor( + tracer, + filter_=filter_, + request_hook=request_hook, + response_hook=response_hook, + ), + _aio_client.StreamStreamAioClientInterceptor( + tracer, + filter_=filter_, + request_hook=request_hook, + response_hook=response_hook, + ), ] diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py index c7630bfe9f..5d5a5ccc46 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_aio_client.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import logging from collections import OrderedDict import grpc @@ -28,8 +29,10 @@ from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace.status import Status, StatusCode +logger = logging.getLogger(__name__) -def _unary_done_callback(span, code, details): + +def _unary_done_callback(span, code, details, response_hook): def callback(call): try: span.set_attribute( @@ -43,6 +46,8 @@ def callback(call): description=details, ) ) + response_hook(span, details) + finally: span.end() @@ -110,7 +115,11 @@ async def _wrap_unary_response(self, continuation, span): code = await call.code() details = await call.details() - call.add_done_callback(_unary_done_callback(span, code, details)) + call.add_done_callback( + _unary_done_callback( + span, code, details, self._call_response_hook + ) + ) return call except grpc.aio.AioRpcError as exc: @@ -120,6 +129,8 @@ async def _wrap_unary_response(self, continuation, span): async def _wrap_stream_response(self, span, call): try: async for response in call: + if self._response_hook: + self._call_response_hook(span, response) yield response except Exception as exc: self.add_error_details_to_span(span, exc) @@ -151,6 +162,9 @@ async def intercept_unary_unary( ) as span: new_details = self.propagate_trace_in_details(client_call_details) + if self._request_hook: + self._call_request_hook(span, request) + continuation_with_args = functools.partial( continuation, new_details, request ) @@ -175,7 +189,8 @@ async def intercept_unary_stream( new_details = self.propagate_trace_in_details(client_call_details) resp = await continuation(new_details, request) - + if self._request_hook: + self._call_request_hook(span, request) return self._wrap_stream_response(span, resp) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py index 55a46d4a49..b966fff4db 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_client.py @@ -19,8 +19,9 @@ """Implementation of the invocation-side open-telemetry interceptor.""" +import logging from collections import OrderedDict -from typing import MutableMapping +from typing import Callable, MutableMapping import grpc @@ -33,6 +34,8 @@ from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace.status import Status, StatusCode +logger = logging.getLogger(__name__) + class _CarrierSetter(Setter): """We use a custom setter in order to be able to lower case @@ -59,12 +62,27 @@ def callback(response_future): return callback +def _safe_invoke(function: Callable, *args): + function_name = "" + try: + function_name = function.__name__ + function(*args) + except Exception as ex: # pylint:disable=broad-except + logger.error( + "Error when invoking function '%s'", function_name, exc_info=ex + ) + + class OpenTelemetryClientInterceptor( grpcext.UnaryClientInterceptor, grpcext.StreamClientInterceptor ): - def __init__(self, tracer, filter_=None): + def __init__( + self, tracer, filter_=None, request_hook=None, response_hook=None + ): self._tracer = tracer self._filter = filter_ + self._request_hook = request_hook + self._response_hook = response_hook def _start_span(self, method, **kwargs): service, meth = method.lstrip("/").split("/", 1) @@ -99,6 +117,8 @@ def _trace_result(self, span, rpc_info, result): if isinstance(result, tuple): response = result[0] rpc_info.response = response + if self._response_hook: + self._call_response_hook(span, response) span.end() return result @@ -127,7 +147,8 @@ def _intercept(self, request, metadata, client_info, invoker): timeout=client_info.timeout, request=request, ) - + if self._request_hook: + self._call_request_hook(span, request) result = invoker(request, metadata) except Exception as exc: if isinstance(exc, grpc.RpcError): @@ -148,6 +169,16 @@ def _intercept(self, request, metadata, client_info, invoker): span.end() return self._trace_result(span, rpc_info, result) + def _call_request_hook(self, span, request): + if not callable(self._request_hook): + return + _safe_invoke(self._request_hook, span, request) + + def _call_response_hook(self, span, response): + if not callable(self._response_hook): + return + _safe_invoke(self._response_hook, span, response) + def intercept_unary(self, request, metadata, client_info, invoker): if self._filter is not None and not self._filter(client_info): return invoker(request, metadata) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_hooks.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_hooks.py new file mode 100644 index 0000000000..fe906b26c1 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_aio_client_interceptor_hooks.py @@ -0,0 +1,120 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +try: + from unittest import IsolatedAsyncioTestCase +except ImportError: + # unittest.IsolatedAsyncioTestCase was introduced in Python 3.8. It's use + # simplifies the following tests. Without it, the amount of test code + # increases significantly, with most of the additional code handling + # the asyncio set up. + from unittest import TestCase + + class IsolatedAsyncioTestCase(TestCase): + def run(self, result=None): + self.skipTest( + "This test requires Python 3.8 for unittest.IsolatedAsyncioTestCase" + ) + + +import grpc +import pytest + +from opentelemetry.instrumentation.grpc import GrpcAioInstrumentorClient +from opentelemetry.test.test_base import TestBase + +from ._aio_client import simple_method +from ._server import create_test_server +from .protobuf import test_server_pb2_grpc # pylint: disable=no-name-in-module + + +def request_hook(span, request): + span.set_attribute("request_data", request.request_data) + + +def response_hook(span, response): + span.set_attribute("response_data", response) + + +def request_hook_with_exception(_span, _request): + raise Exception() + + +def response_hook_with_exception(_span, _response): + raise Exception() + + +@pytest.mark.asyncio +class TestAioClientInterceptorWithHooks(TestBase, IsolatedAsyncioTestCase): + def setUp(self): + super().setUp() + self.server = create_test_server(25565) + self.server.start() + + def tearDown(self): + super().tearDown() + self.server.stop(None) + + async def test_request_and_response_hooks(self): + instrumentor = GrpcAioInstrumentorClient() + + try: + instrumentor.instrument( + request_hook=request_hook, + response_hook=response_hook, + ) + + channel = grpc.aio.insecure_channel( + "localhost:25565", + ) + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertIn("request_data", span.attributes) + self.assertEqual(span.attributes["request_data"], "data") + + self.assertIn("response_data", span.attributes) + self.assertEqual(span.attributes["response_data"], "") + finally: + instrumentor.uninstrument() + + async def test_hooks_with_exception(self): + instrumentor = GrpcAioInstrumentorClient() + + try: + instrumentor.instrument( + request_hook=request_hook_with_exception, + response_hook=response_hook_with_exception, + ) + + channel = grpc.aio.insecure_channel( + "localhost:25565", + ) + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + response = await simple_method(stub) + assert response.response_data == "data" + + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod") + finally: + instrumentor.uninstrument() diff --git a/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor_hooks.py b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor_hooks.py new file mode 100644 index 0000000000..ca649f7bb1 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-grpc/tests/test_client_interceptor_hooks.py @@ -0,0 +1,149 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +from tests.protobuf import ( # pylint: disable=no-name-in-module + test_server_pb2_grpc, +) + +from opentelemetry import trace +from opentelemetry.instrumentation.grpc import GrpcInstrumentorClient +from opentelemetry.test.test_base import TestBase + +from ._client import simple_method +from ._server import create_test_server + + +# User defined interceptor. Is used in the tests along with the opentelemetry client interceptor. +class Interceptor( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): + def __init__(self): + pass + + def intercept_unary_unary( + self, continuation, client_call_details, request + ): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_unary_stream( + self, continuation, client_call_details, request + ): + return self._intercept_call(continuation, client_call_details, request) + + def intercept_stream_unary( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator + ) + + def intercept_stream_stream( + self, continuation, client_call_details, request_iterator + ): + return self._intercept_call( + continuation, client_call_details, request_iterator + ) + + @staticmethod + def _intercept_call( + continuation, client_call_details, request_or_iterator + ): + return continuation(client_call_details, request_or_iterator) + + +def request_hook(span, request): + span.set_attribute("request_data", request.request_data) + + +def response_hook(span, response): + span.set_attribute("response_data", response.response_data) + + +def request_hook_with_exception(_span, _request): + raise Exception() + + +def response_hook_with_exception(_span, _response): + raise Exception() + + +class TestHooks(TestBase): + def setUp(self): + super().setUp() + self.server = create_test_server(25565) + self.server.start() + # use a user defined interceptor along with the opentelemetry client interceptor + self.interceptors = [Interceptor()] + + def tearDown(self): + super().tearDown() + self.server.stop(None) + + def test_response_and_request_hooks(self): + instrumentor = GrpcInstrumentorClient() + + try: + instrumentor.instrument( + request_hook=request_hook, + response_hook=response_hook, + ) + + channel = grpc.insecure_channel("localhost:25565") + channel = grpc.intercept_channel(channel, *self.interceptors) + + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + simple_method(stub) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + + self.assertIn("request_data", span.attributes) + self.assertEqual(span.attributes["request_data"], "data") + + self.assertIn("response_data", span.attributes) + self.assertEqual(span.attributes["response_data"], "data") + finally: + instrumentor.uninstrument() + + def test_hooks_with_exception(self): + instrumentor = GrpcInstrumentorClient() + + try: + instrumentor.instrument( + request_hook=request_hook_with_exception, + response_hook=response_hook_with_exception, + ) + + channel = grpc.insecure_channel("localhost:25565") + channel = grpc.intercept_channel(channel, *self.interceptors) + + stub = test_server_pb2_grpc.GRPCTestServerStub(channel) + + simple_method(stub) + spans = self.memory_exporter.get_finished_spans() + self.assertEqual(len(spans), 1) + span = spans[0] + + self.assertEqual(span.name, "/GRPCTestServer/SimpleMethod") + self.assertIs(span.kind, trace.SpanKind.CLIENT) + finally: + instrumentor.uninstrument()