Skip to content

Commit

Permalink
gRPC integration and aio interceptors (#2369)
Browse files Browse the repository at this point in the history
Automatically add client and server interceptors to gRPC calls. Make it work with async gRPC servers and async gRPC client channels.

---------

Co-authored-by: ali.sorouramini <[email protected]>
Co-authored-by: Anton Pirker <[email protected]>
Co-authored-by: Anton Pirker <[email protected]>
  • Loading branch information
4 people authored Nov 8, 2023
1 parent 2cb232e commit 76af9d2
Show file tree
Hide file tree
Showing 19 changed files with 934 additions and 111 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: 22.6.0
hooks:
- id: black
exclude: ^(.*_pb2.py|.*_pb2_grpc.py)

- repo: https://github.com/pycqa/flake8
rev: 5.0.4
Expand Down
1 change: 1 addition & 0 deletions linter-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mypy
black
flake8==5.0.4 # flake8 depends on pyflakes>=3.0.0 and this dropped support for Python 2 "# type:" comments
types-certifi
types-protobuf
types-redis
types-setuptools
pymongo # There is no separate types module.
Expand Down
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[tool.black]
# 'extend-exclude' excludes files or directories in addition to the defaults
extend-exclude = '''
# A regex preceded with ^/ will apply only to files and directories
# in the root of the project.
(
.*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project
| .*_pb2_grpc.py # exclude autogenerated Protocol Buffer files anywhere in the project
)
'''
154 changes: 152 additions & 2 deletions sentry_sdk/integrations/grpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,152 @@
from .server import ServerInterceptor # noqa: F401
from .client import ClientInterceptor # noqa: F401
from functools import wraps

import grpc
from grpc import Channel, Server, intercept_channel
from grpc.aio import Channel as AsyncChannel
from grpc.aio import Server as AsyncServer

from sentry_sdk.integrations import Integration
from sentry_sdk._types import TYPE_CHECKING

from .client import ClientInterceptor
from .server import ServerInterceptor
from .aio.server import ServerInterceptor as AsyncServerInterceptor
from .aio.client import (
SentryUnaryUnaryClientInterceptor as AsyncUnaryUnaryClientInterceptor,
)
from .aio.client import (
SentryUnaryStreamClientInterceptor as AsyncUnaryStreamClientIntercetor,
)

from typing import Any, Optional, Sequence

# Hack to get new Python features working in older versions
# without introducing a hard dependency on `typing_extensions`
# from: https://stackoverflow.com/a/71944042/300572
if TYPE_CHECKING:
from typing import ParamSpec, Callable
else:
# Fake ParamSpec
class ParamSpec:
def __init__(self, _):
self.args = None
self.kwargs = None

# Callable[anything] will return None
class _Callable:
def __getitem__(self, _):
return None

# Make instances
Callable = _Callable()

P = ParamSpec("P")


def _wrap_channel_sync(func: Callable[P, Channel]) -> Callable[P, Channel]:
"Wrapper for synchronous secure and insecure channel."

@wraps(func)
def patched_channel(*args: Any, **kwargs: Any) -> Channel:
channel = func(*args, **kwargs)
if not ClientInterceptor._is_intercepted:
ClientInterceptor._is_intercepted = True
return intercept_channel(channel, ClientInterceptor())
else:
return channel

return patched_channel


def _wrap_intercept_channel(func: Callable[P, Channel]) -> Callable[P, Channel]:
@wraps(func)
def patched_intercept_channel(
channel: Channel, *interceptors: grpc.ServerInterceptor
) -> Channel:
if ClientInterceptor._is_intercepted:
interceptors = tuple(
[
interceptor
for interceptor in interceptors
if not isinstance(interceptor, ClientInterceptor)
]
)
else:
interceptors = interceptors
return intercept_channel(channel, *interceptors)

return patched_intercept_channel # type: ignore


def _wrap_channel_async(func: Callable[P, AsyncChannel]) -> Callable[P, AsyncChannel]:
"Wrapper for asynchronous secure and insecure channel."

@wraps(func)
def patched_channel(
*args: P.args,
interceptors: Optional[Sequence[grpc.aio.ClientInterceptor]] = None,
**kwargs: P.kwargs,
) -> Channel:
sentry_interceptors = [
AsyncUnaryUnaryClientInterceptor(),
AsyncUnaryStreamClientIntercetor(),
]
interceptors = [*sentry_interceptors, *(interceptors or [])]
return func(*args, interceptors=interceptors, **kwargs) # type: ignore

return patched_channel # type: ignore


def _wrap_sync_server(func: Callable[P, Server]) -> Callable[P, Server]:
"""Wrapper for synchronous server."""

@wraps(func)
def patched_server(
*args: P.args,
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
**kwargs: P.kwargs,
) -> Server:
interceptors = [
interceptor
for interceptor in interceptors or []
if not isinstance(interceptor, ServerInterceptor)
]
server_interceptor = ServerInterceptor()
interceptors = [server_interceptor, *(interceptors or [])]
return func(*args, interceptors=interceptors, **kwargs) # type: ignore

return patched_server # type: ignore


def _wrap_async_server(func: Callable[P, AsyncServer]) -> Callable[P, AsyncServer]:
"""Wrapper for asynchronous server."""

@wraps(func)
def patched_aio_server(
*args: P.args,
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None,
**kwargs: P.kwargs,
) -> Server:
server_interceptor = AsyncServerInterceptor()
interceptors = [server_interceptor, *(interceptors or [])]
return func(*args, interceptors=interceptors, **kwargs) # type: ignore

return patched_aio_server # type: ignore


class GRPCIntegration(Integration):
identifier = "grpc"

@staticmethod
def setup_once() -> None:
import grpc

grpc.insecure_channel = _wrap_channel_sync(grpc.insecure_channel)
grpc.secure_channel = _wrap_channel_sync(grpc.secure_channel)
grpc.intercept_channel = _wrap_intercept_channel(grpc.intercept_channel)

grpc.aio.insecure_channel = _wrap_channel_async(grpc.aio.insecure_channel)
grpc.aio.secure_channel = _wrap_channel_async(grpc.aio.secure_channel)

grpc.server = _wrap_sync_server(grpc.server)
grpc.aio.server = _wrap_async_server(grpc.aio.server)
2 changes: 2 additions & 0 deletions sentry_sdk/integrations/grpc/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .server import ServerInterceptor # noqa: F401
from .client import ClientInterceptor # noqa: F401
91 changes: 91 additions & 0 deletions sentry_sdk/integrations/grpc/aio/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Callable, Union, AsyncIterable, Any

from grpc.aio import (
UnaryUnaryClientInterceptor,
UnaryStreamClientInterceptor,
ClientCallDetails,
UnaryUnaryCall,
UnaryStreamCall,
)
from google.protobuf.message import Message

from sentry_sdk import Hub
from sentry_sdk.consts import OP


class ClientInterceptor:
@staticmethod
def _update_client_call_details_metadata_from_hub(
client_call_details: ClientCallDetails, hub: Hub
) -> ClientCallDetails:
metadata = (
list(client_call_details.metadata) if client_call_details.metadata else []
)
for key, value in hub.iter_trace_propagation_headers():
metadata.append((key, value))

client_call_details = ClientCallDetails(
method=client_call_details.method,
timeout=client_call_details.timeout,
metadata=metadata,
credentials=client_call_details.credentials,
wait_for_ready=client_call_details.wait_for_ready,
)

return client_call_details


class SentryUnaryUnaryClientInterceptor(ClientInterceptor, UnaryUnaryClientInterceptor): # type: ignore
async def intercept_unary_unary(
self,
continuation: Callable[[ClientCallDetails, Message], UnaryUnaryCall],
client_call_details: ClientCallDetails,
request: Message,
) -> Union[UnaryUnaryCall, Message]:
hub = Hub.current
method = client_call_details.method

with hub.start_span(
op=OP.GRPC_CLIENT, description="unary unary call to %s" % method.decode()
) as span:
span.set_data("type", "unary unary")
span.set_data("method", method)

client_call_details = self._update_client_call_details_metadata_from_hub(
client_call_details, hub
)

response = await continuation(client_call_details, request)
status_code = await response.code()
span.set_data("code", status_code.name)

return response


class SentryUnaryStreamClientInterceptor(
ClientInterceptor, UnaryStreamClientInterceptor # type: ignore
):
async def intercept_unary_stream(
self,
continuation: Callable[[ClientCallDetails, Message], UnaryStreamCall],
client_call_details: ClientCallDetails,
request: Message,
) -> Union[AsyncIterable[Any], UnaryStreamCall]:
hub = Hub.current
method = client_call_details.method

with hub.start_span(
op=OP.GRPC_CLIENT, description="unary stream call to %s" % method.decode()
) as span:
span.set_data("type", "unary stream")
span.set_data("method", method)

client_call_details = self._update_client_call_details_metadata_from_hub(
client_call_details, hub
)

response = await continuation(client_call_details, request)
# status_code = await response.code()
# span.set_data("code", status_code)

return response
95 changes: 95 additions & 0 deletions sentry_sdk/integrations/grpc/aio/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from sentry_sdk import Hub
from sentry_sdk._types import MYPY
from sentry_sdk.consts import OP
from sentry_sdk.integrations import DidNotEnable
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM
from sentry_sdk.utils import event_from_exception

if MYPY:
from collections.abc import Awaitable, Callable
from typing import Any


try:
import grpc
from grpc import HandlerCallDetails, RpcMethodHandler
from grpc.aio import ServicerContext
except ImportError:
raise DidNotEnable("grpcio is not installed")


class ServerInterceptor(grpc.aio.ServerInterceptor): # type: ignore
def __init__(self, find_name=None):
# type: (ServerInterceptor, Callable[[ServicerContext], str] | None) -> None
self._find_method_name = find_name or self._find_name

super(ServerInterceptor, self).__init__()

async def intercept_service(self, continuation, handler_call_details):
# type: (ServerInterceptor, Callable[[HandlerCallDetails], Awaitable[RpcMethodHandler]], HandlerCallDetails) -> Awaitable[RpcMethodHandler]
self._handler_call_details = handler_call_details
handler = await continuation(handler_call_details)

if not handler.request_streaming and not handler.response_streaming:
handler_factory = grpc.unary_unary_rpc_method_handler

async def wrapped(request, context):
# type: (Any, ServicerContext) -> Any
name = self._find_method_name(context)
if not name:
return await handler(request, context)

hub = Hub.current

# What if the headers are empty?
transaction = Transaction.continue_from_headers(
dict(context.invocation_metadata()),
op=OP.GRPC_SERVER,
name=name,
source=TRANSACTION_SOURCE_CUSTOM,
)

with hub.start_transaction(transaction=transaction):
try:
return await handler.unary_unary(request, context)
except Exception as exc:
event, hint = event_from_exception(
exc,
mechanism={"type": "grpc", "handled": False},
)
hub.capture_event(event, hint=hint)
raise

elif not handler.request_streaming and handler.response_streaming:
handler_factory = grpc.unary_stream_rpc_method_handler

async def wrapped(request, context): # type: ignore
# type: (Any, ServicerContext) -> Any
async for r in handler.unary_stream(request, context):
yield r

elif handler.request_streaming and not handler.response_streaming:
handler_factory = grpc.stream_unary_rpc_method_handler

async def wrapped(request, context):
# type: (Any, ServicerContext) -> Any
response = handler.stream_unary(request, context)
return await response

elif handler.request_streaming and handler.response_streaming:
handler_factory = grpc.stream_stream_rpc_method_handler

async def wrapped(request, context): # type: ignore
# type: (Any, ServicerContext) -> Any
async for r in handler.stream_stream(request, context):
yield r

return handler_factory(
wrapped,
request_deserializer=handler.request_deserializer,
response_serializer=handler.response_serializer,
)

def _find_name(self, context):
# type: (ServicerContext) -> str
return self._handler_call_details.method
Loading

0 comments on commit 76af9d2

Please sign in to comment.