Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions simple_parsing/annotation_utils/get_field_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@
import types


def evaluate_string_annotation(annotation: str) -> type:
"""Attempts to evaluate the given annotation string, to get a 'live' type annotation back.

Any exceptions that are raised when evaluating are raised directly as-is.

NOTE: This is probably not 100% safe. I mean, if the user code puts urls and stuff in their
type annotations, and then uses simple-parsing, then sure, that code might get executed. But
I don't think it's my job to prevent them from shooting themselves in the foot, you know what I
mean?
"""
# The type of the field might be a string when using `from __future__ import annotations`.
# Get the local and global namespaces to pass to the `get_type_hints` function.
local_ns: Dict[str, Any] = {"typing": typing, **vars(typing)}
local_ns.update(forward_refs_to_types)
# Get the globals in the module where the class was defined.
# global_ns = sys.modules[some_class.__module__].__dict__
# from typing import get_type_hints
if "|" in annotation:
annotation = _get_old_style_annotation(annotation)
evaluated_t: type = eval(annotation, local_ns, {})
return evaluated_t


def _replace_UnionType_with_typing_Union(annotation):
from simple_parsing.utils import builtin_types, is_dict, is_list, is_tuple

Expand Down
1 change: 1 addition & 0 deletions simple_parsing/helpers/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def choice(*choices: T, default: T, **kwargs) -> T:
pass


# TODO: Fix the signature for this.
def choice(*choices: T, default: Union[T, _MISSING_TYPE] = MISSING, **kwargs: Any) -> T:
"""Makes a field which can be chosen from the set of choices from the
command-line.
Expand Down
66 changes: 44 additions & 22 deletions simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from dataclasses import Field, fields
from dataclasses import Field
from enum import Enum
from functools import lru_cache, partial
from logging import getLogger
from typing import TypeVar, Any, Dict, Type, Callable, Optional, Union, List, Tuple, Set
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union

from simple_parsing.annotation_utils.get_field_annotations import (
evaluate_string_annotation,
)
from simple_parsing.utils import (
get_bound,
get_type_arguments,
is_dataclass_type,
is_dict,
is_enum,
is_forward_ref,
is_list,
is_set,
is_tuple,
is_union,
is_forward_ref,
is_typevar,
get_bound, is_enum,
is_union,
)

logger = getLogger(__name__)
Expand Down Expand Up @@ -97,25 +100,50 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
# cache_info = get_decoding_fn.cache_info()
# logger.debug(f"called for type {t}! Cache info: {cache_info}")

if isinstance(t, str):
def _get_potential_keys(annotation: str) -> List[str]:
# Type annotation is a string.
# This can happen when the `from __future__ import annotations` feature is used.
potential_keys: List[Type] = []
for key in _decoding_fns:
if inspect.isclass(key):
if key.__qualname__ == t:
if key.__qualname__ == annotation:
# Qualname is more specific, there can't possibly be another match, so break.
potential_keys.append(key)
break
if key.__qualname__ == t:
if key.__qualname__ == annotation:
# For just __name__, there could be more than one match.
potential_keys.append(key)
return potential_keys

if isinstance(t, str):
if t in _decoding_fns:
return _decoding_fns[t]

potential_keys = _get_potential_keys(t)

if not potential_keys:
# Try to replace the new-style annotation str with the old style syntax, and see if we
# find a match.
# try:
try:
evaluated_t = evaluate_string_annotation(t)
# NOTE: We now have a 'live'/runtime type annotation object from the typing module.
except (ValueError, TypeError) as err:
logger.error(f"Unable to evaluate the type annotation string {t}: {err}.")
else:
if evaluated_t in _decoding_fns:
return _decoding_fns[evaluated_t]
# If we still don't have this annotation stored in our dict of known functions, we
# recurse, to try to deconstruct this annotation into its parts, and construct the
# decoding function for the annotation. If this doesn't work, we just raise the
# errors.
return get_decoding_fn(evaluated_t)

raise ValueError(
f"Couldn't find a decoding function for the string annotation '{t}'.\n"
f"This is probably a bug. If it is, please make an issue on GitHub so we can get "
f"to work on fixing it."
f"to work on fixing it.\n"
f"Types with a known decoding function: {list(_decoding_fns.keys())}"
)
if len(potential_keys) == 1:
t = potential_keys[0]
Expand Down Expand Up @@ -172,12 +200,7 @@ def get_decoding_fn(t: Type[T]) -> Callable[[Any], T]:
logger.debug(f"Decoding an Enum field: {t}")
return decode_enum(t)

from .serializable import (
get_dataclass_types_from_forward_ref,
Serializable,
SerializableMixin,
FrozenSerializable,
)
from .serializable import SerializableMixin, get_dataclass_types_from_forward_ref

if is_forward_ref(t):
dcs = get_dataclass_types_from_forward_ref(t)
Expand Down Expand Up @@ -250,9 +273,7 @@ def _try_functions(val: Any) -> Union[T, Any]:
except Exception as ex:
e = ex
else:
logger.debug(
f"Couldn't parse value {val}, returning it as-is. (exception: {e})"
)
logger.debug(f"Couldn't parse value {val}, returning it as-is. (exception: {e})")
return val

return _try_functions
Expand Down Expand Up @@ -330,9 +351,7 @@ def _decode_set(val: List[Any]) -> Set[T]:
return _decode_set


def decode_dict(
K_: Type[K], V_: Type[V]
) -> Callable[[List[Tuple[Any, Any]]], Dict[K, V]]:
def decode_dict(K_: Type[K], V_: Type[V]) -> Callable[[List[Tuple[Any, Any]]], Dict[K, V]]:
"""Creates a decoding function for a dict type. Works with OrderedDict too.

