diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 05b8b7cd1..024ea19fd 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -836,7 +836,14 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: return t elif issubclass(t, Enum): # Enums always default to zero. - return int + def default_enum(): + try: + # try to create a python enum instance + return t(0) + except ValueError: + return 0 # if that does not work fallback to int + + return default_enum elif t is datetime: # Offsets are relative to 1970-01-01T00:00:00Z return datetime_default_gen @@ -861,6 +868,13 @@ def _postprocess_single( elif meta.proto_type == TYPE_BOOL: # Booleans use a varint encoding, so convert it to true/false. value = value > 0 + elif meta.proto_type == TYPE_ENUM: + # Convert enum ints to python enum instances + cls = self._betterproto.cls_by_field[field_name] + try: + value = cls(value) + except ValueError: + pass # the received value does not exist in the enum so we have to pass it as raw int elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): fmt = _pack_fmt(meta.proto_type) value = struct.unpack(fmt, value)[0] diff --git a/tests/inputs/enum/test_enum.py b/tests/inputs/enum/test_enum.py index 3005c43a8..bc3c8c961 100644 --- a/tests/inputs/enum/test_enum.py +++ b/tests/inputs/enum/test_enum.py @@ -82,3 +82,23 @@ def enum_generator(): yield Choice.THREE assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] + + +def test_enum_mapped_on_parse(): + # test default value + b = Test().parse(bytes(Test())) + assert b.choice.name == Choice.ZERO.name + assert b.choices == [] + + # test non default value + a = Test().parse(bytes(Test(choice=Choice.ONE))) + assert a.choice.name == Choice.ONE.name + assert b.choices == [] + + # test repeated + c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR]))) + assert c.choices[0].name == Choice.THREE.name + assert c.choices[1].name == Choice.FOUR.name + + # bonus: defaults after empty init are also mapped + assert Test().choice.name == Choice.ZERO.name