diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 8a34a4d52..0130dc10b 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1,5 +1,4 @@ import dataclasses -import enum import inspect import json import struct @@ -24,6 +23,7 @@ import typing +from .enum import IntEnum as Enum, Enum as _Enum from ._types import T from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub @@ -119,7 +119,7 @@ def datetime_default_gen(): DATETIME_ZERO = datetime_default_gen() -class Casing(enum.Enum): +class Casing(_Enum): """Casing constants for serialization.""" CAMEL = camel_case @@ -254,18 +254,6 @@ def map_field( ) -class Enum(enum.IntEnum): - """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" - - @classmethod - def from_string(cls, name: str) -> int: - """Return the value which corresponds to the string name.""" - try: - return cls.__members__[name] - except KeyError as e: - raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e - - def _pack_fmt(proto_type: str) -> str: """Returns a little-endian format string for reading/writing binary.""" return { @@ -783,7 +771,7 @@ def FromString(cls: Type[T], data: bytes) -> T: return cls().parse(data) def to_dict( - self, casing: Casing = Casing.CAMEL, include_default_values: bool = False + self, casing: Optional[Casing] = None, include_default_values: bool = False ) -> Dict[str, Any]: """ Returns a dict representation of this message instance which can be @@ -795,6 +783,7 @@ def to_dict( not be in returned dict if `include_default_values` is set to `False`. """ + casing = casing or Casing.CAMEL output: Dict[str, Any] = {} field_types = self._type_hints() for field_name, meta in self._betterproto.meta_by_field_name.items(): @@ -909,9 +898,9 @@ def from_dict(self: T, value: dict) -> T: elif meta.proto_type == TYPE_ENUM: enum_cls = self._betterproto.cls_by_field[field_name] if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] + v = [enum_cls[e] for e in v] elif isinstance(v, str): - v = enum_cls.from_string(v) + v = enum_cls[v] if v is not None: setattr(self, field_name, v) @@ -1007,7 +996,7 @@ class _WrappedMessage(Message): value: Any - def to_dict(self, casing: Casing = Casing.CAMEL) -> Any: + def to_dict(self, casing: Optional[Casing] = None) -> Any: return self.value def from_dict(self: T, value: Any) -> T: diff --git a/src/betterproto/enum.py b/src/betterproto/enum.py new file mode 100644 index 000000000..8d79e6c2f --- /dev/null +++ b/src/betterproto/enum.py @@ -0,0 +1,201 @@ +from enum import EnumMeta as _EnumMeta, _is_dunder, _is_descriptor +from types import MappingProxyType +from typing import Any, Dict, Iterable, List, Mapping, NoReturn, Tuple + +from .casing import camel_case, snake_case + + +class EnumMember: + _enum_cls_: "Enum" + name: str + value: Any + + def __new__(cls, **kwargs: Dict[str, Any]) -> "EnumMember": + self = super().__new__(cls) + try: + self.name = kwargs["name"] + self.value = kwargs["value"] + except KeyError: + pass + finally: + return self + + def __repr__(self): + return f"<{self._enum_cls_.__name__}.{self.name}: {self.value!r}>" + + def __str__(self): + return f"{self._enum_cls_.__name__}.{self.name}" + + @classmethod + def __call__(cls, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: + try: + kwargs["_is_enum__call__"] + except KeyError: + return cls.value(*args, **kwargs) + else: + return cls.__new__(cls, name=kwargs["name"], value=kwargs["value"]) + + def __hash__(self): + return hash((self.name, self.value)) + + def __eq__(self, other: Any): + try: + if other._enum_cls_ is self._enum_cls_: + return self.value == other.value + return False + except AttributeError: + return NotImplemented + + +class IntEnumMember(int, EnumMember): + _enum_cls_: "IntEnum" + value: int + + def __new__(cls, **kwargs: Dict[str, Any]) -> "IntEnumMember": + try: + value = kwargs["value"] + self = super().__new__(cls, value) + self.name = kwargs["name"] + self.value = value + return self + except KeyError: + return super().__new__(cls) + + +class EnumMeta(type): + _enum_value_map_: Dict[Any, EnumMember] + _enum_member_map_: Dict[str, EnumMember] + _enum_member_names_: List[str] + + def __new__( + mcs, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any] + ) -> "EnumMeta": + value_mapping: Dict[Any, EnumMember] = {} + member_mapping: Dict[str, EnumMember] = {} + member_names: List[str] = [] + try: + value_cls = ( + IntEnumMember(name=name) if IntEnum in bases else EnumMember(name=name) + ) + except NameError: + value_cls = EnumMember(name=name) + + for key, value in tuple(attrs.items()): + is_descriptor = _is_descriptor(value) + if key[0] == "_" and not is_descriptor: + continue + + if is_descriptor: + if value not in (camel_case, snake_case): + setattr(value_cls, key, value) + del attrs[key] + continue + + try: + new_value = value_mapping[value] + except KeyError: + new_value = value_cls(name=key, value=value, _is_enum__call__=True) + value_mapping[value] = new_value + member_names.append(key) + + member_mapping[key] = new_value + attrs[key] = new_value + + attrs["_enum_value_map_"] = value_mapping + attrs["_enum_member_map_"] = member_mapping + attrs["_enum_member_names_"] = member_names + enum_class: "EnumMeta" = super().__new__(mcs, name, bases, attrs) + for member in member_mapping.values(): + member._enum_cls_ = enum_class + return enum_class + + def __call__(cls, value: Any) -> "EnumMember": + if isinstance(value, cls): + return value + try: + return cls._enum_value_map_[value] + except (KeyError, TypeError): + raise ValueError(f"{value!r} is not a valid {cls.__name__}") + + def __repr__(cls): + return f"" + + def __iter__(cls) -> Iterable["EnumMember"]: + return (cls._enum_member_map_[name] for name in cls._enum_member_names_) + + def __reversed__(cls) -> Iterable["EnumMember"]: + return ( + cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_) + ) + + def __len__(cls): + return len(cls._enum_member_names_) + + def __getitem__(cls, key: Any) -> "EnumMember": + return cls._enum_member_map_[key] + + def __getattr__(cls, name: str): + if _is_dunder(name): + raise AttributeError(name) + try: + return cls._enum_value_map_[name] + except KeyError: + raise AttributeError(name) from None + + def __setattr__(cls, name: str, value: Any) -> NoReturn: + if name in cls._enum_member_map_: + raise AttributeError("Cannot reassign members.") + super().__setattr__(name, value) + + def __delattr__(cls, attr: Any) -> NoReturn: + if attr in cls._enum_member_map_: + raise AttributeError(f"{cls.__name__}: cannot delete Enum member.") + super().__delattr__(attr) + + def __instancecheck__(self, instance: Any): + try: + cls = instance._enum_cls_ + return cls is self or issubclass(cls, self) + except AttributeError: + return False + + def __dir__(cls): + return [ + "__class__", + "__doc__", + "__members__", + "__module__", + ] + cls._enum_member_names_ + + def __contains__(cls, member: "EnumMember"): + if not isinstance(member, EnumMember): + raise TypeError( + "unsupported operand type(s) for 'in':" + f" '{member.__class__.__qualname__}' and '{cls.__class__.__qualname__}'" + ) + return member.name in cls._enum_member_map_ + + def __bool__(self): + return True + + @property + def __members__(cls) -> Mapping[str, "EnumMember"]: + return MappingProxyType(cls._enum_member_map_) + + +class Enum(metaclass=EnumMeta): + """Protocol buffers enumeration base base class. Acts like `enum.Enum`.""" + + +class IntEnum(int, Enum): + """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" + + +def patched_instance_check(self: _EnumMeta, instance: Any) -> bool: + if isinstance(instance, (EnumMeta, EnumMember)): + return True + + return type.__instancecheck__(self, instance) + + +_EnumMeta.__instancecheck__ = patched_instance_check # fake it till you make it diff --git a/tests/test_enum.py b/tests/test_enum.py new file mode 100644 index 000000000..9c6b1dc6a --- /dev/null +++ b/tests/test_enum.py @@ -0,0 +1,230 @@ +import pytest + +from betterproto.enum import Enum, IntEnum + + +class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + + +class Grades(IntEnum): + A = 5 + B = 4 + C = 3 + D = 2 + F = 0 + + +class Directional(Enum): + EAST = "east" + WEST = "west" + NORTH = "north" + SOUTH = "south" + + +def test_dir_on_class(): + assert set(dir(Season)) == { + "__class__", + "__doc__", + "__members__", + "__module__", + "SPRING", + "SUMMER", + "AUTUMN", + "WINTER", + } + + +def test_enum_in_enum_out(): + assert Season(Season.WINTER) is Season.WINTER + + +def test_enum_value(): + assert Season.SPRING.value == 1 + + +def test_enum(): + lst = list(Season) + assert len(lst) == len(Season) + assert len(Season) == 4 + assert [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER] == lst + + for i, season in enumerate("SPRING SUMMER AUTUMN WINTER".split(), 1): + e = Season(i) + assert e == getattr(Season, season) + assert e.value == i + assert e != i + assert e.name == season + assert e in Season + assert isinstance(e, Season) + assert str(e) == f"Season.{season}" + assert repr(e) == f"" + + +def test_value_name(): + assert Season.SPRING.name == "SPRING" + assert Season.SPRING.value == 1 + + +def test_changing_member(): + with pytest.raises(AttributeError): + Season.WINTER = "really cold" + + +def test_attribute_deletion(): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = 3 + WINTER = 4 + + def spam(cls): + pass + + with pytest.raises(AttributeError): + del Season.DRY + + +def test_bool_of_class(): + class Empty(Enum): + pass + + assert bool(Empty) + + +def test_bool_of_member(): + class Count(Enum): + zero = 0 + one = 1 + two = 2 + + for member in Count: + assert bool(member) + + +def test_bool(): + # plain Enum members are always True + class Logic(Enum): + true = True + false = False + + assert Logic.true + assert Logic.false + + # unless overridden + class RealLogic(Enum): + true = True + false = False + + def __bool__(self): + return bool(self.value) + + assert RealLogic.true + assert not RealLogic.false + + +def test_contains(): + assert Season.AUTUMN in Season + with pytest.raises(TypeError): + 3 in Season + with pytest.raises(TypeError): + "AUTUMN" in Season + + val = Season(3) + assert val in Season + + class OtherEnum(Enum): + one = 1 + two = 2 + + assert OtherEnum.two not in Season + + +def test_comparisons(): + with pytest.raises(TypeError): + Season.SPRING < Season.WINTER + with pytest.raises(TypeError): + Season.SPRING > 4 + + class Part(Enum): + SPRING = 1 + CLIP = 2 + BARREL = 3 + + assert Season.SPRING != Part.SPRING + with pytest.raises(TypeError): + Season.SPRING < Part.CLIP + + +def test_enum_duplicates(): + class Season(Enum): + SPRING = 1 + SUMMER = 2 + AUTUMN = FALL = 3 + WINTER = 4 + ANOTHER_SPRING = 1 + + lst = list(Season) + assert lst == [Season.SPRING, Season.SUMMER, Season.AUTUMN, Season.WINTER] + assert Season.FALL is Season.AUTUMN + assert Season.FALL.value == 3 + assert Season.AUTUMN.value == 3 + assert Season(3) is Season.AUTUMN + assert Season(1) is Season.SPRING + assert Season.FALL.name == "AUTUMN" + assert [k for k, v in Season.__members__.items() if v.name != k] == [ + "FALL", + "ANOTHER_SPRING", + ] + + +def test_enum_with_value_name(): + class Huh(Enum): + name = 1 + value = 2 + + assert list(Huh) == [Huh.name, Huh.value] + assert Huh.name.name == "name" + assert Huh.name.value == 1 + + +def test_hash(): + dates = {} + dates[Season.WINTER] = "1225" + dates[Season.SPRING] = "0315" + dates[Season.SUMMER] = "0704" + dates[Season.AUTUMN] = "1031" + assert dates[Season.AUTUMN] == "1031" + + +def test_intenum(): + class WeekDay(IntEnum): + SUNDAY = 1 + MONDAY = 2 + TUESDAY = 3 + WEDNESDAY = 4 + THURSDAY = 5 + FRIDAY = 6 + SATURDAY = 7 + + assert ["a", "b", "c"][WeekDay.MONDAY] == "c" + assert [i for i in range(WeekDay.TUESDAY)] == [0, 1, 2] + + lst = list(WeekDay) + assert len(lst) == len(WeekDay) + assert len(WeekDay) == 7 + target = "SUNDAY MONDAY TUESDAY WEDNESDAY THURSDAY FRIDAY SATURDAY" + target = target.split() + for i, weekday in enumerate(target, 1): + e = WeekDay(i) + assert e == i + assert int(e) == i + assert e.name == weekday + assert e in WeekDay + assert lst.index(e) + 1 == i + assert 0 < e < 8 + assert isinstance(e, int) + assert isinstance(e, Enum)