Args:
Expand Down Expand Up @@ -376,8 +395,10 @@ def decode_enum(item_type: Type[Enum]) -> Callable[[str], Enum]:
Returns:
Callable[[str], Enum]: A function that returns the enum member for the given name.
"""

def _decode_enum(val: str) -> Enum:
return item_type[val]

return _decode_enum


Expand All @@ -402,6 +423,7 @@ def try_constructor(t: Type[T]) -> Callable[[Any], Union[T, Any]]:
Returns:
Callable[[Any], Union[T, Any]]: A decoding function that might return nothing.
"""

def constructor(val):
if isinstance(val, Mapping):
return t(**val)
Expand Down
3 changes: 2 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from simple_parsing import choice
from simple_parsing.helpers import Serializable

collect_ignore = []
collect_ignore: List[str] = []
if sys.version_info < (3, 7):
collect_ignore.append("test_future_annotations.py")
collect_ignore.append("test_issue_144.py")


# List of simple attributes to use in test:
Expand Down
69 changes: 69 additions & 0 deletions test/test_issue_144.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
""" Tests for issue 144: https://github.com/lebrice/SimpleParsing/issues/144 """
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union

import pytest

from simple_parsing.helpers.serialization.serializable import Serializable


class TestOptional:
@dataclass
class Foo(Serializable):
foo: Optional[int] = 123

@pytest.mark.parametrize("d", [{"foo": None}, {"foo": 1}])
def test_round_trip(self, d: dict):
# NOTE: this double round-trip makes the comparison agnostic to any conversion that may
# happen between the raw dict values and the arguments of the dataclasses.
assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d)


class TestUnion:
@dataclass
class Foo(Serializable):
foo: Union[int, dict[int, bool]] = 123

@pytest.mark.parametrize("d", [{"foo": None}, {"foo": {1: "False"}}])
def test_round_trip(self, d: dict):
# NOTE: this double round-trip makes the comparison agnostic to any conversion that may
# happen between the raw dict values and the arguments of the dataclasses.
assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d)


class TestList:
@dataclass
class Foo(Serializable):
foo: List[int] = field(default_factory=list)

@pytest.mark.parametrize("d", [{"foo": []}, {"foo": [123, 456]}])
def test_round_trip(self, d: dict):
# NOTE: this double round-trip makes the comparison agnostic to any conversion that may
# happen between the raw dict values and the arguments of the dataclasses.
assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d)


class TestTuple:
@dataclass
class Foo(Serializable):
foo: Tuple[int, float, bool]

@pytest.mark.parametrize("d", [{"foo": (1, 1.2, False)}, {"foo": ("1", "1.2", "True")}])
def test_round_trip(self, d: dict):
# NOTE: this double round-trip makes the comparison agnostic to any conversion that may
# happen between the raw dict values and the arguments of the dataclasses.
assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d)


class TestDict:
@dataclass
class Foo(Serializable):
foo: Dict[int, float] = field(default_factory=dict)

@pytest.mark.parametrize("d", [{"foo": {}}, {"foo": {"123": "4.56"}}])
def test_round_trip(self, d: dict):
# NOTE: this double round-trip makes the comparison agnostic to any conversion that may
# happen between the raw dict values and the arguments of the dataclasses.
assert self.Foo.from_dict(self.Foo.from_dict(d).to_dict()) == self.Foo.from_dict(d)