Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions simple_parsing/annotation_utils/get_field_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import typing
from logging import getLogger as get_logger
from typing import Any, Dict, get_type_hints
from typing import Any, Dict, Optional, get_type_hints

logger = get_logger(__name__)

Expand All @@ -19,7 +19,7 @@
import types


def evaluate_string_annotation(annotation: str) -> type:
def evaluate_string_annotation(annotation: str, containing_class: Optional[type] = None) -> type:
"""Attempts to evaluate the given annotation string, to get a 'live' type annotation back.

Any exceptions that are raised when evaluating are raised directly as-is.
Expand All @@ -33,12 +33,14 @@ def evaluate_string_annotation(annotation: str) -> type:
# Get the local and global namespaces to pass to the `get_type_hints` function.
local_ns: Dict[str, Any] = {"typing": typing, **vars(typing)}
local_ns.update(forward_refs_to_types)
# Get the globals in the module where the class was defined.
# global_ns = sys.modules[some_class.__module__].__dict__
# from typing import get_type_hints
global_ns = {}
if containing_class:
# Get the globals in the module where the class was defined.
global_ns = sys.modules[containing_class.__module__].__dict__

if "|" in annotation:
annotation = _get_old_style_annotation(annotation)
evaluated_t: type = eval(annotation, local_ns, {})
evaluated_t: type = eval(annotation, local_ns, global_ns)
return evaluated_t


Expand Down
5 changes: 4 additions & 1 deletion simple_parsing/helpers/serialization/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def decode_bool(v: Any) -> bool:
_decoding_fns[bool] = decode_bool


def decode_field(field: Field, raw_value: Any) -> Any:
def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[type] = None) -> Any:
"""Converts a "raw" value (e.g. from json file) to the type of the `field`.

When serializing a dataclass to json, all objects are converted to dicts.
Expand All @@ -76,6 +76,9 @@ def decode_field(field: Field, raw_value: Any) -> Any:
if custom_decoding_fn is not None:
return custom_decoding_fn(raw_value)

if isinstance(field_type, str) and containing_dataclass:
field_type = evaluate_string_annotation(field_type, containing_dataclass)

return get_decoding_fn(field_type)(raw_value)


Expand Down
48 changes: 14 additions & 34 deletions simple_parsing/helpers/serialization/serializable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import sys
import warnings
from collections import OrderedDict
from dataclasses import MISSING, Field, dataclass, fields, is_dataclass
Expand All @@ -8,7 +7,8 @@
from pathlib import Path
from typing import IO, Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union

from simple_parsing.utils import is_optional, get_args, get_forward_arg
from simple_parsing.utils import get_args, get_forward_arg, is_optional

from .decoding import decode_field, register_decoding_fn
from .encoding import SimpleJsonEncoder, encode

Expand All @@ -25,9 +25,7 @@ def ordered_dict_constructor(loader: yaml.Loader, node: yaml.Node):
value = loader.construct_sequence(node, deep=True)
return OrderedDict(*value)

def ordered_dict_representer(
dumper: yaml.Dumper, instance: OrderedDict
) -> yaml.Node:
def ordered_dict_representer(dumper: yaml.Dumper, instance: OrderedDict) -> yaml.Node:
# NOTE(ycho): nested list for compatibility with PyYAML's representer
node = dumper.represent_sequence("OrderedDict", [list(instance.items())])
return node
Expand Down Expand Up @@ -69,9 +67,7 @@ class SerializableMixin:
subclasses: ClassVar[List[Type[D]]] = []
decode_into_subclasses: ClassVar[bool] = False

def __init_subclass__(
cls, decode_into_subclasses: bool = None, add_variants: bool = True
):
def __init_subclass__(cls, decode_into_subclasses: bool = None, add_variants: bool = True):
logger.debug(f"Registering a new Serializable subclass: {cls}")
super().__init_subclass__()
if decode_into_subclasses is None:
Expand Down Expand Up @@ -229,13 +225,9 @@ def load(

if load_fn is None and isinstance(path, Path):
if path.name.endswith((".yml", ".yaml")):
return cls.load_yaml(
path, drop_extra_fields=drop_extra_fields, **kwargs
)
return cls.load_yaml(path, drop_extra_fields=drop_extra_fields, **kwargs)
elif path.name.endswith(".json"):
return cls.load_json(
path, drop_extra_fields=drop_extra_fields, **kwargs
)
return cls.load_json(path, drop_extra_fields=drop_extra_fields, **kwargs)
elif path.name.endswith(".pth"):
import torch

Expand Down Expand Up @@ -265,9 +257,7 @@ def load(

if isinstance(path, Path):
path = path.open()
return cls._load(
path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs
)
return cls._load(path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs)

@classmethod
def _load(
Expand Down Expand Up @@ -300,9 +290,7 @@ def load_json(
Returns:
D: an instance of the dataclass.
"""
return cls.load(
path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs
)
return cls.load(path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs)

@classmethod
def load_yaml(
Expand All @@ -327,9 +315,7 @@ def load_yaml(

if load_fn is None:
load_fn = yaml.safe_load
return cls.load(
path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs
)
return cls.load(path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs)

def save(self, path: Union[str, Path], dump_fn=None, **kwargs) -> None:
if not isinstance(path, Path):
Expand Down Expand Up @@ -405,9 +391,7 @@ def loads_json(
load_fn=json.loads,
**kwargs,
) -> D:
return cls.loads(
s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs
)
return cls.loads(s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs)

@classmethod
def loads_yaml(
Expand Down Expand Up @@ -549,7 +533,7 @@ def from_dict(cls: Type[Dataclass], d: Dict[str, Any], drop_extra_fields: bool =
continue

raw_value = obj_dict.pop(name)
field_value = decode_field(field, raw_value)
field_value = decode_field(field, raw_value, containing_dataclass=cls)

if field.init:
init_args[name] = field_value
Expand Down Expand Up @@ -585,9 +569,7 @@ def from_dict(cls: Type[Dataclass], d: Dict[str, Any], drop_extra_fields: bool =
derived_classes.sort(key=lambda dc: len(get_init_fields(dc)))

for child_class in derived_classes:
logger.debug(
f"child class: {child_class.__name__}, mro: {child_class.mro()}"
)
logger.debug(f"child class: {child_class.__name__}, mro: {child_class.mro()}")
child_init_fields: Dict[str, Field] = get_init_fields(child_class)
child_init_field_names = set(child_init_fields.keys())

Expand Down Expand Up @@ -623,7 +605,7 @@ def get_first_non_None_type(optional_type: Union[Type, Tuple[Type, ...]]) -> Opt
if not isinstance(optional_type, tuple):
optional_type = get_args(optional_type)
for arg in optional_type:
if arg is not Union and arg is not type(None):
if arg is not Union and arg is not type(None): # noqa: E721
logger.debug(f"arg: {arg} is not union? {arg is not Union}")
logger.debug(f"arg is not type(None)? {arg is not type(None)}")
return arg
Expand All @@ -632,6 +614,4 @@ def get_first_non_None_type(optional_type: Union[Type, Tuple[Type, ...]]) -> Opt

def is_dataclass_or_optional_dataclass_type(t: Type) -> bool:
"""Returns whether `t` is a dataclass type or an Optional[<dataclass type>]."""
return is_dataclass(t) or (
is_optional(t) and is_dataclass(get_args(t)[0])
)
return is_dataclass(t) or (is_optional(t) and is_dataclass(get_args(t)[0]))