Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ protobuf = "^3.12.2"
pytest = "^5.4.2"
pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
tox = "^3.15.1"

[tool.poetry.scripts]
Expand Down
20 changes: 20 additions & 0 deletions src/betterproto/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,10 @@ def lookup_method_input_type(method, types):
return known_type


def is_mutable_field_type(field_type: str) -> bool:
return field_type.startswith("List[") or field_type.startswith("Dict[")


def read_protobuf_service(
service: ServiceDescriptorProto, index, proto_file, content, output_types
):
Expand All @@ -384,8 +388,23 @@ def read_protobuf_service(
for j, method in enumerate(service.method):
method_input_message = lookup_method_input_type(method, output_types)

# This section ensures that method arguments having a default
# value that is initialised as a List/Dict (mutable) is replaced
# with None and initialisation is deferred to the beginning of the
# method definition. This is done so to avoid any side-effects.
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
mutable_default_args = []

if method_input_message:
for field in method_input_message["properties"]:
if (
not method.client_streaming
and field["zero"] != "None"
and is_mutable_field_type(field["type"])
):
mutable_default_args.append((field["py_name"], field["zero"]))
field["zero"] = "None"

if field["zero"] == "None":
template_data["typing_imports"].add("Optional")

Expand All @@ -407,6 +426,7 @@ def read_protobuf_service(
),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
"mutable_default_args": mutable_default_args,
}
)

Expand Down
4 changes: 4 additions & 0 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{{ method.comment }}

{% endif %}
{%- for py_name, zero in method.mutable_default_args %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}

{% if not method.client_streaming %}
request = {{ method.input }}()
{% for field in method.input_message.properties %}
Expand Down
19 changes: 18 additions & 1 deletion tests/grpc/test_grpclib_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import asyncio
import sys

from tests.output_betterproto.service.service import (
DoThingResponse,
DoThingRequest,
DoThingResponse,
GetThingRequest,
TestStub as ThingServiceClient,
)
import grpclib
import grpclib.metadata
from grpclib.testing import ChannelFor
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
Expand Down Expand Up @@ -35,6 +38,20 @@ async def test_simple_service_call():
await _test_client(ThingServiceClient(channel))


@pytest.mark.asyncio
@pytest.mark.skipif(
sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"
)
async def test_service_call_mutable_defaults(mocker):
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
spy = mocker.spy(client, "_unary_unary")
await _test_client(client)
comments = spy.call_args_list[-1].args[1].comments
await _test_client(client)
assert spy.call_args_list[-1].args[1].comments is not comments


@pytest.mark.asyncio
async def test_service_call_with_upfront_request_params():
# Setting deadline
Expand Down
1 change: 1 addition & 0 deletions tests/inputs/service/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package service;

message DoThingRequest {
string name = 1;
repeated string comments = 2;
}

message DoThingResponse {
Expand Down