Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 16 additions & 22 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import struct
import sys
import warnings
import typing
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
Expand All @@ -22,8 +22,6 @@
get_type_hints,
)

import typing

from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
Expand Down Expand Up @@ -126,11 +124,7 @@ class Casing(enum.Enum):
SNAKE = snake_case


class _PLACEHOLDER:
pass


PLACEHOLDER: Any = _PLACEHOLDER()
PLACEHOLDER: Any = object()


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -261,7 +255,7 @@ class Enum(enum.IntEnum):
def from_string(cls, name: str) -> int:
"""Return the value which corresponds to the string name."""
try:
return cls.__members__[name]
return cls._member_map_[name]
Comment thread
nat-n marked this conversation as resolved.
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

Expand Down Expand Up @@ -349,7 +343,7 @@ def _serialize_single(
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, wraps, value)

output = b""
output = bytearray()
if proto_type in WIRE_VARINT_TYPES:
key = encode_varint(field_number << 3)
output += key + value
Expand All @@ -366,10 +360,10 @@ def _serialize_single(
else:
raise NotImplementedError(proto_type)

return output
return bytes(output)


def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
Decode a single varint value from a byte buffer. Returns the value and the
new position in the buffer.
Expand Down Expand Up @@ -513,11 +507,11 @@ def __post_init__(self) -> None:
all_sentinel = True

# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, str] = {}
group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():

if meta.group:
group_current.setdefault(meta.group)
if meta.group and group_current.get(meta.group) is None:
group_current[meta.group] = None

if getattr(self, field_name) != PLACEHOLDER:
# Skip anything not set to the sentinel value
Expand Down Expand Up @@ -571,7 +565,7 @@ def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
"""
output = b""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)

Expand Down Expand Up @@ -609,7 +603,7 @@ def __bytes__(self) -> bytes:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = b""
buf = bytearray()
for item in value:
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
Expand Down Expand Up @@ -644,7 +638,8 @@ def __bytes__(self) -> bytes:
wraps=meta.wraps or "",
)

return output + self._unknown_fields
output += self._unknown_fields
return bytes(output)

# For compatibility with other libraries
SerializeToString = __bytes__
Expand Down Expand Up @@ -754,14 +749,14 @@ def parse(self: T, data: bytes) -> T:
"""
# Got some data over the wire
self._serialized_on_wire = True

proto_meta = self._betterproto
for parsed in parse_fields(data):
field_name = self._betterproto.field_name_by_number.get(parsed.number)
field_name = proto_meta.field_name_by_number.get(parsed.number)
if not field_name:
self._unknown_fields += parsed.raw
continue

meta = self._betterproto.meta_by_field_name[field_name]
meta = proto_meta.meta_by_field_name[field_name]

value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
Expand Down Expand Up @@ -907,7 +902,6 @@ def from_dict(self: T, value: dict) -> T:
returns the instance itself and is therefore assignable and chainable.
"""
self._serialized_on_wire = True
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
Expand Down