Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix proto message str in plugin and add list generaiton plugin #283

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions django_socio_grpc/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from django_socio_grpc.protobuf.generation_plugin import (
BaseGenerationPlugin,
RequestAsListGenerationPlugin,
ResponseAsListGenerationPlugin,
ListGenerationPlugin,
)
from django_socio_grpc.protobuf.message_name_constructor import MessageNameConstructor
from django_socio_grpc.settings import grpc_settings
Expand All @@ -20,18 +19,14 @@ def _maintain_compat(use_request_list, use_response_list, use_generation_plugins
"""
internal_plugins = [] if use_generation_plugins is None else use_generation_plugins
warning_message = "You are using {0} argument in grpc_action. This argument is deprecated and has been remplaced by a specific GenerationPlugin. Please update following the documentation: https://django-socio-grpc.readthedocs.io/en/stable/features/proto-generation.html#proto-generation-plugins"
if use_request_list:
logger.warning(warning_message.format("use_request_list"))
if use_request_list or use_response_list:
log_text = "use_request_list" if use_request_list else "use_response_list"
if use_request_list and use_response_list:
log_text = "use_request_list and use_response_list"
logger.warning(warning_message.format(log_text))
internal_plugins.insert(
0,
RequestAsListGenerationPlugin(list_field_name="results"),
)

if use_response_list:
logger.warning(warning_message.format("use_response_list"))
internal_plugins.insert(
0,
ResponseAsListGenerationPlugin(list_field_name="results"),
ListGenerationPlugin(request=use_request_list, response=use_response_list),
)

return internal_plugins
Expand Down
8 changes: 6 additions & 2 deletions django_socio_grpc/grpc_actions/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,14 @@ def make_proto_rpc(self, action_name: str, service: Type["Service"]) -> ProtoRpc

# INFO - AM - 22/02/2024 - Get the actual request name
request_name: str = message_name_constructor.construct_request_name()
request: ProtoMessage = req_class.create(value=self.request, name=request_name)
request: Union[ProtoMessage, str] = req_class.create(
value=self.request, name=request_name
)

response_name: str = message_name_constructor.construct_response_name()
response: ProtoMessage = res_class.create(value=self.response, name=response_name)
response: Union[ProtoMessage, str] = res_class.create(
value=self.response, name=response_name
)

for generation_plugin in self.use_generation_plugins:
request, response = generation_plugin.run_validation_and_transform(
Expand Down
70 changes: 54 additions & 16 deletions django_socio_grpc/protobuf/generation_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class BaseGenerationPlugin:
def check_condition(
self,
service: Type["Service"],
request_message: ProtoMessage,
response_message: ProtoMessage,
request_message: Union[ProtoMessage, str],
response_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> bool:
"""
Expand All @@ -32,7 +32,7 @@ def check_condition(
def transform_request_message(
self,
service: Type["Service"],
proto_message: ProtoMessage,
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
"""
Expand All @@ -43,7 +43,7 @@ def transform_request_message(
def transform_response_message(
self,
service: Type["Service"],
proto_message: ProtoMessage,
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
"""
Expand All @@ -54,8 +54,8 @@ def transform_response_message(
def run_validation_and_transform(
self,
service: Type["Service"],
request_message: ProtoMessage,
response_message: ProtoMessage,
request_message: Union[ProtoMessage, str],
response_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> Tuple[ProtoMessage, ProtoMessage]:
"""
Expand Down Expand Up @@ -92,9 +92,13 @@ def __init__(self, *args, **kwargs):
def transform_request_message(
self,
service: Type["Service"],
proto_message: ProtoMessage,
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
):
if isinstance(proto_message, str):
logger.warning(
f"Plugin {self.__class__.__name__} can't be used with a string message. Please use the plugin directly on the grpc_action that generate the message"
)
proto_message.fields.append(
ProtoField.from_field_dict(
{
Expand Down Expand Up @@ -123,8 +127,8 @@ def __init__(self, display_warning_message=True):
def check_condition(
self,
service: Type["Service"],
request_message: ProtoMessage,
response_message: ProtoMessage,
request_message: Union[ProtoMessage, str],
response_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> bool:
# INFO - AM - 20/02/2024 - If service don't support filtering we do not add filter field
Expand Down Expand Up @@ -165,8 +169,8 @@ def __init__(self, display_warning_message=True):
def check_condition(
self,
service: Type["Service"],
request_message: ProtoMessage,
response_message: ProtoMessage,
request_message: Union[ProtoMessage, str],
response_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> bool:
# INFO - AM - 20/02/2024 - If service don't support filtering we do not add filter field
Expand Down Expand Up @@ -204,7 +208,7 @@ class AsListGenerationPlugin(BaseGenerationPlugin):
list_field_name: str = "results"

def transform_message_to_list(
self, service: Type["Service"], proto_message: ProtoMessage, list_name: str
self, service: Type["Service"], proto_message: Union[ProtoMessage, str], list_name: str
) -> ProtoMessage:
try:
list_field_name = proto_message.serializer.Meta.message_list_attr
Expand All @@ -219,7 +223,7 @@ def transform_message_to_list(
),
]

if hasattr(service, "pagination_class"):
if getattr(service, "pagination_class", None):
fields.append(
ProtoField(
name="count",
Expand All @@ -233,7 +237,7 @@ def transform_message_to_list(
)

# INFO - AM - If the original proto message is a serializer then we keep the comment at the serializer level. Else we put them at the list level
if not proto_message.serializer:
if not isinstance(proto_message, str) and not proto_message.serializer:
list_message.comments = proto_message.comments
proto_message.comments = None

Expand All @@ -248,7 +252,7 @@ class RequestAsListGenerationPlugin(AsListGenerationPlugin):
def transform_request_message(
self,
service: Type["Service"],
proto_message: ProtoMessage,
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
list_name = message_name_constructor.construct_request_list_name()
Expand All @@ -263,7 +267,7 @@ class ResponseAsListGenerationPlugin(AsListGenerationPlugin):
def transform_response_message(
self,
service: Type["Service"],
proto_message: ProtoMessage,
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
list_name = message_name_constructor.construct_response_list_name()
Expand All @@ -278,3 +282,37 @@ class RequestAndResponseAsListGenerationPlugin(
"""

...


@dataclass
class ListGenerationPlugin(RequestAsListGenerationPlugin, ResponseAsListGenerationPlugin):
"""
Transform both request and response ProtoMessage in list ProtoMessage
"""

request: bool = False
response: bool = False

def transform_response_message(
self,
service: Type["Service"],
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
if self.response:
return super().transform_response_message(
service, proto_message, message_name_constructor
)
return proto_message

def transform_request_message(
self,
service: Type["Service"],
proto_message: Union[ProtoMessage, str],
message_name_constructor: MessageNameConstructor,
) -> ProtoMessage:
if self.request:
return super().transform_request_message(
service, proto_message, message_name_constructor
)
return proto_message
13 changes: 6 additions & 7 deletions django_socio_grpc/tests/fakeapp/services/basic_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from django_socio_grpc import generics
from django_socio_grpc.decorators import grpc_action
from django_socio_grpc.protobuf.generation_plugin import (
RequestAndResponseAsListGenerationPlugin,
ResponseAsListGenerationPlugin,
ListGenerationPlugin,
)

from .basic_mixins import ListIdsMixin, ListNameMixin
Expand Down Expand Up @@ -56,7 +55,7 @@ async def TestEmptyMethod(self, request, context): ...
@grpc_action(
request=[],
response=BasicServiceSerializer,
use_generation_plugins=[ResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(response=True)],
)
async def GetMultiple(self, request, context):
# INFO - AM - 14/01/2022 - Do something here as filter user with the user name
Expand Down Expand Up @@ -99,7 +98,7 @@ async def MyMethod(self, request, context):
request=[{"name": "user_name", "type": "string"}],
response=[{"name": "user_name", "type": "string"}],
request_name="CustomMixParamForRequest",
use_generation_plugins=[RequestAndResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(request=True, response=True)],
)
async def MixParam(self, request, context):
pass
Expand All @@ -108,23 +107,23 @@ async def MixParam(self, request, context):
request=BasicServiceSerializer,
response="google.protobuf.Struct",
request_name="BasicParamWithSerializerRequest",
use_generation_plugins=[RequestAndResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(request=True, response=True)],
)
async def MixParamWithSerializer(self, request, context):
pass

@grpc_action(
request=BaseProtoExampleSerializer,
response=BaseProtoExampleSerializer,
use_generation_plugins=[ResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(response=True)],
)
async def TestBaseProtoSerializer(self, request, context):
pass

@grpc_action(
request=BasicProtoListChildSerializer,
response=BasicProtoListChildSerializer,
use_generation_plugins=[RequestAndResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(request=True, response=True)],
)
async def BasicList(self, request, context):
serializer = BasicProtoListChildSerializer(message=request, many=True)
Expand Down
4 changes: 2 additions & 2 deletions django_socio_grpc/tests/fakeapp/services/stream_in_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django_socio_grpc import generics
from django_socio_grpc.decorators import grpc_action
from django_socio_grpc.exceptions import NotFound
from django_socio_grpc.protobuf.generation_plugin import ResponseAsListGenerationPlugin
from django_socio_grpc.protobuf.generation_plugin import ListGenerationPlugin


class StreamInService(generics.GenericService):
Expand All @@ -14,7 +14,7 @@ class StreamInService(generics.GenericService):
request=[{"name": "name", "type": "string"}],
response=[{"name": "count", "type": "int32"}],
request_stream=True,
use_generation_plugins=[ResponseAsListGenerationPlugin()],
use_generation_plugins=[ListGenerationPlugin(response=True)],
)
async def StreamIn(self, request, context):
messages = [message async for message in request]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from django_socio_grpc.decorators import grpc_action
from django_socio_grpc.protobuf.generation_plugin import (
FilterGenerationPlugin,
ListGenerationPlugin,
PaginationGenerationPlugin,
ResponseAsListGenerationPlugin,
)


Expand Down Expand Up @@ -44,7 +44,7 @@ class UnitTestModelWithStructFilterService(
request=[],
response=UnitTestModelWithStructFilterSerializer,
use_generation_plugins=[
ResponseAsListGenerationPlugin(),
ListGenerationPlugin(response=True),
FilterGenerationPluginForce(),
PaginationGenerationPluginForce(),
],
Expand Down
Loading
Loading