Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 147 additions & 123 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ packages = [
[tool.poetry.dependencies]
python = "^3.6"
backports-datetime-fromisoformat = { version = "^1.0.0", python = "<3.7" }
black = { version = "^19.10b0", optional = true }
black = { version = "^20.8b1", optional = true }
dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpclib = "^0.3.1"
jinja2 = { version = "^2.11.2", optional = true }
protobuf = { version = "^3.12.2", optional = true }

[tool.poetry.dev-dependencies]
black = "^19.10b0"
black = "^20.8b1"
bpython = "^0.19"
grpcio-tools = "^1.30.0"
jinja2 = "^2.11.2"
Expand Down
38 changes: 16 additions & 22 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import struct
import sys
import warnings
import typing
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
Expand All @@ -22,8 +22,6 @@
get_type_hints,
)

import typing

from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
Expand Down Expand Up @@ -126,11 +124,7 @@ class Casing(enum.Enum):
SNAKE = snake_case


class _PLACEHOLDER:
pass


PLACEHOLDER: Any = _PLACEHOLDER()
PLACEHOLDER: Any = object()


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -261,7 +255,7 @@ class Enum(enum.IntEnum):
def from_string(cls, name: str) -> int:
"""Return the value which corresponds to the string name."""
try:
return cls.__members__[name]
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

Expand Down Expand Up @@ -349,7 +343,7 @@ def _serialize_single(
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, wraps, value)

output = b""
output = bytearray()
if proto_type in WIRE_VARINT_TYPES:
key = encode_varint(field_number << 3)
output += key + value
Expand All @@ -366,10 +360,10 @@ def _serialize_single(
else:
raise NotImplementedError(proto_type)

return output
return bytes(output)


def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
Decode a single varint value from a byte buffer. Returns the value and the
new position in the buffer.
Expand Down Expand Up @@ -513,7 +507,7 @@ def __post_init__(self) -> None:
all_sentinel = True

# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, str] = {}
group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():

if meta.group:
Expand Down Expand Up @@ -549,7 +543,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
self._group_current[group] = field.name
else:
super().__setattr__(
field.name, self._get_field_default(field.name),
field.name, self._get_field_default(field.name)
)

super().__setattr__(attr, value)
Expand All @@ -571,7 +565,7 @@ def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
"""
output = b""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)

Expand Down Expand Up @@ -609,7 +603,7 @@ def __bytes__(self) -> bytes:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = b""
buf = bytearray()
for item in value:
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
Expand Down Expand Up @@ -644,7 +638,8 @@ def __bytes__(self) -> bytes:
wraps=meta.wraps or "",
)

return output + self._unknown_fields
output += self._unknown_fields
return bytes(output)

# For compatibility with other libraries
SerializeToString = __bytes__
Expand Down Expand Up @@ -754,14 +749,14 @@ def parse(self: T, data: bytes) -> T:
"""
# Got some data over the wire
self._serialized_on_wire = True

proto_meta = self._betterproto
for parsed in parse_fields(data):
field_name = self._betterproto.field_name_by_number.get(parsed.number)
field_name = proto_meta.field_name_by_number.get(parsed.number)
if not field_name:
self._unknown_fields += parsed.raw
continue

meta = self._betterproto.meta_by_field_name[field_name]
meta = proto_meta.meta_by_field_name[field_name]

value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
Expand Down Expand Up @@ -857,7 +852,7 @@ def to_dict(
field_name=field_name, meta=meta
)
):
output[cased_name] = value.to_dict(casing, include_default_values,)
output[cased_name] = value.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
for k in value:
if hasattr(value[k], "to_dict"):
Expand Down Expand Up @@ -907,7 +902,6 @@ def from_dict(self: T, value: dict) -> T:
returns the instance itself and is therefore assignable and chainable.
"""
self._serialized_on_wire = True
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
Expand Down
2 changes: 1 addition & 1 deletion src/betterproto/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def parse_source_type_name(field_type_name):


def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True,
package: str, imports: set, source_type: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
Expand Down
4 changes: 1 addition & 3 deletions src/betterproto/grpc/util/async_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ class AsyncChannel(AsyncIterable[T]):
or immediately if no source is provided.
"""

def __init__(
self, *, buffer_limit: int = 0, close: bool = False,
):
def __init__(self, *, buffer_limit: int = 0, close: bool = False):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False
self._waiting_receivers: int = 0
Expand Down
56 changes: 26 additions & 30 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent,
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
)

