diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 6741236a7e5f5..12a29246888f7 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,16 +17,19 @@ import inspect import os from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO +from copy import deepcopy +from functools import partial +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn import torch import yaml from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities import AttributeDict, OMEGACONF_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.parsing import parse_class_init_keys PRIMITIVE_TYPES = (bool, int, float, str) @@ -34,6 +37,9 @@ if OMEGACONF_AVAILABLE: from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig + from omegaconf.errors import UnsupportedValueType, ValidationError + # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( @@ -321,9 +327,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. + Args: + config_yaml: Path to config yaml file + use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, + the hparams will be converted to `DictConfig` if possible + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) @@ -338,9 +349,15 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: return {} with fs.open(config_yaml, "r") as fp: - tags = yaml.full_load(fp) + hparams = yaml.full_load(fp) - return tags + if OMEGACONF_AVAILABLE: + if use_omegaconf: + try: + return OmegaConf.create(hparams) + except (UnsupportedValueType, ValidationError): + pass + return hparams def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: @@ -361,15 +378,16 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if OMEGACONF_AVAILABLE: - if OmegaConf.is_config(hparams): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(hparams, fp, resolve=True) - return - for v in hparams.values(): - if OmegaConf.is_config(v): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True) + # deepcopy: hparams from user shouldn't be resolved + hparams = deepcopy(hparams) + to_container = partial(OmegaConf.to_container, resolve=True) + hparams = apply_to_collection(hparams, DictConfig, to_container) + with fs.open(config_yaml, "w", encoding="utf-8") as fp: + try: + OmegaConf.save(hparams, fp) return + except (UnsupportedValueType, ValidationError): + pass assert isinstance(hparams, dict) hparams_allowed = {} diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e5641337cc8d2..c5dade86c348a 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -23,36 +23,16 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.package_utils import _module_available from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils - -def _module_available(module_path: str) -> bool: - """Testing if given module is avalaible in your env - - >>> _module_available('os') - True - >>> _module_available('bla.bla') - False - """ - # todo: find a better way than try / except - try: - mods = module_path.split('.') - assert mods, 'nothing given to test' - # it has to be tested as per partets - for i in range(len(mods)): - module_path = '.'.join(mods[:i + 1]) - if importlib.util.find_spec(module_path) is None: - return False - return True - except AttributeError: - return False - - +OMEGACONF_AVAILABLE = _module_available("omegaconf") APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") +HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") HOROVOD_AVAILABLE = _module_available("horovod.torch") BOLTS_AVAILABLE = _module_available("pl_bolts") diff --git a/pytorch_lightning/utilities/package_utils.py b/pytorch_lightning/utilities/package_utils.py new file mode 100644 index 0000000000000..99fd6fcc7ebb5 --- /dev/null +++ b/pytorch_lightning/utilities/package_utils.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + + +def _module_available(module_path: str) -> bool: + """Testing if given module is avalaible in your env + + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + # todo: find a better way than try / except + try: + mods = module_path.split('.') + assert mods, 'nothing given to test' + # it has to be tested as per partets + for i in range(len(mods)): + module_path = '.'.join(mods[:i + 1]) + if importlib.util.find_spec(module_path) is None: + return False + return True + except AttributeError: + return False diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index b207320c25ccc..5d90583345b4a 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -15,9 +15,11 @@ import inspect import pickle from argparse import Namespace -from typing import Dict, Union, Tuple +from typing import Dict, Tuple, Union from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.package_utils import _module_available def str_to_bool_or_str(val: str) -> Union[str, bool]: @@ -115,7 +117,6 @@ def get_init_args(frame) -> dict: self_var, args_var, kwargs_var = parse_class_init_keys(cls) filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') - # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args.update(local_args.get(kwargs_var, {})) diff --git a/tests/models/conf/config.yaml b/tests/models/conf/config.yaml new file mode 100644 index 0000000000000..faf751c24f6cb --- /dev/null +++ b/tests/models/conf/config.yaml @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: + - training: default + +log: ${training.log} diff --git a/tests/models/conf/training/default.yaml b/tests/models/conf/training/default.yaml new file mode 100644 index 0000000000000..2c35b22365420 --- /dev/null +++ b/tests/models/conf/training/default.yaml @@ -0,0 +1,2 @@ +# @package training +log: "Something" diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7df78d9760bd9..e354c6e708d95 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -15,19 +15,25 @@ import os import pickle from argparse import Namespace +from copy import deepcopy import cloudpickle import pytest import torch from fsspec.implementations.local import LocalFileSystem -from omegaconf import OmegaConf, Container +from omegaconf import Container, OmegaConf +from omegaconf.dictconfig import DictConfig from torch.nn import functional as F from torch.utils.data import DataLoader -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml -from pytorch_lightning.utilities import AttributeDict, is_picklable -from tests.base import EvalModelTemplate, TrialMNIST, BoringModel +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml +from pytorch_lightning.utilities import AttributeDict, HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable +from tests.base import BoringModel, EvalModelTemplate, TrialMNIST + +if HYDRA_EXPERIMENTAL_AVAILABLE: + from hydra.experimental import compose, initialize class SaveHparamsModel(BoringModel): @@ -483,13 +489,13 @@ def test_hparams_save_yaml(tmpdir): path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') save_hparams_to_yaml(path_yaml, hparams) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams @@ -642,3 +648,46 @@ def test_model_with_fsspec_as_parameter(tmpdir): ) trainer.fit(model) trainer.test() + + +@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available") +def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): + """ + This test relies on configuration saved under tests/models/conf/config.yaml + """ + + class TestHydraModel(BoringModel): + + def __init__(self, args_0, args_1, args_2, kwarg_1=None): + self.save_hyperparameters() + self.test_hparams() + config_file = f"{tmpdir}/hparams.yaml" + save_hparams_to_yaml(config_file, self.hparams) + self.hparams = load_hparams_from_yaml(config_file) + self.test_hparams() + super().__init__() + + def test_hparams(self): + assert self.hparams.args_0.log == "Something" + assert self.hparams.args_1['cfg'].log == "Something" + assert self.hparams.args_2[0].log == "Something" + assert self.hparams.kwarg_1['cfg'][0].log == "Something" + + with initialize(config_path="conf"): + args_0 = compose(config_name="config") + args_1 = {"cfg": compose(config_name="config")} + args_2 = [compose(config_name="config")] + kwarg_1 = {"cfg": [compose(config_name="config")]} + model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1) + epochs = 2 + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + trainer.fit(model) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)