Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import inspect
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from dataclasses import Field, fields
from enum import Enum
from functools import lru_cache, partial
from logging import getLogger
from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set


from simple_parsing.utils import (
get_type_arguments,
is_dataclass_type,
Expand All @@ -19,7 +20,7 @@
is_union,
is_forward_ref,
is_typevar,
get_bound,
get_bound, is_enum,
)

logger = getLogger(__name__)
Expand Down Expand Up @@ -167,6 +168,10 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
args = get_type_arguments(t)
return decode_union(*args)

if is_enum(t):
logger.debug(f"Decoding an Enum field: {t}")
return decode_enum(t)

from .serializable import (
get_dataclass_types_from_forward_ref,
Serializable,
Expand Down Expand Up @@ -361,6 +366,21 @@ def _decode_dict(val: Union[Dict[Any, Any], List[Tuple[Any, Any]]]) -> Dict[K, V
return _decode_dict


def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]:
"""
Creates a decoding function for an enum type.

Args:
item_type (Type[Enum]): the type of the items in the set.

Returns:
Callable[[str], Enum]: A function that returns the enum member for the given name.
"""
def _decode_enum(val: str) -> Enum:
return item_type[val]
return _decode_enum


def no_op(v: T) -> T:
"""Decoding function that gives back the value as-is.

Expand All @@ -382,4 +402,10 @@ def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]:
Returns:
Callable[[Any], Union[T, Any]]: A decoding function that might return nothing.
"""
return try_functions(lambda val: t(**val))
def constructor(val):
if isinstance(val, Mapping):
return t(**val)
else:
return t(val)

return try_functions(constructor)
6 changes: 6 additions & 0 deletions simple_parsing/helpers/serialization/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def encode_ndarray(obj: np.ndarray) -> str:
import copy
import json
from dataclasses import fields, is_dataclass
from enum import Enum
from functools import singledispatch
from logging import getLogger
from os import PathLike
Expand Down Expand Up @@ -134,3 +135,8 @@ def encode_path(obj: PathLike) -> str:
@encode.register(Namespace)
def encode_namespace(obj: Namespace) -> Any:
return encode(vars(obj))


@encode.register(Enum)
def encode_enum(obj: Enum) -> str:
return obj.name
51 changes: 51 additions & 0 deletions test/utils/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from simple_parsing.helpers.serialization.serializable import SerializableMixin
from test.testutils import raises, TestSetup
Expand Down Expand Up @@ -539,3 +540,53 @@ class Bob(FrozenSerializable if frozen else Serializable, TestSetup):
assert b.func(10) == 20
assert b.to_dict() == {"func": "double"}
assert Bob.from_dict(b.to_dict()) == b


def test_enum(frozen: bool):
class AnimalType(Enum):
CAT = "cat"
DOG = "dog"

@dataclass(frozen=frozen)
class Animal(FrozenSerializable if frozen else Serializable):
animal_type: AnimalType
name: str

animal = Animal(AnimalType.CAT, "Fluffy")
assert Animal.loads(animal.dumps()) == animal

d = animal.to_dict()
assert d["animal_type"] == "CAT"
assert Animal.from_dict(d) == animal


def test_enum_with_ints(frozen: bool):
class AnimalType(Enum):
CAT = 1
DOG = 2

@dataclass(frozen=frozen)
class Animal(FrozenSerializable if frozen else Serializable):
animal_type: AnimalType
name: str

animal = Animal(AnimalType.CAT, "Fluffy")
assert Animal.loads(animal.dumps()) == animal

d = animal.to_dict()
assert d["animal_type"] == "CAT"
assert Animal.from_dict(d) == animal


def test_path(frozen: bool):
@dataclass(frozen=frozen)
class Foo(FrozenSerializable if frozen else Serializable):
path: Path

foo = Foo(Path("/tmp/foo"))
assert Foo.loads(foo.dumps()) == foo

d = foo.to_dict()
assert isinstance(d["path"], str)
assert Foo.from_dict(d) == foo
assert isinstance(Foo.from_dict(d).path, Path)