From 848066fdbcfe00b24101601491ef662eb49d675d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 7 Jun 2022 17:06:51 -0400 Subject: [PATCH] Add the globals of the dataclass when evaluating Signed-off-by: Fabrice Normandin --- .../annotation_utils/get_field_annotations.py | 14 +++--- .../helpers/serialization/decoding.py | 5 +- .../helpers/serialization/serializable.py | 48 ++++++------------- 3 files changed, 26 insertions(+), 41 deletions(-) diff --git a/simple_parsing/annotation_utils/get_field_annotations.py b/simple_parsing/annotation_utils/get_field_annotations.py index 7586b391..9cc64e8b 100644 --- a/simple_parsing/annotation_utils/get_field_annotations.py +++ b/simple_parsing/annotation_utils/get_field_annotations.py @@ -2,7 +2,7 @@ import sys import typing from logging import getLogger as get_logger -from typing import Any, Dict, get_type_hints +from typing import Any, Dict, Optional, get_type_hints logger = get_logger(__name__) @@ -19,7 +19,7 @@ import types -def evaluate_string_annotation(annotation: str) -> type: +def evaluate_string_annotation(annotation: str, containing_class: Optional[type] = None) -> type: """Attempts to evaluate the given annotation string, to get a 'live' type annotation back. Any exceptions that are raised when evaluating are raised directly as-is. @@ -33,12 +33,14 @@ def evaluate_string_annotation(annotation: str) -> type: # Get the local and global namespaces to pass to the `get_type_hints` function. local_ns: Dict[str, Any] = {"typing": typing, **vars(typing)} local_ns.update(forward_refs_to_types) - # Get the globals in the module where the class was defined. - # global_ns = sys.modules[some_class.__module__].__dict__ - # from typing import get_type_hints + global_ns = {} + if containing_class: + # Get the globals in the module where the class was defined. + global_ns = sys.modules[containing_class.__module__].__dict__ + if "|" in annotation: annotation = _get_old_style_annotation(annotation) - evaluated_t: type = eval(annotation, local_ns, {}) + evaluated_t: type = eval(annotation, local_ns, global_ns) return evaluated_t diff --git a/simple_parsing/helpers/serialization/decoding.py b/simple_parsing/helpers/serialization/decoding.py index 9934d1c5..99788134 100644 --- a/simple_parsing/helpers/serialization/decoding.py +++ b/simple_parsing/helpers/serialization/decoding.py @@ -50,7 +50,7 @@ def decode_bool(v: Any) -> bool: _decoding_fns[bool] = decode_bool -def decode_field(field: Field, raw_value: Any) -> Any: +def decode_field(field: Field, raw_value: Any, containing_dataclass: Optional[type] = None) -> Any: """Converts a "raw" value (e.g. from json file) to the type of the `field`. When serializing a dataclass to json, all objects are converted to dicts. @@ -76,6 +76,9 @@ def decode_field(field: Field, raw_value: Any) -> Any: if custom_decoding_fn is not None: return custom_decoding_fn(raw_value) + if isinstance(field_type, str) and containing_dataclass: + field_type = evaluate_string_annotation(field_type, containing_dataclass) + return get_decoding_fn(field_type)(raw_value) diff --git a/simple_parsing/helpers/serialization/serializable.py b/simple_parsing/helpers/serialization/serializable.py index dd0754ec..a770eacf 100644 --- a/simple_parsing/helpers/serialization/serializable.py +++ b/simple_parsing/helpers/serialization/serializable.py @@ -1,5 +1,4 @@ import json -import sys import warnings from collections import OrderedDict from dataclasses import MISSING, Field, dataclass, fields, is_dataclass @@ -8,7 +7,8 @@ from pathlib import Path from typing import IO, Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar, Union -from simple_parsing.utils import is_optional, get_args, get_forward_arg +from simple_parsing.utils import get_args, get_forward_arg, is_optional + from .decoding import decode_field, register_decoding_fn from .encoding import SimpleJsonEncoder, encode @@ -25,9 +25,7 @@ def ordered_dict_constructor(loader: yaml.Loader, node: yaml.Node): value = loader.construct_sequence(node, deep=True) return OrderedDict(*value) - def ordered_dict_representer( - dumper: yaml.Dumper, instance: OrderedDict - ) -> yaml.Node: + def ordered_dict_representer(dumper: yaml.Dumper, instance: OrderedDict) -> yaml.Node: # NOTE(ycho): nested list for compatibility with PyYAML's representer node = dumper.represent_sequence("OrderedDict", [list(instance.items())]) return node @@ -69,9 +67,7 @@ class SerializableMixin: subclasses: ClassVar[List[Type[D]]] = [] decode_into_subclasses: ClassVar[bool] = False - def __init_subclass__( - cls, decode_into_subclasses: bool = None, add_variants: bool = True - ): + def __init_subclass__(cls, decode_into_subclasses: bool = None, add_variants: bool = True): logger.debug(f"Registering a new Serializable subclass: {cls}") super().__init_subclass__() if decode_into_subclasses is None: @@ -229,13 +225,9 @@ def load( if load_fn is None and isinstance(path, Path): if path.name.endswith((".yml", ".yaml")): - return cls.load_yaml( - path, drop_extra_fields=drop_extra_fields, **kwargs - ) + return cls.load_yaml(path, drop_extra_fields=drop_extra_fields, **kwargs) elif path.name.endswith(".json"): - return cls.load_json( - path, drop_extra_fields=drop_extra_fields, **kwargs - ) + return cls.load_json(path, drop_extra_fields=drop_extra_fields, **kwargs) elif path.name.endswith(".pth"): import torch @@ -265,9 +257,7 @@ def load( if isinstance(path, Path): path = path.open() - return cls._load( - path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs - ) + return cls._load(path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs) @classmethod def _load( @@ -300,9 +290,7 @@ def load_json( Returns: D: an instance of the dataclass. """ - return cls.load( - path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs - ) + return cls.load(path, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) @classmethod def load_yaml( @@ -327,9 +315,7 @@ def load_yaml( if load_fn is None: load_fn = yaml.safe_load - return cls.load( - path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs - ) + return cls.load(path, load_fn=load_fn, drop_extra_fields=drop_extra_fields, **kwargs) def save(self, path: Union[str, Path], dump_fn=None, **kwargs) -> None: if not isinstance(path, Path): @@ -405,9 +391,7 @@ def loads_json( load_fn=json.loads, **kwargs, ) -> D: - return cls.loads( - s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs - ) + return cls.loads(s, drop_extra_fields=drop_extra_fields, load_fn=load_fn, **kwargs) @classmethod def loads_yaml( @@ -549,7 +533,7 @@ def from_dict(cls: Type[Dataclass], d: Dict[str, Any], drop_extra_fields: bool = continue raw_value = obj_dict.pop(name) - field_value = decode_field(field, raw_value) + field_value = decode_field(field, raw_value, containing_dataclass=cls) if field.init: init_args[name] = field_value @@ -585,9 +569,7 @@ def from_dict(cls: Type[Dataclass], d: Dict[str, Any], drop_extra_fields: bool = derived_classes.sort(key=lambda dc: len(get_init_fields(dc))) for child_class in derived_classes: - logger.debug( - f"child class: {child_class.__name__}, mro: {child_class.mro()}" - ) + logger.debug(f"child class: {child_class.__name__}, mro: {child_class.mro()}") child_init_fields: Dict[str, Field] = get_init_fields(child_class) child_init_field_names = set(child_init_fields.keys()) @@ -623,7 +605,7 @@ def get_first_non_None_type(optional_type: Union[Type, Tuple[Type, ...]]) -> Opt if not isinstance(optional_type, tuple): optional_type = get_args(optional_type) for arg in optional_type: - if arg is not Union and arg is not type(None): + if arg is not Union and arg is not type(None): # noqa: E721 logger.debug(f"arg: {arg} is not union? {arg is not Union}") logger.debug(f"arg is not type(None)? {arg is not type(None)}") return arg @@ -632,6 +614,4 @@ def get_first_non_None_type(optional_type: Union[Type, Tuple[Type, ...]]) -> Opt def is_dataclass_or_optional_dataclass_type(t: Type) -> bool: """Returns whether `t` is a dataclass type or an Optional[].""" - return is_dataclass(t) or ( - is_optional(t) and is_dataclass(get_args(t)[0]) - ) + return is_dataclass(t) or (is_optional(t) and is_dataclass(get_args(t)[0]))