diff --git a/docs/source/pytorch_lightning_example.rst b/docs/source/pytorch_lightning_example.rst index 5024ffafb..58c74196e 100644 --- a/docs/source/pytorch_lightning_example.rst +++ b/docs/source/pytorch_lightning_example.rst @@ -52,7 +52,7 @@ The following is the boilerplate-free code. | """ y = sum(V sigmoid(X W + b))""" | class ExperimentConfig: | | | optim: Any = builds( | | def __init__( | optim.Adam, | -| self, | zen_partial=True, | +| self, | zen_partial=True, | | num_neurons: int, | populate_full_signature=True, | | optim: Type[optim.Optimizer], | ) | | dataloader: Type[DataLoader], | | @@ -61,7 +61,7 @@ The following is the boilerplate-free code. | ): | batch_size=25, | | super().__init__() | shuffle=True, | | self.optim = optim | drop_last=True, | -| self.dataloader = dataloader | zen_partial=True, | +| self.dataloader = dataloader | zen_partial=True, | | self.training_domain = training_domain | ) | | self.target_fn = target_fn | | | | lightning_module: Any = builds( | diff --git a/src/hydra_zen/funcs.py b/src/hydra_zen/funcs.py index 7055a57f6..20d28d1a1 100644 --- a/src/hydra_zen/funcs.py +++ b/src/hydra_zen/funcs.py @@ -5,18 +5,23 @@ Simple helper functions used to implement `just` and `builds`. This module is designed specifically so that these functions have a legible module-path when they appear in configuration files. """ - import functools as _functools import typing as _typing from hydra._internal import utils as _hydra_internal_utils from hydra.utils import log as _log +from hydra_zen.structured_configs._utils import ( + is_interpolated_string as _is_interpolated_string, +) from hydra_zen.typing import Partial as _Partial -_T = _typing.TypeVar("_T") +__all__ = ["partial", "get_obj", "zen_processing"] -__all__ = ["partial", "get_obj"] +_T = _typing.TypeVar("_T") +_WrapperConf = _typing.Union[ + str, _typing.Callable[[_typing.Callable], _typing.Callable] +] def partial( @@ -45,10 +50,53 @@ def zen_processing( _zen_target: str, _zen_partial: bool = False, _zen_exclude: _typing.Sequence[str] = tuple(), + _zen_wrappers: _typing.Union[ + _WrapperConf, _typing.Sequence[_WrapperConf] + ] = tuple(), **kwargs, ): + if isinstance(_zen_wrappers, str) or not isinstance( + _zen_wrappers, _typing.Sequence + ): + unresolved_wrappers: _typing.Sequence[_WrapperConf] = (_zen_wrappers,) # type: ignore + else: + unresolved_wrappers: _typing.Sequence[_WrapperConf] = _zen_wrappers + del _zen_wrappers + + resolved_wrappers = [] + + for _unresolved in unresolved_wrappers: + if _unresolved is None: + # We permit interpolated fields to resolve to `None`; this is + # a nice pattern for enabling people to ergonomically toggle + # wrappers off. + continue + if isinstance(_unresolved, str): + # Hydra will have already raised on missing interpolation + # keys by here + assert not _is_interpolated_string(_unresolved) + _unresolved = get_obj(path=_unresolved) + + if not callable(_unresolved): + raise TypeError( + f"Instantiating {_zen_target}: `zen_wrappers` was passed a non-callable object: {_unresolved}" + ) + else: + resolved = _unresolved + del _unresolved + resolved_wrappers.append(resolved) + obj = get_obj(path=_zen_target) + # first wrapper listed should be called first + # [f1, f2, f3, ...] -> + # target = f1(target) + # target = f2(target) + # target = f3(target) + # ... + for wrapper in resolved_wrappers: + obj = wrapper(obj) + if _zen_exclude: excluded_set = set(_zen_exclude) kwargs = {k: v for k, v in kwargs.items() if k not in excluded_set} diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index 9d26f4880..d04747f99 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -14,6 +14,7 @@ List, Mapping, Optional, + Sequence, Set, Tuple, Type, @@ -24,6 +25,7 @@ overload, ) +from omegaconf import II from typing_extensions import Final, Literal, TypeGuard from hydra_zen.errors import HydraZenDeprecationWarning @@ -38,18 +40,19 @@ except ImportError: # pragma: no cover ufunc = None - _T = TypeVar("_T") +_T2 = TypeVar("_T2", bound=Callable) +_Wrapper = Callable[[_T2], _T2] +ZenWrapper = Union[ + Builds[_Wrapper], PartialBuilds[_Wrapper], Just[_Wrapper], _Wrapper, str +] -_ZEN_PROCESSING_LOCATION: Final[str] = _utils.get_obj_path(zen_processing) +# Hydra-specific fields _TARGET_FIELD_NAME: Final[str] = "_target_" _RECURSIVE_FIELD_NAME: Final[str] = "_recursive_" _CONVERT_FIELD_NAME: Final[str] = "_convert_" -_PARTIAL_TARGET_FIELD_NAME: Final[str] = "_zen_partial" -_META_FIELD_NAME: Final[str] = "_zen_exclude" -_ZEN_TARGET_FIELD_NAME: Final[str] = "_zen_target" _POS_ARG_FIELD_NAME: Final[str] = "_args_" -_JUST_FIELD_NAME: Final[str] = "path" + _HYDRA_FIELD_NAMES: FrozenSet[str] = frozenset( ( _TARGET_FIELD_NAME, @@ -59,6 +62,16 @@ ) ) +# hydra-zen-specific fields +_ZEN_PROCESSING_LOCATION: Final[str] = _utils.get_obj_path(zen_processing) +_ZEN_TARGET_FIELD_NAME: Final[str] = "_zen_target" +_PARTIAL_TARGET_FIELD_NAME: Final[str] = "_zen_partial" +_META_FIELD_NAME: Final[str] = "_zen_exclude" +_ZEN_WRAPPERS_FIELD_NAME: Final[str] = "_zen_wrappers" +_JUST_FIELD_NAME: Final[str] = "path" +# TODO: add _JUST_Target + +# signature param-types _POSITIONAL_ONLY: Final = inspect.Parameter.POSITIONAL_ONLY _POSITIONAL_OR_KEYWORD: Final = inspect.Parameter.POSITIONAL_OR_KEYWORD _VAR_POSITIONAL: Final = inspect.Parameter.VAR_POSITIONAL @@ -72,10 +85,6 @@ def _get_target(x): return getattr(x, _TARGET_FIELD_NAME) -_T = TypeVar("_T") -_T2 = TypeVar("_T2", bound=Callable) - - def _target_as_kwarg_deprecation(func: _T2) -> Callable[..., _T2]: @wraps(func) def wrapped(*args, **kwargs): @@ -186,6 +195,7 @@ def hydrated_dataclass( target: Callable, *pos_args: Any, zen_partial: bool = False, + zen_wrappers: Union[Optional[ZenWrapper], Sequence[Optional[ZenWrapper]]] = None, zen_meta: Optional[Mapping[str, Any]] = None, populate_full_signature: bool = False, hydra_recursive: Optional[bool] = None, @@ -193,7 +203,7 @@ def hydrated_dataclass( frozen: bool = False, **_kw, # reserved to deprecate hydra_partial ) -> Callable[[Type[_T]], Type[_T]]: - """A decorator that uses `hydra_zen.builds` to create a dataclass with the appropriate + """A decorator that uses `builds` to create a dataclass with the appropriate hydra-specific fields for specifying a structured config [1]_. Parameters @@ -210,6 +220,21 @@ def hydrated_dataclass( zen_partial : Optional[bool] (default=False) If True, then hydra-instantiation produces ``functools.partial(target, **kwargs)`` + zen_wrappers : Optional[Union[ZenWrapper, Sequence[ZenWrapper]]] + One or more wrappers, which will wrap `hydra_target` prior to instantiation. + E.g. specifying ``[f1, f2, f3]`` will instantiate as:: + + ``f3(f2(f1(hydra_target)))(*args, **kwargs)`` + + Wrappers can also be specified as interpolated strings [2]_ or targeted structured + configs. + + zen_meta: Optional[Mapping[str, Any]] + Specifies field-names and corresponding values that will be included in the + resulting dataclass, but that will *not* be used to build ``hydra_target`` + via instantiation. These are called "meta" fields. + + populate_full_signature : bool, optional (default=False) If True, then the resulting dataclass's ``__init__`` signature and fields will be populated according to the signature of `target`. @@ -217,19 +242,15 @@ def hydrated_dataclass( Values specified in ``**kwargs_for_target`` take precedent over the corresponding default values from the signature. - zen_meta: Optional[Mapping[str, Any]] - Specifies field-names and corresponding values that will be included in the - resulting dataclass, but that will *not* be used to build ``hydra_target`` - via instantiation. These are called "meta" fields. hydra_recursive : bool, optional (default=True) If True, then upon hydra will recursively instantiate all other - hydra-config objects nested within this dataclass [2]_. + hydra-config objects nested within this dataclass [3]_. If ``None``, the ``_recursive_`` attribute is not set on the resulting dataclass. hydra_convert: Optional[Literal["none", "partial", "all"]] (default="none") - Determines how hydra handles the non-primitive objects passed to `target` [3]_. + Determines how hydra handles the non-primitive objects passed to `target` [4]_. - ``"none"``: Passed objects are DictConfig and ListConfig, default - ``"partial"``: Passed objects are converted to dict and list, with @@ -251,8 +272,9 @@ def hydrated_dataclass( References ---------- .. [1] https://hydra.cc/docs/next/tutorials/structured_config/intro/ - .. [2] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#recursive-instantiation - .. [3] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#parameter-conversion-strategies + .. [2] https://omegaconf.readthedocs.io/en/2.1_branch/usage.html#variable-interpolation + .. [3] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#recursive-instantiation + .. [4] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#parameter-conversion-strategies Examples -------- @@ -302,6 +324,8 @@ def hydrated_dataclass( TypeError: Building: AdamW .. The following unexpected keyword argument(s) for torch.optim.adamw.AdamW was specified via inheritance from a base class: wieght_decay + + For more detailed examples, refer to `builds`. """ if "hydra_partial" in _kw: @@ -350,6 +374,7 @@ def wrapper(decorated_obj: Any) -> Any: populate_full_signature=populate_full_signature, hydra_recursive=hydra_recursive, hydra_convert=hydra_convert, + zen_wrappers=zen_wrappers, zen_partial=zen_partial, zen_meta=zen_meta, builds_bases=(decorated_obj,), @@ -457,6 +482,7 @@ def builds( hydra_target: Importable, *pos_args: Any, zen_partial: Literal[False] = False, + zen_wrappers: Union[Optional[ZenWrapper], Sequence[Optional[ZenWrapper]]] = None, zen_meta: Optional[Mapping[str, Any]] = None, populate_full_signature: bool = False, hydra_recursive: Optional[bool] = None, @@ -475,6 +501,7 @@ def builds( hydra_target: Importable, *pos_args: Any, zen_partial: Literal[True], + zen_wrappers: Union[Optional[ZenWrapper], Sequence[Optional[ZenWrapper]]] = None, zen_meta: Optional[Mapping[str, Any]] = None, populate_full_signature: bool = False, hydra_recursive: Optional[bool] = None, @@ -493,6 +520,7 @@ def builds( hydra_target: Importable, *pos_args: Any, zen_partial: bool, + zen_wrappers: Union[Optional[ZenWrapper], Sequence[Optional[ZenWrapper]]] = None, zen_meta: Optional[Mapping[str, Any]] = None, populate_full_signature: bool = False, hydra_recursive: Optional[bool] = None, @@ -512,6 +540,7 @@ def builds( def builds( *pos_args: Any, zen_partial: bool = False, + zen_wrappers: Union[Optional[ZenWrapper], Sequence[Optional[ZenWrapper]]] = None, zen_meta: Optional[Mapping[str, Any]] = None, populate_full_signature: bool = False, hydra_recursive: Optional[bool] = None, @@ -531,7 +560,7 @@ def builds( Parameters ---------- - hydra_target : Union[Instantiable, Callable] + hydra_target : Instantiable | Callable The object to be instantiated/called. This is a required, positional-only argument. *pos_args: Any @@ -555,6 +584,15 @@ def builds( user or that have default values specified in the target's signature. I.e. it is presumed that un-specified parameters are to be excluded from the partial configuration. + zen_wrappers : Optional[Callable | Builds | InterpStr | Sequence[Callable | Builds | InterpStr] + One or more wrappers, which will wrap `hydra_target` prior to instantiation. + E.g. specifying the wrappers ``[f1, f2, f3]`` will instantiate as:: + + f3(f2(f1(hydra_target)))(*args, **kwargs) + + Wrappers can also be specified as interpolated strings [2]_ or targeted structured + configs. + zen_meta: Optional[Mapping[str, Any]] Specifies field-names and corresponding values that will be included in the resulting dataclass, but that will *not* be used to build ``hydra_target`` @@ -572,12 +610,12 @@ def builds( hydra_recursive : Optional[bool], optional (default=True) If ``True``, then Hydra will recursively instantiate all other - hydra-config objects nested within this dataclass [2]_. + hydra-config objects nested within this dataclass [3]_. If ``None``, the ``_recursive_`` attribute is not set on the resulting dataclass. hydra_convert: Optional[Literal["none", "partial", "all"]], optional (default="none") - Determines how hydra handles the non-primitive objects passed to `target` [3]_. + Determines how hydra handles the non-primitive objects passed to `target` [4]_. - ``"none"``: Passed objects are DictConfig and ListConfig, default - ``"partial"``: Passed objects are converted to dict and list, with @@ -625,14 +663,15 @@ def builds( the target's signature. This helps to ensure that typos in field names fail early and explicitly. - Mutable values are automatically transformed to use a default factory [4]_. + Mutable values are automatically transformed to use a default factory [5]_. References ---------- .. [1] https://hydra.cc/docs/next/tutorials/structured_config/intro/ - .. [2] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#recursive-instantiation - .. [3] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#parameter-conversion-strategies - .. [4] https://docs.python.org/3/library/dataclasses.html#mutable-default-values + .. [2] https://omegaconf.readthedocs.io/en/2.1_branch/usage.html#variable-interpolation + .. [3] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#recursive-instantiation + .. [4] https://hydra.cc/docs/next/advanced/instantiate_objects/overview/#parameter-conversion-strategies + .. [5] https://docs.python.org/3/library/dataclasses.html#mutable-default-values Examples -------- @@ -684,6 +723,25 @@ def builds( {'a': -10, 'b': -10} >>> instantiate(Conf, s=2) {'a': 2, 'b': 2} + + Leveraging zen-wrappers to inject unit-conversion capabilities. Let's take + a function that converts Farenheit to Celcius, and wrap it so that it converts + to Kelvin instead. + + >>> def faren_to_celsius(temp_f): + ... return ((temp_f - 32) * 5) / 9 + + >>> def change_celcius_to_kelvin(celc_func): + ... def wraps(*args, **kwargs): + ... return 273.15 + celc_func(*args, **kwargs) + ... return wraps + + >>> AsCelcius = builds(faren_to_celsius) + >>> AsKelvin = builds(faren_to_celsius, zen_wrappers=change_celcius_to_kelvin) + >>> instantiate(AsCelcius, temp_f=32) + 0.0 + >>> instantiate(AsKelvin, temp_f=32) + 273.15 """ if not pos_args and not kwargs_for_target: @@ -742,11 +800,58 @@ def builds( f"`zen_meta` must be a mapping (e.g. a dictionary), got: {zen_meta}" ) - for _key in zen_meta: - if not isinstance(_key, str): - raise TypeError( - f"`zen_meta` must be a mapping whose keys are strings, got key: {_key}" - ) + if any(not isinstance(_key, str) for _key in zen_meta): + raise TypeError( + f"`zen_meta` must be a mapping whose keys are strings, got key(s):" + f" {','.join(str(_key) for _key in zen_meta if not isinstance(_key, str))}" + ) + + if zen_wrappers is not None: + if not isinstance(zen_wrappers, Sequence) or isinstance(zen_wrappers, str): + zen_wrappers = (zen_wrappers,) + + validated_wrappers: Sequence[Union[str, Builds]] = [] + for wrapper in zen_wrappers: + if wrapper is None: + continue + # We are intentionally keeping each condition branched + # so that test-coverage will be checked for each one + if is_builds(wrapper): + # If Hydra's locate function starts supporting importing literals + # – or if we decide to ship our own locate function – + # then we should get the target of `wrapper` and make sure it is callable + if is_just(wrapper): + # `zen_wrappers` handles importing string; we can + # elimintate the indirection of Just and "flatten" this + # config + validated_wrappers.append(getattr(wrapper, _JUST_FIELD_NAME)) + else: + if hydra_recursive is False: + warnings.warn( + "A structured config was supplied for `zen_wrappers`. Its parent config has " + "`hydra_recursive=False`.\n If this value is not toggled to `True`, the config's " + "instantiation will result in an error" + ) + validated_wrappers.append(wrapper) + + elif callable(wrapper): + validated_wrappers.append(_utils.get_obj_path(wrapper)) + + elif isinstance(wrapper, str): + # Assumed that wrapper is either a valid omegaconf-style interpolation string + # or a "valid" path for importing an object. The latter seems hopeless for validating: + # https://stackoverflow.com/a/47538106/6592114 + # so we can't make any assurances here. + validated_wrappers.append(wrapper) + else: + raise TypeError( + f"`zen_wrappers` requires a callable, targeted config, or a string, got: {wrapper}" + ) + + del zen_wrappers + validated_wrappers = tuple(validated_wrappers) + else: + validated_wrappers = () # Check for reserved names for _name in chain(kwargs_for_target, zen_meta): @@ -767,7 +872,7 @@ def builds( target_field: List[Union[Tuple[str, Type[Any]], Tuple[str, Type[Any], Field[Any]]]] - if zen_partial or zen_meta: + if zen_partial or zen_meta or validated_wrappers: target_field = [ ( _TARGET_FIELD_NAME, @@ -780,6 +885,7 @@ def builds( _utils.field(default=_utils.get_obj_path(target), init=False), ), ] + if zen_partial: target_field.append( ( @@ -788,14 +894,45 @@ def builds( _utils.field(default=True, init=False), ), ) + if zen_meta: target_field.append( ( _META_FIELD_NAME, - bool, + _utils.sanitized_type(Tuple[str, ...]), _utils.field(default=tuple(zen_meta), init=False), ), ) + + if validated_wrappers: + if zen_meta: + # Check to see + tuple( + _utils.check_suspicious_interpolations( + validated_wrappers, zen_meta=zen_meta, target=target + ) + ) + if len(validated_wrappers) == 1: + # we flatten the config to avoid unnecessary list + target_field.append( + ( + _ZEN_WRAPPERS_FIELD_NAME, + _utils.sanitized_type( + Union[Union[str, Builds], Tuple[Union[str, Builds], ...]] + ), + _utils.field(default=validated_wrappers[0], init=False), + ), + ) + else: + target_field.append( + ( + _ZEN_WRAPPERS_FIELD_NAME, + _utils.sanitized_type( + Union[Union[str, Builds], Tuple[Union[str, Builds], ...]] + ), + _utils.field(default=validated_wrappers, init=False), + ), + ) else: target_field = [ ( @@ -1020,7 +1157,6 @@ def builds( # We don't check for collisions between `zen_meta` names and the # names of inherited fields. Thus `zen_meta` can effectively be used # to "delete" names from a config, via inheritance. - user_specified_named_params.update( { name: (name, Any, sanitized_default_value(value)) @@ -1174,7 +1310,7 @@ def _is_old_partial_builds(x: Any) -> bool: # pragma: no cover return False -def uses_zen_processing(x: Any) -> bool: +def uses_zen_processing(x: Any) -> TypeGuard[Builds]: if not is_builds(x) or not hasattr(x, _ZEN_TARGET_FIELD_NAME): return False attr = _get_target(x) diff --git a/src/hydra_zen/structured_configs/_utils.py b/src/hydra_zen/structured_configs/_utils.py index 0f906940a..64d8b4170 100644 --- a/src/hydra_zen/structured_configs/_utils.py +++ b/src/hydra_zen/structured_configs/_utils.py @@ -1,7 +1,7 @@ # Copyright (c) 2021 Massachusetts Institute of Technology # SPDX-License-Identifier: MIT - import sys +import warnings from dataclasses import MISSING, Field, field as _field, is_dataclass from enum import Enum from typing import ( @@ -11,6 +11,7 @@ List, Mapping, Optional, + Sequence, Tuple, TypeVar, Union, @@ -18,7 +19,10 @@ overload, ) -from typing_extensions import Final +from omegaconf import II +from typing_extensions import Final, TypeGuard + +from hydra_zen.typing._implementations import InterpStr try: from typing import get_args, get_origin @@ -335,3 +339,39 @@ def sanitized_type( return type_ return Any + + +def is_interpolated_string(x: Any) -> TypeGuard[InterpStr]: + # This is only a necessary check – not a sufficient one – that `x` + # is a valid interpolated string. We do not verify that it rigorously + # satisfies omegaconf's grammar + return isinstance(x, str) and len(x) > 3 and x.startswith("${") and x.endswith("}") + + +def check_suspicious_interpolations( + validated_wrappers: Sequence[Any], zen_meta: Mapping[str, Any], target: Any +): + """Looks for patterns among zen_meta fields and interpolated fields in + wrappers. Relative interpolations pointing to the wrong level will produce + a warning""" + for _w in validated_wrappers: + if is_interpolated_string(_w): + _lvl = _w.count(".") # level of relative-interp + _field_name = _w.replace(".", "")[2:-1] + if ( + _lvl + and _field_name in zen_meta + and _lvl != (1 if len(validated_wrappers) == 1 else 2) + ): + _expected = II( + "." * (1 if len(validated_wrappers) == 1 else 2) + _field_name + ) + + warnings.warn( + building_error_prefix(target) + + f"A zen-wrapper is specified via the interpolated field, {_w}," + f" along with the meta-field name {_field_name}, however it " + f"appears to point to the wrong level. It is likely you should " + f"change {_w} to {_expected}" + ) + yield _expected diff --git a/src/hydra_zen/typing/_implementations.py b/src/hydra_zen/typing/_implementations.py index 2451d0715..631276d32 100644 --- a/src/hydra_zen/typing/_implementations.py +++ b/src/hydra_zen/typing/_implementations.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: MIT from dataclasses import Field -from typing import Any, Callable, Dict, Generic, Tuple, TypeVar +from typing import Any, Callable, Dict, Generic, NewType, Tuple, TypeVar from typing_extensions import Protocol, runtime_checkable @@ -16,7 +16,6 @@ _T = TypeVar("_T", covariant=True) -_T2 = TypeVar("_T2") class Partial(Generic[_T]): @@ -33,10 +32,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> _T: # pragma: no cover ... -class _Importable(Protocol): - __module__: str - __name__: str - +InterpStr = NewType("InterpStr", str) Importable = TypeVar("Importable") diff --git a/tests/annotations/declarations.py b/tests/annotations/declarations.py index 053c7c2c1..728a1e726 100644 --- a/tests/annotations/declarations.py +++ b/tests/annotations/declarations.py @@ -118,7 +118,6 @@ def f7(): ) b4: Literal["(x: int) -> int"] = reveal_type(get_target(just(f))) - # get_target(Builds[T]) -> T c1: Literal["Type[str]"] = reveal_type(get_target(builds(str)())) c2: Literal["Type[str]"] = reveal_type(get_target(builds(str, zen_partial=False)())) c3: Literal["Type[str]"] = reveal_type(get_target(builds(str, zen_partial=True)())) @@ -129,3 +128,42 @@ def f8(): @dataclass class A: x: List[int] = mutable_value([1, 2]) + + +def zen_wrappers(): + def f(obj): + return obj + + J = just(f) + B = builds(f, zen_partial=True) + PB = builds(f, zen_partial=True) + a1: Literal["Type[Builds[Type[str]]]"] = reveal_type(builds(str, zen_wrappers=f)) + a2: Literal["Type[Builds[Type[str]]]"] = reveal_type(builds(str, zen_wrappers=J)) + a3: Literal["Type[Builds[Type[str]]]"] = reveal_type(builds(str, zen_wrappers=B)) + a4: Literal["Type[Builds[Type[str]]]"] = reveal_type(builds(str, zen_wrappers=PB)) + a5: Literal["Type[Builds[Type[str]]]"] = reveal_type( + builds(str, zen_wrappers=(None,)) + ) + + a6: Literal["Type[Builds[Type[str]]]"] = reveal_type( + builds(str, zen_wrappers=(f, J, B, PB, None)) + ) + + b1: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=f) + ) + b2: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=J) + ) + b3: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=B) + ) + b4: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=PB) + ) + b5: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=(None,)) + ) + b6: Literal["Type[PartialBuilds[Type[str]]]"] = reveal_type( + builds(str, zen_partial=True, zen_wrappers=(f, J, B, PB, None)) + ) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8811136a0..de705501e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import collections.abc as abc import enum import random +import string import sys from dataclasses import dataclass, field as dataclass_field from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union @@ -10,7 +11,7 @@ import hypothesis.strategies as st import pytest from hypothesis import given -from omegaconf import OmegaConf, ValidationError +from omegaconf import II, OmegaConf, ValidationError from omegaconf.errors import ( ConfigIndexError, ConfigTypeError, @@ -20,8 +21,14 @@ from typing_extensions import Final, Literal from hydra_zen import builds, instantiate, mutable_value -from hydra_zen.structured_configs._utils import field, safe_name, sanitized_type +from hydra_zen.structured_configs._utils import ( + field, + is_interpolated_string, + safe_name, + sanitized_type, +) from hydra_zen.typing import Builds +from tests import everything_except T = TypeVar("T") @@ -207,6 +214,8 @@ def test_bare_generics(func, value): def test_vendored_field(): + # Test that our implementation of `field` matches that of `dataclasses.field + # The case where `default` is specified instead of `default_factory` # is already covered via our other tests @@ -229,3 +238,40 @@ class B: def test_builds_random_regression(): # was broken in `0.3.0rc3` assert 1 <= instantiate(builds(random.uniform, 1, 2)) <= 2 + + +def f_for_interp(*args, **kwargs): + return args[0] + + +# II renders a string in omegaconf's interpolated-field format +@given(st.text(alphabet=string.ascii_lowercase, min_size=1).map(II)) +def test_is_interpolated_against_omegaconf_generated_interpolated_strs(text): + assert is_interpolated_string(text) + + # ensure interpolation actually works + assert instantiate(builds(f_for_interp, text), **{text[2:-1]: 1}) == 1 + + +@given(everything_except(str)) +def test_non_strings_are_not_interpolated_strings(not_a_str): + assert not is_interpolated_string(not_a_str) + + +@given(st.text(alphabet=string.printable)) +def test_strings_that_fail_to_interpolate_are_not_interpolated_strings(any_text): + c = builds( + f_for_interp, any_text + ) # any_text is an attempt at an interpolated field + kwargs = {any_text[2:-1]: 1} + try: + # Interpreter raises if `any_text` is not a valid field name + # omegaconf raises if `any_text` causes a grammar error + out = instantiate(c, **kwargs) + except Exception: + # either fail case means `any_text` is not a valid interpolated string + assert not is_interpolated_string(any_text) + return + + # If `any_text` is a valid interpolated string, then `out == 1` + assert out == 1 or not is_interpolated_string(any_text) diff --git a/tests/test_zen_wrappers.py b/tests/test_zen_wrappers.py new file mode 100644 index 000000000..de2c99240 --- /dev/null +++ b/tests/test_zen_wrappers.py @@ -0,0 +1,300 @@ +# Copyright (c) 2021 Massachusetts Institute of Technology +# SPDX-License-Identifier: MIT +import string +from typing import Any, Callable, Dict, List, TypeVar, Union + +import hypothesis.strategies as st +import pytest +from hypothesis import given, settings +from omegaconf import OmegaConf +from omegaconf.errors import InterpolationKeyError, InterpolationResolutionError +from typing_extensions import Protocol + +from hydra_zen import builds, get_target, hydrated_dataclass, instantiate, just, to_yaml +from hydra_zen.structured_configs._implementations import is_builds +from hydra_zen.structured_configs._utils import is_interpolated_string +from hydra_zen.typing import Just, PartialBuilds +from hydra_zen.typing._implementations import InterpStr + +T = TypeVar("T", bound=Callable) + + +class TrackedFunc(Protocol): + tracked_id: int + + def __call__(self, obj: T) -> T: + ... + + +def _coordinate_meta_fields_for_interpolation(wrappers, zen_meta): + # Utility for testing + # + # Check if any of the wrappers are interpolation strings. + # If so: attach corresponding meta-fields so that the + # interpolated strings map to the named decorators + if is_interpolated_string(wrappers): + # change level of interpolation + wrappers = wrappers.replace("..", ".") # type: ignore + dec_name: str = wrappers[3:-1] + item = decorators_by_name[dec_name] + zen_meta[dec_name] = item if item is None else just(item) + elif isinstance(wrappers, list): + num_none = wrappers.count(None) + for n, wrapper in enumerate(wrappers): + if is_interpolated_string(wrapper): + if len(wrappers) - num_none == 1: + wrappers[n] = wrapper.replace("..", ".") + dec_name = wrapper[4:-1] + item = decorators_by_name[dec_name] + zen_meta[dec_name] = item if item is None else just(item) + return wrappers, zen_meta + + +def _resolve_wrappers(wrappers) -> List[TrackedFunc]: + # Utility for testing + if not isinstance(wrappers, list): + wrappers = [wrappers] + + # None and interp-none can be skipped - no wrapping happened + wrappers = [w for w in wrappers if w is not None] + wrappers = [w for w in wrappers if not (isinstance(w, str) and w.endswith("none}"))] + + # get wrappers from builds + wrappers = [get_target(w) if is_builds(w) else w for w in wrappers] + + # get wrappers from interpolated strings + wrappers = [ + decorators_by_name[w[2:-1].replace(".", "")] if isinstance(w, str) else w + for w in wrappers + ] + return wrappers # type: ignore + + +def tracked_decorator(obj): + if hasattr(obj, "num_decorated"): + obj.num_decorated = obj.num_decorated + 1 + else: + obj.num_decorated = 1 + return obj + + +# We will append the tracking-id of each wrapper function +# that is used. +TRACKED = [] + + +def f1(obj): + TRACKED.append(f1.tracked_id) + return tracked_decorator(obj) + + +def f2(obj): + TRACKED.append(f2.tracked_id) + return tracked_decorator(obj) + + +def f3(obj): + TRACKED.append(f3.tracked_id) + return tracked_decorator(obj) + + +f1.tracked_id = 1 +f2.tracked_id = 2 +f3.tracked_id = 3 + +decorators_by_name = dict(f1=f1, f2=f2, f3=f3, none=None) + + +def target(*args, **kwargs): + return args, kwargs + + +# prepare all variety of valid decorators to be tested +tracked_funcs = [f1, f2, f3, None] # adds TrackedFunc +tracked_funcs.extend(just(f) for f in [f1, f2, f3]) # adds Just[TrackedFunc] +tracked_funcs.extend(builds(f, zen_partial=True) for f in [f1, f2, f3]) +tracked_funcs.extend(["${..f1}", "${..f2}", "${..f3}", "${..none}"]) + +a_tracked_wrapper = st.sampled_from(tracked_funcs) + + +@settings(max_examples=500) # ensures coverage of various branches +@given( + wrappers=a_tracked_wrapper | st.lists(a_tracked_wrapper), + args=st.lists(st.integers()), + kwargs=st.dictionaries( + st.text(string.ascii_lowercase, min_size=1, max_size=1), st.integers() + ), + zen_partial=st.booleans(), + zen_meta=st.dictionaries( + st.text(string.ascii_lowercase, min_size=1, max_size=1).map(lambda x: "_" + x), + st.integers(), + max_size=2, + ), + as_yaml=st.booleans(), +) +def test_zen_wrappers_expected_behavior( + wrappers: Union[ # type: ignore + Union[TrackedFunc, Just[TrackedFunc], PartialBuilds[TrackedFunc], InterpStr], + List[ + Union[TrackedFunc, Just[TrackedFunc], PartialBuilds[TrackedFunc], InterpStr] + ], + ], + args: List[int], + kwargs: Dict[str, int], + zen_partial: bool, + zen_meta: Dict[str, Any], + as_yaml: bool, +): + """ + Tests: + - wrappers as functions + - wrappers as PartialBuilds + - wrappers as Just + - wrappers as interpolated strings + - zero or more wrappers + - that each wrapper is called once, in order, from left to right + - that each wrapper is passed the output of the previous wrapper + - that the args and kwargs passed to the target are passed as-expected + - that things interact as-expected with `zen_partial=True` + - that things interact as-expected with `zen_meta` + - that confs are serializable and produce the correct behavior + """ + TRACKED.clear() + if hasattr(target, "num_decorated"): + del target.num_decorated + + wrappers, zen_meta = _coordinate_meta_fields_for_interpolation(wrappers, zen_meta) # type: ignore + + args = tuple(args) # type: ignore + conf = builds( + target, + *args, + **kwargs, + zen_wrappers=wrappers, + zen_partial=zen_partial, + zen_meta=zen_meta + ) + if not as_yaml: + instantiated = instantiate(conf) + else: + # ensure serializable + conf = OmegaConf.create(to_yaml(conf)) + instantiated = instantiate(conf) + + out_args, out_kwargs = instantiated() if zen_partial else instantiated # type: ignore + + # ensure arguments passed-through as-expected + assert out_args == args + assert out_kwargs == kwargs + + # ensure zen_meta works as-expected + for meta_key, meta_val in zen_meta.items(): + assert getattr(conf, meta_key) == meta_val + + resolved_wrappers = _resolve_wrappers(wrappers) + + # ensure wrappers called in expected order and that + # each one wrapped the expected target + if resolved_wrappers: + assert len(resolved_wrappers) == target.num_decorated + assert TRACKED == [w.tracked_id for w in resolved_wrappers] + else: + assert not hasattr(target, "num_decorated") + assert not TRACKED + + +def test_wrapper_for_hydrated_dataclass(): + TRACKED.clear() + if hasattr(target, "num_decorated"): + del target.num_decorated + + @hydrated_dataclass(target, zen_wrappers=f1) + class A: + pass + + instantiate(A) + assert target.num_decorated == 1 + assert TRACKED == [f1.tracked_id] + + +class NotAWrapper: + pass + + +@pytest.mark.parametrize( + "bad_wrapper", + [ + 1, # not callable, + (1,), # not callable in sequence + (tracked_decorator, 1), # 1st ok, 2nd bad + ], +) +def test_zen_wrappers_validation_during_builds(bad_wrapper): + with pytest.raises(TypeError): + builds(int, zen_wrappers=bad_wrapper) + + +@pytest.mark.parametrize( + "bad_wrapper", + [ + NotAWrapper, # doesn't instantiate to a callable + (NotAWrapper,), + (builds(NotAWrapper),), + (tracked_decorator, builds(NotAWrapper)), + ], +) +def test_zen_wrappers_validation_during_instantiation(bad_wrapper): + conf = builds(int, zen_wrappers=bad_wrapper) + with pytest.raises(TypeError): + instantiate(conf) + + +@pytest.mark.parametrize( + "bad_wrapper", + ["${unresolved}", (tracked_decorator, "${unresolved}")], +) +def test_unresolved_interpolated_value_gets_caught(bad_wrapper): + conf = builds(int, zen_wrappers=bad_wrapper) + with pytest.raises(InterpolationKeyError): + instantiate(conf) + + +@pytest.mark.filterwarnings( + "ignore:A structured config was supplied for `zen_wrappers`" +) +def test_wrapper_via_builds_with_recusive_False(): + with pytest.warns(Warning): + builds(int, zen_wrappers=builds(f1, zen_partial=True), hydra_recursive=False) + + +@pytest.mark.filterwarnings( + "ignore:A zen-wrapper is specified via the interpolated field" +) +@pytest.mark.parametrize( + "wrappers, expected_match", + [ + ("${..s}", r"to \$\{.s\}"), + ("${...s}", r"to \$\{.s\}"), + (["${..s}"], r"to \$\{.s\}"), + (["${..s}", "${...s}"], r"to \$\{..s\}"), + (["${.s}", "${..s}", "${...s}"], r"to \$\{..s\}"), + ], +) +def test_bad_relative_interp_warns(wrappers, expected_match): + with pytest.warns(UserWarning, match=expected_match): + conf = builds(dict, zen_wrappers=wrappers, zen_meta=dict(s=1)) + + with pytest.raises(InterpolationResolutionError): + # make sure interpolation is actually bad + instantiate(conf) + + +@pytest.mark.parametrize( + "wrappers", + ["${s}", "${.s}", ["${.s}"], ["${..s}", "${..s}"], ["${..s}", "${..s}", "${..s}"]], +) +def test_interp_doesnt_warn(wrappers): + with pytest.warns(None) as record: + builds(dict, zen_wrappers=wrappers, zen_meta=dict(s=1)) + assert not record