diff --git a/ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py b/ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py index da6cc1ac79c..e9e10e40de1 100644 --- a/ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py +++ b/ext/opentelemetry-ext-grpc/src/opentelemetry/ext/grpc/__init__.py @@ -33,6 +33,8 @@ SimpleExportSpanProcessor, ) + from opentelemetry.sdk.metrics.export import ConsoleMetricsExporter + try: from .gen import helloworld_pb2, helloworld_pb2_grpc except ImportError: @@ -42,7 +44,9 @@ trace.get_tracer_provider().add_span_processor( SimpleExportSpanProcessor(ConsoleSpanExporter()) ) - instrumentor = GrpcInstrumentorClient() + + # Optional - export GRPC specific metrics (latency, bytes in/out, errors) by passing an exporter + instrumentor = GrpcInstrumentorClient(exporter=ConsoleMetricsExporter(), interval=10) instrumentor.instrument() def run(): @@ -109,6 +113,7 @@ def serve(): serve() """ from contextlib import contextmanager +from functools import partial import grpc from wrapt import wrap_function_wrapper as _wrap @@ -139,11 +144,21 @@ def wrapper_fn(self, original_func, instance, args, kwargs): class GrpcInstrumentorClient(BaseInstrumentor): def _instrument(self, **kwargs): + exporter = kwargs.get("exporter", None) + interval = kwargs.get("interval", 30) if kwargs.get("channel_type") == "secure": - _wrap("grpc", "secure_channel", self.wrapper_fn) + _wrap( + "grpc", + "secure_channel", + partial(self.wrapper_fn, exporter, interval), + ) else: - _wrap("grpc", "insecure_channel", self.wrapper_fn) + _wrap( + "grpc", + "insecure_channel", + partial(self.wrapper_fn, exporter, interval), + ) def _uninstrument(self, **kwargs): if kwargs.get("channel_type") == "secure": @@ -152,10 +167,19 @@ def _uninstrument(self, **kwargs): else: unwrap(grpc, "insecure_channel") - @contextmanager - def wrapper_fn(self, original_func, instance, args, kwargs): - with original_func(*args, **kwargs) as channel: - yield intercept_channel(channel, client_interceptor()) + def wrapper_fn( + self, exporter, interval, original_func, instance, args, kwargs + ): + channel = original_func(*args, **kwargs) + tracer_provider = kwargs.get("tracer_provider") + return intercept_channel( + channel, + client_interceptor( + tracer_provider=tracer_provider, + exporter=exporter, + interval=interval, + ), + ) def client_interceptor(tracer_provider=None, exporter=None, interval=30): diff --git a/ext/opentelemetry-ext-grpc/tests/test_client_interceptor.py b/ext/opentelemetry-ext-grpc/tests/test_client_interceptor.py index a5b24f3873b..3d9361b338d 100644 --- a/ext/opentelemetry-ext-grpc/tests/test_client_interceptor.py +++ b/ext/opentelemetry-ext-grpc/tests/test_client_interceptor.py @@ -16,7 +16,7 @@ import opentelemetry.ext.grpc from opentelemetry import trace -from opentelemetry.ext.grpc import client_interceptor +from opentelemetry.ext.grpc import GrpcInstrumentorClient from opentelemetry.ext.grpc.grpcext import intercept_channel from opentelemetry.sdk.metrics.export.aggregate import ( MinMaxSumCountAggregator, @@ -37,23 +37,23 @@ class TestClientProto(TestBase): def setUp(self): super().setUp() - self.server = create_test_server(25565) - self.server.start() - self.interceptor = client_interceptor( + GrpcInstrumentorClient().instrument( exporter=self.memory_metrics_exporter ) - self.channel = intercept_channel( - grpc.insecure_channel("localhost:25565"), self.interceptor - ) + self.server = create_test_server(25565) + self.server.start() + self.channel = grpc.insecure_channel("localhost:25565") self._stub = test_server_pb2_grpc.GRPCTestServerStub(self.channel) def tearDown(self): super().tearDown() + GrpcInstrumentorClient().uninstrument() self.memory_metrics_exporter.clear() self.server.stop(None) def _verify_success_records(self, num_bytes_out, num_bytes_in, method): - self.interceptor.controller.tick() + # pylint: disable=protected-access + self.channel._interceptor.controller.tick() records = self.memory_metrics_exporter.get_exported_metrics() self.assertEqual(len(records), 3) @@ -163,7 +163,8 @@ def test_stream_stream(self): ) def _verify_error_records(self, method): - self.interceptor.controller.tick() + # pylint: disable=protected-access + self.channel._interceptor.controller.tick() records = self.memory_metrics_exporter.get_exported_metrics() self.assertEqual(len(records), 3)