-
Notifications
You must be signed in to change notification settings - Fork 504
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
gRPC integration and aio interceptors (#2369)
Automatically add client and server interceptors to gRPC calls. Make it work with async gRPC servers and async gRPC client channels. --------- Co-authored-by: ali.sorouramini <[email protected]> Co-authored-by: Anton Pirker <[email protected]> Co-authored-by: Anton Pirker <[email protected]>
- Loading branch information
1 parent
2cb232e
commit 76af9d2
Showing
19 changed files
with
934 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[tool.black] | ||
# 'extend-exclude' excludes files or directories in addition to the defaults | ||
extend-exclude = ''' | ||
# A regex preceded with ^/ will apply only to files and directories | ||
# in the root of the project. | ||
( | ||
.*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project | ||
| .*_pb2_grpc.py # exclude autogenerated Protocol Buffer files anywhere in the project | ||
) | ||
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,152 @@ | ||
from .server import ServerInterceptor # noqa: F401 | ||
from .client import ClientInterceptor # noqa: F401 | ||
from functools import wraps | ||
|
||
import grpc | ||
from grpc import Channel, Server, intercept_channel | ||
from grpc.aio import Channel as AsyncChannel | ||
from grpc.aio import Server as AsyncServer | ||
|
||
from sentry_sdk.integrations import Integration | ||
from sentry_sdk._types import TYPE_CHECKING | ||
|
||
from .client import ClientInterceptor | ||
from .server import ServerInterceptor | ||
from .aio.server import ServerInterceptor as AsyncServerInterceptor | ||
from .aio.client import ( | ||
SentryUnaryUnaryClientInterceptor as AsyncUnaryUnaryClientInterceptor, | ||
) | ||
from .aio.client import ( | ||
SentryUnaryStreamClientInterceptor as AsyncUnaryStreamClientIntercetor, | ||
) | ||
|
||
from typing import Any, Optional, Sequence | ||
|
||
# Hack to get new Python features working in older versions | ||
# without introducing a hard dependency on `typing_extensions` | ||
# from: https://stackoverflow.com/a/71944042/300572 | ||
if TYPE_CHECKING: | ||
from typing import ParamSpec, Callable | ||
else: | ||
# Fake ParamSpec | ||
class ParamSpec: | ||
def __init__(self, _): | ||
self.args = None | ||
self.kwargs = None | ||
|
||
# Callable[anything] will return None | ||
class _Callable: | ||
def __getitem__(self, _): | ||
return None | ||
|
||
# Make instances | ||
Callable = _Callable() | ||
|
||
P = ParamSpec("P") | ||
|
||
|
||
def _wrap_channel_sync(func: Callable[P, Channel]) -> Callable[P, Channel]: | ||
"Wrapper for synchronous secure and insecure channel." | ||
|
||
@wraps(func) | ||
def patched_channel(*args: Any, **kwargs: Any) -> Channel: | ||
channel = func(*args, **kwargs) | ||
if not ClientInterceptor._is_intercepted: | ||
ClientInterceptor._is_intercepted = True | ||
return intercept_channel(channel, ClientInterceptor()) | ||
else: | ||
return channel | ||
|
||
return patched_channel | ||
|
||
|
||
def _wrap_intercept_channel(func: Callable[P, Channel]) -> Callable[P, Channel]: | ||
@wraps(func) | ||
def patched_intercept_channel( | ||
channel: Channel, *interceptors: grpc.ServerInterceptor | ||
) -> Channel: | ||
if ClientInterceptor._is_intercepted: | ||
interceptors = tuple( | ||
[ | ||
interceptor | ||
for interceptor in interceptors | ||
if not isinstance(interceptor, ClientInterceptor) | ||
] | ||
) | ||
else: | ||
interceptors = interceptors | ||
return intercept_channel(channel, *interceptors) | ||
|
||
return patched_intercept_channel # type: ignore | ||
|
||
|
||
def _wrap_channel_async(func: Callable[P, AsyncChannel]) -> Callable[P, AsyncChannel]: | ||
"Wrapper for asynchronous secure and insecure channel." | ||
|
||
@wraps(func) | ||
def patched_channel( | ||
*args: P.args, | ||
interceptors: Optional[Sequence[grpc.aio.ClientInterceptor]] = None, | ||
**kwargs: P.kwargs, | ||
) -> Channel: | ||
sentry_interceptors = [ | ||
AsyncUnaryUnaryClientInterceptor(), | ||
AsyncUnaryStreamClientIntercetor(), | ||
] | ||
interceptors = [*sentry_interceptors, *(interceptors or [])] | ||
return func(*args, interceptors=interceptors, **kwargs) # type: ignore | ||
|
||
return patched_channel # type: ignore | ||
|
||
|
||
def _wrap_sync_server(func: Callable[P, Server]) -> Callable[P, Server]: | ||
"""Wrapper for synchronous server.""" | ||
|
||
@wraps(func) | ||
def patched_server( | ||
*args: P.args, | ||
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, | ||
**kwargs: P.kwargs, | ||
) -> Server: | ||
interceptors = [ | ||
interceptor | ||
for interceptor in interceptors or [] | ||
if not isinstance(interceptor, ServerInterceptor) | ||
] | ||
server_interceptor = ServerInterceptor() | ||
interceptors = [server_interceptor, *(interceptors or [])] | ||
return func(*args, interceptors=interceptors, **kwargs) # type: ignore | ||
|
||
return patched_server # type: ignore | ||
|
||
|
||
def _wrap_async_server(func: Callable[P, AsyncServer]) -> Callable[P, AsyncServer]: | ||
"""Wrapper for asynchronous server.""" | ||
|
||
@wraps(func) | ||
def patched_aio_server( | ||
*args: P.args, | ||
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None, | ||
**kwargs: P.kwargs, | ||
) -> Server: | ||
server_interceptor = AsyncServerInterceptor() | ||
interceptors = [server_interceptor, *(interceptors or [])] | ||
return func(*args, interceptors=interceptors, **kwargs) # type: ignore | ||
|
||
return patched_aio_server # type: ignore | ||
|
||
|
||
class GRPCIntegration(Integration): | ||
identifier = "grpc" | ||
|
||
@staticmethod | ||
def setup_once() -> None: | ||
import grpc | ||
|
||
grpc.insecure_channel = _wrap_channel_sync(grpc.insecure_channel) | ||
grpc.secure_channel = _wrap_channel_sync(grpc.secure_channel) | ||
grpc.intercept_channel = _wrap_intercept_channel(grpc.intercept_channel) | ||
|
||
grpc.aio.insecure_channel = _wrap_channel_async(grpc.aio.insecure_channel) | ||
grpc.aio.secure_channel = _wrap_channel_async(grpc.aio.secure_channel) | ||
|
||
grpc.server = _wrap_sync_server(grpc.server) | ||
grpc.aio.server = _wrap_async_server(grpc.aio.server) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .server import ServerInterceptor # noqa: F401 | ||
from .client import ClientInterceptor # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
from typing import Callable, Union, AsyncIterable, Any | ||
|
||
from grpc.aio import ( | ||
UnaryUnaryClientInterceptor, | ||
UnaryStreamClientInterceptor, | ||
ClientCallDetails, | ||
UnaryUnaryCall, | ||
UnaryStreamCall, | ||
) | ||
from google.protobuf.message import Message | ||
|
||
from sentry_sdk import Hub | ||
from sentry_sdk.consts import OP | ||
|
||
|
||
class ClientInterceptor: | ||
@staticmethod | ||
def _update_client_call_details_metadata_from_hub( | ||
client_call_details: ClientCallDetails, hub: Hub | ||
) -> ClientCallDetails: | ||
metadata = ( | ||
list(client_call_details.metadata) if client_call_details.metadata else [] | ||
) | ||
for key, value in hub.iter_trace_propagation_headers(): | ||
metadata.append((key, value)) | ||
|
||
client_call_details = ClientCallDetails( | ||
method=client_call_details.method, | ||
timeout=client_call_details.timeout, | ||
metadata=metadata, | ||
credentials=client_call_details.credentials, | ||
wait_for_ready=client_call_details.wait_for_ready, | ||
) | ||
|
||
return client_call_details | ||
|
||
|
||
class SentryUnaryUnaryClientInterceptor(ClientInterceptor, UnaryUnaryClientInterceptor): # type: ignore | ||
async def intercept_unary_unary( | ||
self, | ||
continuation: Callable[[ClientCallDetails, Message], UnaryUnaryCall], | ||
client_call_details: ClientCallDetails, | ||
request: Message, | ||
) -> Union[UnaryUnaryCall, Message]: | ||
hub = Hub.current | ||
method = client_call_details.method | ||
|
||
with hub.start_span( | ||
op=OP.GRPC_CLIENT, description="unary unary call to %s" % method.decode() | ||
) as span: | ||
span.set_data("type", "unary unary") | ||
span.set_data("method", method) | ||
|
||
client_call_details = self._update_client_call_details_metadata_from_hub( | ||
client_call_details, hub | ||
) | ||
|
||
response = await continuation(client_call_details, request) | ||
status_code = await response.code() | ||
span.set_data("code", status_code.name) | ||
|
||
return response | ||
|
||
|
||
class SentryUnaryStreamClientInterceptor( | ||
ClientInterceptor, UnaryStreamClientInterceptor # type: ignore | ||
): | ||
async def intercept_unary_stream( | ||
self, | ||
continuation: Callable[[ClientCallDetails, Message], UnaryStreamCall], | ||
client_call_details: ClientCallDetails, | ||
request: Message, | ||
) -> Union[AsyncIterable[Any], UnaryStreamCall]: | ||
hub = Hub.current | ||
method = client_call_details.method | ||
|
||
with hub.start_span( | ||
op=OP.GRPC_CLIENT, description="unary stream call to %s" % method.decode() | ||
) as span: | ||
span.set_data("type", "unary stream") | ||
span.set_data("method", method) | ||
|
||
client_call_details = self._update_client_call_details_metadata_from_hub( | ||
client_call_details, hub | ||
) | ||
|
||
response = await continuation(client_call_details, request) | ||
# status_code = await response.code() | ||
# span.set_data("code", status_code) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from sentry_sdk import Hub | ||
from sentry_sdk._types import MYPY | ||
from sentry_sdk.consts import OP | ||
from sentry_sdk.integrations import DidNotEnable | ||
from sentry_sdk.tracing import Transaction, TRANSACTION_SOURCE_CUSTOM | ||
from sentry_sdk.utils import event_from_exception | ||
|
||
if MYPY: | ||
from collections.abc import Awaitable, Callable | ||
from typing import Any | ||
|
||
|
||
try: | ||
import grpc | ||
from grpc import HandlerCallDetails, RpcMethodHandler | ||
from grpc.aio import ServicerContext | ||
except ImportError: | ||
raise DidNotEnable("grpcio is not installed") | ||
|
||
|
||
class ServerInterceptor(grpc.aio.ServerInterceptor): # type: ignore | ||
def __init__(self, find_name=None): | ||
# type: (ServerInterceptor, Callable[[ServicerContext], str] | None) -> None | ||
self._find_method_name = find_name or self._find_name | ||
|
||
super(ServerInterceptor, self).__init__() | ||
|
||
async def intercept_service(self, continuation, handler_call_details): | ||
# type: (ServerInterceptor, Callable[[HandlerCallDetails], Awaitable[RpcMethodHandler]], HandlerCallDetails) -> Awaitable[RpcMethodHandler] | ||
self._handler_call_details = handler_call_details | ||
handler = await continuation(handler_call_details) | ||
|
||
if not handler.request_streaming and not handler.response_streaming: | ||
handler_factory = grpc.unary_unary_rpc_method_handler | ||
|
||
async def wrapped(request, context): | ||
# type: (Any, ServicerContext) -> Any | ||
name = self._find_method_name(context) | ||
if not name: | ||
return await handler(request, context) | ||
|
||
hub = Hub.current | ||
|
||
# What if the headers are empty? | ||
transaction = Transaction.continue_from_headers( | ||
dict(context.invocation_metadata()), | ||
op=OP.GRPC_SERVER, | ||
name=name, | ||
source=TRANSACTION_SOURCE_CUSTOM, | ||
) | ||
|
||
with hub.start_transaction(transaction=transaction): | ||
try: | ||
return await handler.unary_unary(request, context) | ||
except Exception as exc: | ||
event, hint = event_from_exception( | ||
exc, | ||
mechanism={"type": "grpc", "handled": False}, | ||
) | ||
hub.capture_event(event, hint=hint) | ||
raise | ||
|
||
elif not handler.request_streaming and handler.response_streaming: | ||
handler_factory = grpc.unary_stream_rpc_method_handler | ||
|
||
async def wrapped(request, context): # type: ignore | ||
# type: (Any, ServicerContext) -> Any | ||
async for r in handler.unary_stream(request, context): | ||
yield r | ||
|
||
elif handler.request_streaming and not handler.response_streaming: | ||
handler_factory = grpc.stream_unary_rpc_method_handler | ||
|
||
async def wrapped(request, context): | ||
# type: (Any, ServicerContext) -> Any | ||
response = handler.stream_unary(request, context) | ||
return await response | ||
|
||
elif handler.request_streaming and handler.response_streaming: | ||
handler_factory = grpc.stream_stream_rpc_method_handler | ||
|
||
async def wrapped(request, context): # type: ignore | ||
# type: (Any, ServicerContext) -> Any | ||
async for r in handler.stream_stream(request, context): | ||
yield r | ||
|
||
return handler_factory( | ||
wrapped, | ||
request_deserializer=handler.request_deserializer, | ||
response_serializer=handler.response_serializer, | ||
) | ||
|
||
def _find_name(self, context): | ||
# type: (ServicerContext) -> str | ||
return self._handler_call_details.method |
Oops, something went wrong.