Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interceptor support #45

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ flake8>=5.0.4
pytest-cov>=4.0.0
pytest-asyncio>=0.15.1
aiounittest==1.4.2

grpc-interceptor==0.15.4
ViridianForge marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 7 additions & 3 deletions src/grpc_requests/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .client import CredentialsInfo
from .utils import load_data

logger = logging.getLogger(__name__)

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -55,7 +56,7 @@ def reflection_request(channel, requests):

class BaseAsyncClient:
def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_options=None, ssl=False,
compression=None, credentials: Optional[CredentialsInfo] = None, **kwargs):
compression=None, credentials: Optional[CredentialsInfo] = None, interceptors=None, **kwargs):
self.endpoint = endpoint
self._symbol_db = symbol_db or _symbol_database.Default()
self._desc_pool = descriptor_pool or _descriptor_pool.Default()
Expand All @@ -71,10 +72,13 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_optio

self._channel = grpc.aio.secure_channel(endpoint, grpc.ssl_channel_credentials(**_credentials),
options=self.channel_options,
compression=self.compression)
compression=self.compression,
interceptors=interceptors)

else:
self._channel = grpc.aio.insecure_channel(endpoint, options=self.channel_options,
compression=self.compression)
compression=self.compression,
interceptors=interceptors)

@property
def channel(self):
Expand Down
5 changes: 4 additions & 1 deletion src/grpc_requests/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class CredentialsInfo(TypedDict):

class BaseClient:
def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_options=None, ssl=False,
compression=None, credentials: Optional[CredentialsInfo] = None, **kwargs):
compression=None, credentials: Optional[CredentialsInfo] = None, interceptors=None, **kwargs):
self.endpoint = endpoint
self._desc_pool = descriptor_pool or _descriptor_pool.Default()
self.compression = compression
Expand All @@ -86,6 +86,9 @@ def __init__(self, endpoint, symbol_db=None, descriptor_pool=None, channel_optio
else:
self._channel = grpc.insecure_channel(endpoint, options=self.channel_options, compression=self.compression)

if interceptors:
self._channel = grpc.intercept_channel(self._channel, *interceptors)

@property
def channel(self):
return self._channel
Expand Down
10 changes: 10 additions & 0 deletions src/tests/async_reflection_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from grpc_requests.aio import AsyncClient
from google.protobuf.json_format import ParseError

from tests.common import AsyncMetadataClientInterceptor

"""
Test cases for async reflection based client
"""
Expand All @@ -18,6 +20,14 @@ async def test_unary_unary():
assert isinstance(response, dict)
assert response == {"message": "Hello, sinsky!"}

@pytest.mark.asyncio
async def test_unary_unary_interceptor():
client = AsyncClient('localhost:50051', interceptors=[AsyncMetadataClientInterceptor()])
greeter_service = await client.service('helloworld.Greeter')
response = await greeter_service.SayHello({"name": "sinsky"})
assert isinstance(response, dict)
assert response == {"message": "Hello, sinsky, interceptor accepted!"}

@pytest.mark.asyncio
async def test_empty_body_request():
client = AsyncClient('localhost:50051')
Expand Down
37 changes: 37 additions & 0 deletions src/tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import grpc
from grpc_interceptor import ClientCallDetails


class MetadataClientInterceptor(grpc.UnaryUnaryClientInterceptor):
def __init__(self):
pass

def intercept_unary_unary(
self,
continuation,
client_call_details,
request,
):
new_details = ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
[("interceptor", "true")],
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression,
)

return continuation(new_details, request)

class AsyncMetadataClientInterceptor(grpc.aio.UnaryUnaryClientInterceptor):

async def intercept_unary_unary(self, continuation, client_call_details, request):
new_details = grpc.aio.ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
[("interceptor", "true")],
client_call_details.credentials,
client_call_details.wait_for_ready,
)

return await continuation(new_details, request)
19 changes: 19 additions & 0 deletions src/tests/reflection_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from grpc_requests.client import Client
from google.protobuf.json_format import ParseError

from tests.common import MetadataClientInterceptor

"""
Test cases for reflection based client
"""
Expand All @@ -18,6 +20,15 @@ def helloworld_reflection_client():
except: # noqa: E722
pytest.fail("Could not connect to local HelloWorld server")

@pytest.fixture(scope="module")
def helloworld_reflection_client_with_interceptor():
try:
# Don't use get_by_endpoint here, because interceptors are not cached. Consider caching kwargs too
client = Client('localhost:50051', interceptors=[MetadataClientInterceptor()])
yield client
except: # noqa: E722
pytest.fail("Could not connect to local HelloWorld server")

@pytest.fixture(scope="module")
def client_tester_reflection_client():
try:
Expand All @@ -35,6 +46,14 @@ def test_metadata_usage(helloworld_reflection_client):
assert isinstance(response, dict)
assert response == {"message": "Hello, sinsky, password accepted!"}

def test_interceptor_usage(helloworld_reflection_client_with_interceptor):
response = helloworld_reflection_client_with_interceptor.request(
'helloworld.Greeter', 'SayHello',
{"name": "sinsky"},
)
assert isinstance(response, dict)
assert response == {"message": "Hello, sinsky, interceptor accepted!"}


def test_unary_unary(helloworld_reflection_client):
response = helloworld_reflection_client.request('helloworld.Greeter', 'SayHello', {"name": "sinsky"})
Expand Down
11 changes: 4 additions & 7 deletions src/tests/test_servers/helloworld/helloworld_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,13 @@ def SayHello(self, request, context):
Unary-Unary
Sends a HelloReply based on a HelloRequest.
"""
authorized = False
if context.invocation_metadata():
for key, value in context.invocation_metadata():
if key == "password" and value == "12345":
authorized = True

if authorized:
return HelloReply(message=f"Hello, {request.name}, password accepted!")
else:
return HelloReply(message=f"Hello, {request.name}!")
return HelloReply(message=f"Hello, {request.name}, password accepted!")
if key == "interceptor" and value == "true":
return HelloReply(message=f"Hello, {request.name}, interceptor accepted!")
return HelloReply(message=f"Hello, {request.name}!")

def SayHelloGroup(self, request, context):
"""
Expand Down