if path[-2] == 2 and path[-4] != 6:
Expand All @@ -153,6 +153,7 @@ class ProtoContentBase:

path: List[int]
comment_indent: int = 4
parent: Union["Messsage", "OutputTemplate"]

def __post_init__(self):
"""Checks that no fake default fields were left as placeholders."""
Expand Down Expand Up @@ -187,7 +188,7 @@ def comment(self) -> str:
for this object.
"""
return get_comment(
proto_file=self.proto_file, path=self.path, indent=self.comment_indent,
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
)


Expand Down Expand Up @@ -262,8 +263,7 @@ def python_module_imports(self) -> Set[str]:

@dataclass
class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message.
"""
"""Representation of a protobuf message."""

parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
Expand Down Expand Up @@ -307,8 +307,7 @@ def deprecated_fields(self) -> Iterator[str]:
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool:
"""True if proto_field_obj is a map, otherwise False.
"""
"""True if proto_field_obj is a map, otherwise False."""
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
# This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower()
Expand All @@ -322,8 +321,7 @@ def is_map(


def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False.
"""
"""True if proto_field_obj is a OneOf, otherwise False."""
if proto_field_obj.HasField("oneof_index"):
return True
return False
Expand Down Expand Up @@ -355,24 +353,26 @@ def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field as a field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
field_args = ", ".join(
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
)
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
+ f"{self.betterproto_field_args}"
+ field_args
+ ")"
)
return name + annotations + " = " + betterproto_field_type

@property
def betterproto_field_args(self):
args = ""
def betterproto_field_args(self) -> List[str]:
args = []
if self.field_wraps:
args = args + f", wraps={self.field_wraps}"
args.append(f"wraps={self.field_wraps}")
return args

@property
def field_wraps(self) -> Union[str, None]:
"""Returns betterproto wrapped field type or None.
"""
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
)
Expand Down Expand Up @@ -405,8 +405,7 @@ def field_type(self) -> str:

@property
def default_value_string(self) -> Union[Text, None, float, int]:
"""Python representation of the default proto value.
"""
"""Python representation of the default proto value."""
if self.repeated:
return "[]"
if self.py_type == "int":
Expand Down Expand Up @@ -473,10 +472,10 @@ def annotation(self) -> str:
@dataclass
class OneOfFieldCompiler(FieldCompiler):
@property
def betterproto_field_args(self) -> "str":
def betterproto_field_args(self) -> List[str]:
args = super().betterproto_field_args
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
args = args + f', group="{group}"'
args.append(f'group="{group}"')
return args


Expand All @@ -495,26 +494,23 @@ def __post_init__(self):
if nested.options.map_entry:
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0], # key
parent=self, proto_obj=nested.field[0] # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1], # value
parent=self, proto_obj=nested.field[1] # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__

def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
betterproto_field_type = (
f"betterproto.map_field("
f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, "
f"betterproto.{self.proto_v_type})"
)
return name + annotations + " = " + betterproto_field_type
@property
def betterproto_field_args(self) -> List[str]:
return [f"betterproto.{self.proto_k_type}", f"betterproto.{self.proto_v_type}"]

@property
def field_type(self) -> str:
return "map"

@property
def annotation(self):
Expand Down
4 changes: 2 additions & 2 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def read_protobuf_service(
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
) -> None:
service_data = ServiceCompiler(
parent=output_package, proto_obj=service, path=[6, index],
parent=output_package, proto_obj=service, path=[6, index]
)
for j, method in enumerate(service.method):
ServiceMethodCompiler(
parent=service_data, proto_obj=method, path=[6, index, 2, j],
parent=service_data, proto_obj=method, path=[6, index, 2, j]
)
6 changes: 3 additions & 3 deletions tests/grpc/test_grpclib_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
Expand All @@ -117,7 +117,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
Expand All @@ -134,7 +134,7 @@ async def test_service_call_lower_level_with_overrides():
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata))]
) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary(
Expand Down
2 changes: 1 addition & 1 deletion tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def protoc(
*[p.as_posix() for p in path.glob("*.proto")],
]
proc = await asyncio.create_subprocess_exec(
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
*command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
stdout, stderr = await proc.communicate()
return stdout, stderr, proc.returncode
Expand Down