diff --git a/HISTORY.rst b/HISTORY.rst index b0bd5fe7..05b76f13 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -11,10 +11,11 @@ History (`#218 `_) * Fix structuring bare ``typing.Tuple`` on Pythons lower than 3.9. (`#218 `_) - * Fix a wrong ``AttributeError`` of an missing ``__parameters__`` attribute. This could happen when inheriting certain generic classes – for example ``typing.*`` classes are affected. (`#217 `_) +* Fix structuring of ``enum.Enum`` instances in ``typing.Literal`` types. + (`#231 `_) 1.10.0 (2022-01-04) ------------------- diff --git a/src/cattr/converters.py b/src/cattr/converters.py index 6bf4af8a..bdc0f448 100644 --- a/src/cattr/converters.py +++ b/src/cattr/converters.py @@ -81,6 +81,12 @@ def is_optional(typ): return is_union_type(typ) and NoneType in typ.__args__ and len(typ.__args__) == 2 +def is_literal_containing_enums(typ): + return is_literal(typ) and any( + isinstance(val, Enum) for val in typ.__args__ + ) + + class Converter: """Converts between structured and unstructured data.""" @@ -146,7 +152,8 @@ def __init__( [ (lambda cl: cl is Any or cl is Optional or cl is None, lambda v, _: v), (is_generic_attrs, self._gen_structure_generic, True), - (is_literal, self._structure_literal), + (is_literal, self._structure_simple_literal), + (is_literal_containing_enums, self._structure_enum_literal), (is_sequence, self._structure_list), (is_mutable_set, self._structure_set), (is_frozenset, self._structure_frozenset), @@ -375,11 +382,21 @@ def _structure_call(obj, cl): return cl(obj) @staticmethod - def _structure_literal(val, type): + def _structure_simple_literal(val, type): if val not in type.__args__: raise Exception(f"{val} not in literal {type}") return val + @staticmethod + def _structure_enum_literal(val, type): + vals = { + (x.value if isinstance(x, Enum) else x): x for x in type.__args__ + } + try: + return vals[val] + except KeyError: + raise Exception(f"{val} not in literal {type}") from None + # Attrs classes. def structure_attrs_fromtuple(self, obj: Tuple[Any, ...], cl: Type[T]) -> T: diff --git a/tests/test_structure_attrs.py b/tests/test_structure_attrs.py index 6675faa5..deb6baff 100644 --- a/tests/test_structure_attrs.py +++ b/tests/test_structure_attrs.py @@ -1,4 +1,5 @@ """Loading of attrs classes.""" +from enum import Enum from ipaddress import IPv4Address, IPv6Address, ip_address from typing import Union from unittest.mock import Mock @@ -153,6 +154,27 @@ class ClassWithLiteral: ) == ClassWithLiteral(4) +@pytest.mark.skipif(is_py37, reason="Not supported on 3.7") +@pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) +def test_structure_literal_enum(converter_cls): + """Structuring a class with a literal field works.""" + from typing import Literal + + converter = converter_cls() + + class Foo(Enum): + FOO = 1 + BAR = 2 + + @define + class ClassWithLiteral: + literal_field: Literal[Foo.FOO] = Foo.FOO + + assert converter.structure( + {"literal_field": 1}, ClassWithLiteral + ) == ClassWithLiteral(Foo.FOO) + + @pytest.mark.skipif(is_py37, reason="Not supported on 3.7") @pytest.mark.parametrize("converter_cls", [Converter, GenConverter]) def test_structure_literal_multiple(converter_cls): @@ -161,9 +183,17 @@ def test_structure_literal_multiple(converter_cls): converter = converter_cls() + class Foo(Enum): + FOO = 7 + FOOFOO = 77 + + class Bar(int, Enum): + BAR = 8 + BARBAR = 88 + @define class ClassWithLiteral: - literal_field: Literal[4, 5] = 4 + literal_field: Literal[4, 5, Foo.FOO, Bar.BARBAR] = 4 assert converter.structure( {"literal_field": 4}, ClassWithLiteral @@ -172,6 +202,14 @@ class ClassWithLiteral: {"literal_field": 5}, ClassWithLiteral ) == ClassWithLiteral(5) + assert converter.structure( + {"literal_field": 7}, ClassWithLiteral + ) == ClassWithLiteral(Foo.FOO) + + cwl = converter.structure({"literal_field": 88}, ClassWithLiteral) + assert cwl == ClassWithLiteral(Bar.BARBAR) + assert isinstance(cwl.literal_field, Bar) + @pytest.mark.skipif(is_py37, reason="Not supported on 3.7") @pytest.mark.parametrize("converter_cls", [Converter, GenConverter])