From 52d60a07a55816ba15b8208b593e4aff87b09602 Mon Sep 17 00:00:00 2001 From: "ivan.prado" Date: Wed, 11 May 2022 16:48:38 +0200 Subject: [PATCH 1/6] Fix ser/deserialization for enumerations and Paths --- .../helpers/serialization/decoding.py | 12 +++-- .../helpers/serialization/encoding.py | 6 +++ test/utils/test_serialization.py | 51 +++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 709e0e2a..169424ad 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -6,8 +6,8 @@ from dataclasses import Field, fields from functools import lru_cache, partial from logging import getLogger -from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set - +from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set, Mapping, \ + Iterable from simple_parsing.utils import ( get_type_arguments, @@ -382,4 +382,10 @@ def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]: Returns: Callable[[Any], Union[T, Any]]: A decoding function that might return nothing. """ - return try_functions(lambda val: t(**val)) + def constructor(val): + if isinstance(val, Mapping): + return t(**val) + else: + return t(val) + + return try_functions(constructor) diff --git a/simple_parsing/helpers/serialization/encoding.py b/simple_parsing/helpers/serialization/encoding.py index 012d6433..2c24469d 100644 --- a/simple_parsing/helpers/serialization/encoding.py +++ b/simple_parsing/helpers/serialization/encoding.py @@ -11,6 +11,7 @@ def encode_ndarray(obj: np.ndarray) -> str: import copy import json from dataclasses import fields, is_dataclass +from enum import Enum from functools import singledispatch from logging import getLogger from os import PathLike @@ -134,3 +135,8 @@ def encode_path(obj: PathLike) -> str: @encode.register(Namespace) def encode_namespace(obj: Namespace) -> Any: return encode(vars(obj)) + + +@encode.register(Enum) +def encode_enum(obj: Enum) -> str: + return obj.value \ No newline at end of file diff --git a/test/utils/test_serialization.py b/test/utils/test_serialization.py index 25611c4d..8a93fc03 100644 --- a/test/utils/test_serialization.py +++ b/test/utils/test_serialization.py @@ -2,6 +2,7 @@ """ from collections import OrderedDict from dataclasses import dataclass +from enum import Enum from pathlib import Path from simple_parsing.helpers.serialization.serializable import SerializableMixin from test.testutils import raises, TestSetup @@ -539,3 +540,53 @@ class Bob(FrozenSerializable if frozen else Serializable, TestSetup): assert b.func(10) == 20 assert b.to_dict() == {"func": "double"} assert Bob.from_dict(b.to_dict()) == b + + +def test_enum(frozen: bool): + class AnimalType(Enum): + CAT = "cat" + DOG = "dog" + + @dataclass(frozen=frozen) + class Animal(FrozenSerializable if frozen else Serializable): + animal_type: AnimalType + name: str + + animal = Animal(AnimalType.CAT, "Fluffy") + assert animal.loads(animal.dumps()) == animal + + d = animal.to_dict() + assert d["animal_type"] == "cat" + assert animal.from_dict(d) == animal + + +def test_enum_with_ints(frozen: bool): + class AnimalType(Enum): + CAT = 1 + DOG = 2 + + @dataclass(frozen=frozen) + class Animal(FrozenSerializable if frozen else Serializable): + animal_type: AnimalType + name: str + + animal = Animal(AnimalType.CAT, "Fluffy") + assert animal.loads(animal.dumps()) == animal + + d = animal.to_dict() + assert d["animal_type"] == 1 + assert animal.from_dict(d) == animal + + +def test_path(frozen: bool): + @dataclass(frozen=frozen) + class Foo(FrozenSerializable if frozen else Serializable): + path: Path + + foo = Foo(Path("/tmp/foo")) + assert foo.loads(foo.dumps()) == foo + + d = foo.to_dict() + assert isinstance(d["path"], str) + assert foo.from_dict(d) == foo + assert isinstance(foo.from_dict(d).path, Path) \ No newline at end of file From 4102210486aaeedce9c796a00b4d52cb3b82aca2 Mon Sep 17 00:00:00 2001 From: "ivan.prado" Date: Wed, 11 May 2022 16:51:04 +0200 Subject: [PATCH 2/6] Return carriage --- simple_parsing/helpers/serialization/encoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simple_parsing/helpers/serialization/encoding.py b/simple_parsing/helpers/serialization/encoding.py index 2c24469d..17a6d97e 100644 --- a/simple_parsing/helpers/serialization/encoding.py +++ b/simple_parsing/helpers/serialization/encoding.py @@ -139,4 +139,4 @@ def encode_namespace(obj: Namespace) -> Any: @encode.register(Enum) def encode_enum(obj: Enum) -> str: - return obj.value \ No newline at end of file + return obj.value From 56ccea897be72833aeb5ecceccfa36523cc71c80 Mon Sep 17 00:00:00 2001 From: "ivan.prado" Date: Wed, 11 May 2022 16:52:18 +0200 Subject: [PATCH 3/6] Return carriage --- test/utils/test_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/utils/test_serialization.py b/test/utils/test_serialization.py index 8a93fc03..e9d2ae9e 100644 --- a/test/utils/test_serialization.py +++ b/test/utils/test_serialization.py @@ -589,4 +589,4 @@ class Foo(FrozenSerializable if frozen else Serializable): d = foo.to_dict() assert isinstance(d["path"], str) assert foo.from_dict(d) == foo - assert isinstance(foo.from_dict(d).path, Path) \ No newline at end of file + assert isinstance(foo.from_dict(d).path, Path) From f78f44212b577ae7ef734c75e54a744f54d4c03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20de=20Prado?= Date: Thu, 12 May 2022 09:24:10 +0200 Subject: [PATCH 4/6] Use the class method from_dict instead Co-authored-by: Fabrice Normandin --- test/utils/test_serialization.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/utils/test_serialization.py b/test/utils/test_serialization.py index e9d2ae9e..b570da0d 100644 --- a/test/utils/test_serialization.py +++ b/test/utils/test_serialization.py @@ -553,11 +553,11 @@ class Animal(FrozenSerializable if frozen else Serializable): name: str animal = Animal(AnimalType.CAT, "Fluffy") - assert animal.loads(animal.dumps()) == animal + assert Animal.loads(animal.dumps()) == animal d = animal.to_dict() assert d["animal_type"] == "cat" - assert animal.from_dict(d) == animal + assert Animal.from_dict(d) == animal def test_enum_with_ints(frozen: bool): @@ -571,11 +571,11 @@ class Animal(FrozenSerializable if frozen else Serializable): name: str animal = Animal(AnimalType.CAT, "Fluffy") - assert animal.loads(animal.dumps()) == animal + assert Animal.loads(animal.dumps()) == animal d = animal.to_dict() assert d["animal_type"] == 1 - assert animal.from_dict(d) == animal + assert Animal.from_dict(d) == animal def test_path(frozen: bool): @@ -584,9 +584,9 @@ class Foo(FrozenSerializable if frozen else Serializable): path: Path foo = Foo(Path("/tmp/foo")) - assert foo.loads(foo.dumps()) == foo + assert Foo.loads(foo.dumps()) == foo d = foo.to_dict() assert isinstance(d["path"], str) - assert foo.from_dict(d) == foo - assert isinstance(foo.from_dict(d).path, Path) + assert Foo.from_dict(d) == foo + assert isinstance(Foo.from_dict(d).path, Path) From 752a57c1030b08f64f3e755a1bd8b2ce1e701a3b Mon Sep 17 00:00:00 2001 From: "ivan.prado" Date: Thu, 12 May 2022 09:27:32 +0200 Subject: [PATCH 5/6] Switch to Mapping from collections --- simple_parsing/helpers/serialization/decoding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 169424ad..0e4007d2 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -3,11 +3,11 @@ import inspect import warnings from collections import OrderedDict +from collections.abc import Mapping from dataclasses import Field, fields from functools import lru_cache, partial from logging import getLogger -from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set, Mapping, \ - Iterable +from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set from simple_parsing.utils import ( get_type_arguments, From 6296025dff72d413545a0ee91eb6a251f57475d6 Mon Sep 17 00:00:00 2001 From: "ivan.prado" Date: Thu, 12 May 2022 11:25:49 +0200 Subject: [PATCH 6/6] Serializing enum name instead of value: value might not be unique and can be less representative --- .../helpers/serialization/decoding.py | 22 ++++++++++++++++++- .../helpers/serialization/encoding.py | 2 +- test/utils/test_serialization.py | 4 ++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 0e4007d2..40de3865 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -5,6 +5,7 @@ from collections import OrderedDict from collections.abc import Mapping from dataclasses import Field, fields +from enum import Enum from functools import lru_cache, partial from logging import getLogger from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set @@ -19,7 +20,7 @@ is_union, is_forward_ref, is_typevar, - get_bound, + get_bound, is_enum, ) logger = getLogger(__name__) @@ -167,6 +168,10 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]: args = get_type_arguments(t) return decode_union(*args) + if is_enum(t): + logger.debug(f"Decoding an Enum field: {t}") + return decode_enum(t) + from .serializable import ( get_dataclass_types_from_forward_ref, Serializable, @@ -361,6 +366,21 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V return _decode_dict +def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]: + """ + Creates a decoding function for an enum type. + + Args: + item_type (Type[Enum]): the type of the items in the set. + + Returns: + Callable[[str], Enum]: A function that returns the enum member for the given name. + """ + def _decode_enum(val: str) -> Enum: + return item_type[val] + return _decode_enum + + def no_op(v: T) -> T: """Decoding function that gives back the value as-is. diff --git a/simple_parsing/helpers/serialization/encoding.py b/simple_parsing/helpers/serialization/encoding.py index 17a6d97e..67f72a8d 100644 --- a/simple_parsing/helpers/serialization/encoding.py +++ b/simple_parsing/helpers/serialization/encoding.py @@ -139,4 +139,4 @@ def encode_namespace(obj: Namespace) -> Any: @encode.register(Enum) def encode_enum(obj: Enum) -> str: - return obj.value + return obj.name diff --git a/test/utils/test_serialization.py b/test/utils/test_serialization.py index b570da0d..866132ac 100644 --- a/test/utils/test_serialization.py +++ b/test/utils/test_serialization.py @@ -556,7 +556,7 @@ class Animal(FrozenSerializable if frozen else Serializable): assert Animal.loads(animal.dumps()) == animal d = animal.to_dict() - assert d["animal_type"] == "cat" + assert d["animal_type"] == "CAT" assert Animal.from_dict(d) == animal @@ -574,7 +574,7 @@ class Animal(FrozenSerializable if frozen else Serializable): assert Animal.loads(animal.dumps()) == animal d = animal.to_dict() - assert d["animal_type"] == 1 + assert d["animal_type"] == "CAT" assert Animal.from_dict(d) == animal