Skip to content

Commit

Permalink
Merge pull request #79 from wesky93/develop
Browse files Browse the repository at this point in the history
Allow configuration of message_to_dict
  • Loading branch information
ViridianForge authored Apr 23, 2024
2 parents d9de93a + f383892 commit 67280d6
Show file tree
Hide file tree
Showing 10 changed files with 490 additions and 207 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.1.17](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.17) - 2024-04-22

## Added

- Support for custom message parsing in both async and sync clients

## Removed

- Removed singular FileDescriptor getter methods and Method specific field descriptor
methods as laid out previously.

## [0.1.16](https://github.com/wesky93/grpc_requests/releases/tag/v0.1.16) - 2024-03-03

## Added
Expand Down
20 changes: 20 additions & 0 deletions src/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,26 @@ result = await greeter.HelloEveryone(requests_data)
results = [x async for x in await greeter.SayHelloOneByOne(requests_data)]
```

## Setting a Client's message_to_dict behavior

By utilizing `CustomArgumentParsers`, behavioral arguments can be passed to
message_to_dict at time of Client instantiation. This is available for both
synchronous and asynchronous clients.

```python
client = Client(
"localhost:50051",
message_parsers=CustomArgumentParsers(
message_to_dict_kwargs={
"preserving_proto_field_name": True,
"including_default_value_fields": True,
}
),
)
```

[Review the json_format documentation for what kwargs are available to message_to_dict.](https://googleapis.dev/python/protobuf/latest/google/protobuf/json_format.html)

## Retrieving Information about a Server

All forms of clients expose methods to allow a user to query a server about its
Expand Down
2 changes: 1 addition & 1 deletion src/grpc_requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
)
from .client import Client, ReflectionClient, StubClient, get_by_endpoint

__version__ = "0.1.16"
__version__ = "0.1.17"
139 changes: 88 additions & 51 deletions src/grpc_requests/aio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import sys
import warnings
from enum import Enum
from functools import partial
from typing import (
Expand All @@ -18,10 +17,16 @@
import grpc
from google.protobuf import (
descriptor_pb2,
message_factory,
)
from google.protobuf import (
descriptor_pool as _descriptor_pool,
)
from google.protobuf import (
symbol_database as _symbol_database,
message_factory,
) # noqa: E501
)

# noqa: E501
from google.protobuf.descriptor import MethodDescriptor, ServiceDescriptor
from google.protobuf.descriptor_pb2 import ServiceDescriptorProto
from google.protobuf.json_format import MessageToDict, ParseDict
Expand All @@ -34,11 +39,13 @@

if sys.version_info >= (3, 8):
import importlib.metadata
from typing import Protocol

def get_metadata(package_name: str):
return importlib.metadata.version(package_name)
else:
import pkg_resources
from typing_extensions import Protocol

def get_metadata(package_name: str):
return pkg_resources.get_distribution(package_name).version
Expand Down Expand Up @@ -146,27 +153,67 @@ def __del__(self):
pass


def parse_request_data(reqeust_data, input_type):
_data = reqeust_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type())
else:
request = _data
return request
class MessageParsersProtocol(Protocol):
def parse_request_data(self, request_data, input_type): ...

def parse_stream_requests(self, stream_requests_data: Iterable, input_type): ...

def parse_stream_requests(stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield parse_request_data(request_data or {}, input_type)
async def parse_response(self, response): ...

async def parse_stream_responses(self, responses: AsyncIterable): ...


class MessageParsers(MessageParsersProtocol):
def parse_request_data(self, request_data, input_type):
_data = request_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type())
else:
request = _data
return request

def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield self.parse_request_data(request_data or {}, input_type)

async def parse_response(self, response):
return MessageToDict(response, preserving_proto_field_name=True)

async def parse_stream_responses(self, responses: AsyncIterable):
async for resp in responses:
yield await self.parse_response(resp)


class CustomArgumentParsers(MessageParsersProtocol):
_message_to_dict_kwargs: Dict[str, Any]
_parse_dict_kwargs: Dict[str, Any]

def __init__(
self,
message_to_dict_kwargs: Dict[str, Any] = dict(),
parse_dict_kwargs: Dict[str, Any] = dict(),
):
self._message_to_dict_kwargs = message_to_dict_kwargs or {}
self._parse_dict_kwargs = parse_dict_kwargs or {}

def parse_request_data(self, request_data, input_type):
_data = request_data or {}
if isinstance(_data, dict):
request = ParseDict(_data, input_type(), **self._parse_dict_kwargs)
else:
request = _data
return request

async def parse_response(response):
return MessageToDict(response, preserving_proto_field_name=True)
def parse_stream_requests(self, stream_requests_data: Iterable, input_type):
for request_data in stream_requests_data:
yield self.parse_request_data(request_data or {}, input_type)

async def parse_response(self, response):
return MessageToDict(response, **self._message_to_dict_kwargs)

async def parse_stream_responses(responses: AsyncIterable):
async for resp in responses:
yield await parse_response(resp)
async def parse_stream_responses(self, responses: AsyncIterable):
async for resp in responses:
yield await self.parse_response(resp)


class MethodType(Enum):
Expand All @@ -179,25 +226,32 @@ class MethodType(Enum):
def is_unary_request(self):
return "unary_" in self.value

@property
def request_parser(self):
return parse_request_data if self.is_unary_request else parse_stream_requests

@property
def is_unary_response(self):
return "_unary" in self.value

@property
def response_parser(self):
return parse_response if self.is_unary_response else parse_stream_responses


class MethodMetaData(NamedTuple):
input_type: Any
output_type: Any
method_type: MethodType
handler: Any
descriptor: MethodDescriptor
parsers: MessageParsersProtocol

@property
def request_parser(self):
if self.method_type.is_unary_request:
return self.parsers.parse_request_data
else:
return self.parsers.parse_stream_requests

@property
def response_parser(self):
if self.method_type.is_unary_response:
return self.parsers.parse_response
else:
return self.parsers.parse_stream_responses


IS_REQUEST_STREAM = TypeVar("IS_REQUEST_STREAM")
Expand All @@ -220,6 +274,7 @@ def __init__(
ssl=False,
compression=None,
skip_check_method_available=False,
message_parsers: MessageParsersProtocol = MessageParsers(),
**kwargs,
):
super().__init__(
Expand All @@ -233,6 +288,7 @@ def __init__(
self._service_names: list = None
self.has_server_registered = False
self._skip_check_method_available = skip_check_method_available
self._message_parsers = message_parsers
self._services_module_name = {}
self._service_methods_meta: Dict[str, Dict[str, MethodMetaData]] = {}

Expand Down Expand Up @@ -309,6 +365,7 @@ def _register_methods(
output_type=output_type,
handler=handler,
descriptor=method_desc,
parsers=self._message_parsers,
)
return metadata

Expand Down Expand Up @@ -348,19 +405,17 @@ async def _request(self, service, method, request, raw_output=False, **kwargs):
# does not check request is available
method_meta = self.get_method_meta(service, method)

_request = method_meta.method_type.request_parser(
request, method_meta.input_type
)
_request = method_meta.request_parser(request, method_meta.input_type)
if method_meta.method_type.is_unary_response:
result = await method_meta.handler(_request, **kwargs)

if raw_output:
return result
else:
return await method_meta.method_type.response_parser(result)
return await method_meta.response_parser(result)
else:
result = method_meta.handler(_request, **kwargs)
return method_meta.method_type.response_parser(result)
return method_meta.response_parser(result)

async def request(self, service, method, request=None, raw_output=False, **kwargs):
await self.check_method_available(service, method)
Expand Down Expand Up @@ -427,6 +482,7 @@ def __init__(
descriptor_pool=None,
ssl=False,
compression=None,
message_parsers: MessageParsersProtocol = MessageParsers(),
**kwargs,
):
super().__init__(
Expand All @@ -435,6 +491,7 @@ def __init__(
descriptor_pool,
ssl=ssl,
compression=compression,
message_parsers=message_parsers,
**kwargs,
)
self.reflection_stub = reflection_pb2_grpc.ServerReflectionStub(self.channel)
Expand All @@ -453,26 +510,6 @@ async def _get_service_names(self):
services = tuple([s.name for s in resp.list_services_response.service])
return services

async def get_file_descriptor_by_name(self, name):
warnings.warn(
"This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_name() instead.",
DeprecationWarning,
)
request = reflection_pb2.ServerReflectionRequest(file_by_filename=name)
result = await self._reflection_single_request(request)
proto = result.file_descriptor_response.file_descriptor_proto[0]
return descriptor_pb2.FileDescriptorProto.FromString(proto)

async def get_file_descriptor_by_symbol(self, symbol):
warnings.warn(
"This function is deprecated, and will be removed in the 0.1.17 release. Use get_file_descriptors_by_symbol() instead.",
DeprecationWarning,
)
request = reflection_pb2.ServerReflectionRequest(file_containing_symbol=symbol)
result = await self._reflection_single_request(request)
proto = result.file_descriptor_response.file_descriptor_proto[0]
return descriptor_pb2.FileDescriptorProto.FromString(proto)

async def get_file_descriptors_by_name(self, name):
request = reflection_pb2.ServerReflectionRequest(file_by_filename=name)
result = await self._reflection_single_request(request)
Expand Down
Loading

0 comments on commit 67280d6

Please sign in to comment.