From e57a876177327e042f15b782104433b17f0ebd8a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 16:41:28 +0100 Subject: [PATCH 01/15] resolve bug --- pytorch_lightning/core/lightning.py | 1 - pytorch_lightning/utilities/parsing.py | 13 ++++++-- tests/models/conf/config.yaml | 17 ++++++++++ tests/models/conf/training/default.yaml | 15 +++++++++ tests/models/test_hparams.py | 41 ++++++++++++++++++++++--- 5 files changed, 79 insertions(+), 8 deletions(-) create mode 100644 tests/models/conf/config.yaml create mode 100644 tests/models/conf/training/default.yaml diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 421fc5e5cf2ac..beb4af9d9f1bb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -34,7 +34,6 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index b207320c25ccc..208d4a476b6cd 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -15,9 +15,13 @@ import inspect import pickle from argparse import Namespace -from typing import Dict, Union, Tuple +from typing import Dict, Tuple, Union + +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection def str_to_bool_or_str(val: str) -> Union[str, bool]: @@ -106,6 +110,11 @@ def _get_first_if_any(params, param_type): return n_self, n_args, n_kwargs +def resolve_dict_config(data): + data = OmegaConf.to_container(data, resolve=True) + return OmegaConf.create(data) + + def get_init_args(frame) -> dict: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: @@ -115,11 +124,11 @@ def get_init_args(frame) -> dict: self_var, args_var, kwargs_var = parse_class_init_keys(cls) filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') - # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} + local_args = apply_to_collection(local_args, DictConfig, resolve_dict_config) return local_args diff --git a/tests/models/conf/config.yaml b/tests/models/conf/config.yaml new file mode 100644 index 0000000000000..faf751c24f6cb --- /dev/null +++ b/tests/models/conf/config.yaml @@ -0,0 +1,17 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +defaults: + - training: default + +log: ${training.log} diff --git a/tests/models/conf/training/default.yaml b/tests/models/conf/training/default.yaml new file mode 100644 index 0000000000000..1c26d966a14f5 --- /dev/null +++ b/tests/models/conf/training/default.yaml @@ -0,0 +1,15 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# @package training +log: "Something" diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 7df78d9760bd9..de67b45d90a5b 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -20,14 +20,18 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem -from omegaconf import OmegaConf, Container +from omegaconf import Container, OmegaConf +from omegaconf.dictconfig import DictConfig from torch.nn import functional as F from torch.utils.data import DataLoader -from pytorch_lightning import Trainer, LightningModule -from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml -from pytorch_lightning.utilities import AttributeDict, is_picklable -from tests.base import EvalModelTemplate, TrialMNIST, BoringModel +from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml +from pytorch_lightning.utilities import AttributeDict, HYDRA_AVAILABLE, is_picklable +from tests.base import BoringModel, EvalModelTemplate, TrialMNIST + +if HYDRA_AVAILABLE: + from hydra.experimental import compose, initialize class SaveHparamsModel(BoringModel): @@ -642,3 +646,30 @@ def test_model_with_fsspec_as_parameter(tmpdir): ) trainer.fit(model) trainer.test() + + +@pytest.mark.skipif(not HYDRA_AVAILABLE, reason="Hydra is not available") +def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): + + initialize(config_path="conf") + + cfg = compose(config_name="config") + + class TestModel(BoringModel): + + def __init__(self, cfg): + self.save_hyperparameters() + assert isinstance(self.hparams.cfg, DictConfig) + assert self.hparams.cfg.log == "Something" + super().__init__() + + model = TestModel(cfg) + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=2, + max_epochs=1, + ) + trainer.fit(model) + trainer.test() From 3f2e9d7aafa49e9b48a59151b6c0e647df6727e7 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 7 Jan 2021 16:50:46 +0100 Subject: [PATCH 02/15] Apply suggestions from code review --- tests/models/test_hparams.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index de67b45d90a5b..bffd4a7a9491b 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,8 +30,6 @@ from pytorch_lightning.utilities import AttributeDict, HYDRA_AVAILABLE, is_picklable from tests.base import BoringModel, EvalModelTemplate, TrialMNIST -if HYDRA_AVAILABLE: - from hydra.experimental import compose, initialize class SaveHparamsModel(BoringModel): From a861180b1c958481ad63bf91ef18d74e3710abf1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 18:05:13 +0100 Subject: [PATCH 03/15] resolve package import --- pytorch_lightning/utilities/__init__.py | 24 +------------ pytorch_lightning/utilities/package_utils.py | 36 ++++++++++++++++++++ pytorch_lightning/utilities/parsing.py | 13 ++++--- 3 files changed, 46 insertions(+), 27 deletions(-) create mode 100644 pytorch_lightning/utilities/package_utils.py diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e5641337cc8d2..d98123f73a673 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -23,32 +23,10 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities.package_utils import _module_available from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils - -def _module_available(module_path: str) -> bool: - """Testing if given module is avalaible in your env - - >>> _module_available('os') - True - >>> _module_available('bla.bla') - False - """ - # todo: find a better way than try / except - try: - mods = module_path.split('.') - assert mods, 'nothing given to test' - # it has to be tested as per partets - for i in range(len(mods)): - module_path = '.'.join(mods[:i + 1]) - if importlib.util.find_spec(module_path) is None: - return False - return True - except AttributeError: - return False - - APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") diff --git a/pytorch_lightning/utilities/package_utils.py b/pytorch_lightning/utilities/package_utils.py new file mode 100644 index 0000000000000..99fd6fcc7ebb5 --- /dev/null +++ b/pytorch_lightning/utilities/package_utils.py @@ -0,0 +1,36 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib + + +def _module_available(module_path: str) -> bool: + """Testing if given module is avalaible in your env + + >>> _module_available('os') + True + >>> _module_available('bla.bla') + False + """ + # todo: find a better way than try / except + try: + mods = module_path.split('.') + assert mods, 'nothing given to test' + # it has to be tested as per partets + for i in range(len(mods)): + module_path = '.'.join(mods[:i + 1]) + if importlib.util.find_spec(module_path) is None: + return False + return True + except AttributeError: + return False diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 208d4a476b6cd..a91c7dad5fbf8 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -17,11 +17,15 @@ from argparse import Namespace from typing import Dict, Tuple, Union -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig - from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.package_utils import _module_available + +OMEGACONF_AVAILABLE = _module_available("omegaconf") + +if OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig def str_to_bool_or_str(val: str) -> Union[str, bool]: @@ -128,7 +132,8 @@ def get_init_args(frame) -> dict: local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} - local_args = apply_to_collection(local_args, DictConfig, resolve_dict_config) + if OMEGACONF_AVAILABLE: + local_args = apply_to_collection(local_args, DictConfig, resolve_dict_config) return local_args From 3891278dc3a63147a53a62f1fea511f73f0ca8c1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 18:09:57 +0100 Subject: [PATCH 04/15] resolve import --- tests/models/test_hparams.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index bffd4a7a9491b..79c997dd656ca 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -30,6 +30,11 @@ from pytorch_lightning.utilities import AttributeDict, HYDRA_AVAILABLE, is_picklable from tests.base import BoringModel, EvalModelTemplate, TrialMNIST +if HYDRA_AVAILABLE: + try: + from hydra.experimental import compose, initialize + except Exception: + HYDRA_AVAILABLE = False class SaveHparamsModel(BoringModel): From 1c3e0779ee76eab169b3df4c121ba1bb14f975aa Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 8 Jan 2021 14:06:13 +0100 Subject: [PATCH 05/15] update on comments --- pytorch_lightning/core/saving.py | 8 ++++---- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/parsing.py | 13 ------------- 3 files changed, 5 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 6741236a7e5f5..4036f331b5b58 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,16 +17,16 @@ import inspect import os from argparse import Namespace -from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn import torch import yaml from pytorch_lightning import _logger as log -from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, OMEGACONF_AVAILABLE -from pytorch_lightning.utilities.cloud_io import load as pl_load +from pytorch_lightning.utilities import AttributeDict, OMEGACONF_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.parsing import parse_class_init_keys PRIMITIVE_TYPES = (bool, int, float, str) @@ -368,7 +368,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: for v in hparams.values(): if OmegaConf.is_config(v): with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True) + OmegaConf.save(OmegaConf.create(v), fp, resolve=True) return assert isinstance(hparams, dict) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index d98123f73a673..b256b39acb6de 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -27,6 +27,7 @@ from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils +OMEGACONF_AVAILABLE = _module_available("omegaconf") APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index a91c7dad5fbf8..5d90583345b4a 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -21,12 +21,6 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.package_utils import _module_available -OMEGACONF_AVAILABLE = _module_available("omegaconf") - -if OMEGACONF_AVAILABLE: - from omegaconf import OmegaConf - from omegaconf.dictconfig import DictConfig - def str_to_bool_or_str(val: str) -> Union[str, bool]: """Possibly convert a string representation of truth to bool. @@ -114,11 +108,6 @@ def _get_first_if_any(params, param_type): return n_self, n_args, n_kwargs -def resolve_dict_config(data): - data = OmegaConf.to_container(data, resolve=True) - return OmegaConf.create(data) - - def get_init_args(frame) -> dict: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: @@ -132,8 +121,6 @@ def get_init_args(frame) -> dict: local_args = {k: local_vars[k] for k in init_parameters.keys()} local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} - if OMEGACONF_AVAILABLE: - local_args = apply_to_collection(local_args, DictConfig, resolve_dict_config) return local_args From ce2b1316d1bc1248f7c17fe81530d605e85a4f26 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 8 Jan 2021 19:42:15 +0100 Subject: [PATCH 06/15] update on comments --- pytorch_lightning/utilities/__init__.py | 1 + tests/models/test_hparams.py | 23 +++++++---------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index b256b39acb6de..c5dade86c348a 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -32,6 +32,7 @@ NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast") OMEGACONF_AVAILABLE = _module_available("omegaconf") HYDRA_AVAILABLE = _module_available("hydra") +HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental") HOROVOD_AVAILABLE = _module_available("horovod.torch") BOLTS_AVAILABLE = _module_available("pl_bolts") diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 79c997dd656ca..531a67d4e05a2 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -27,14 +27,11 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import AttributeDict, HYDRA_AVAILABLE, is_picklable +from pytorch_lightning.utilities import AttributeDict, HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable from tests.base import BoringModel, EvalModelTemplate, TrialMNIST -if HYDRA_AVAILABLE: - try: - from hydra.experimental import compose, initialize - except Exception: - HYDRA_AVAILABLE = False +if HYDRA_EXPERIMENTAL_AVAILABLE: + from hydra.experimental import compose, initialize class SaveHparamsModel(BoringModel): @@ -651,8 +648,11 @@ def test_model_with_fsspec_as_parameter(tmpdir): trainer.test() -@pytest.mark.skipif(not HYDRA_AVAILABLE, reason="Hydra is not available") +@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available") def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): + """ + This test relies on configuration saved under tests/models/conf/config.yaml + """ initialize(config_path="conf") @@ -667,12 +667,3 @@ def __init__(self, cfg): super().__init__() model = TestModel(cfg) - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - max_epochs=1, - ) - trainer.fit(model) - trainer.test() From 03efd241e28bd216eff552cada3474e55e3df61f Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:13:34 +0100 Subject: [PATCH 07/15] hacky fix --- pytorch_lightning/core/saving.py | 27 +++++++++++------ tests/models/conf/training/default.yaml | 13 -------- tests/models/test_hparams.py | 40 +++++++++++++++++++------ 3 files changed, 49 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 4036f331b5b58..1430ecada1e3a 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -17,14 +17,17 @@ import inspect import os from argparse import Namespace +from copy import deepcopy from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn import torch import yaml +from omegaconf.dictconfig import DictConfig from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AttributeDict, OMEGACONF_AVAILABLE, rank_zero_warn +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.parsing import parse_class_init_keys @@ -337,6 +340,10 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} + if OMEGACONF_AVAILABLE: + with fs.open(config_yaml, "r") as fp: + return OmegaConf.load(fp) + with fs.open(config_yaml, "r") as fp: tags = yaml.full_load(fp) @@ -349,10 +356,13 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: config_yaml: path to new YAML file hparams: parameters to be saved """ + print(config_yaml) fs = get_filesystem(config_yaml) if not fs.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") + hparams = deepcopy(hparams) + # convert Namespace or AD to dict if isinstance(hparams, Namespace): hparams = vars(hparams) @@ -361,15 +371,14 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if OMEGACONF_AVAILABLE: - if OmegaConf.is_config(hparams): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(hparams, fp, resolve=True) - return - for v in hparams.values(): - if OmegaConf.is_config(v): - with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(OmegaConf.create(v), fp, resolve=True) - return + def resolve_dict_config(data): + data = OmegaConf.to_container(data, resolve=True) + return OmegaConf.create(data) + + hparams = apply_to_collection(hparams, DictConfig, resolve_dict_config) + with fs.open(config_yaml, "w", encoding="utf-8") as fp: + OmegaConf.save(hparams, fp) + return assert isinstance(hparams, dict) hparams_allowed = {} diff --git a/tests/models/conf/training/default.yaml b/tests/models/conf/training/default.yaml index 1c26d966a14f5..2c35b22365420 100644 --- a/tests/models/conf/training/default.yaml +++ b/tests/models/conf/training/default.yaml @@ -1,15 +1,2 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # @package training log: "Something" diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 531a67d4e05a2..23798758dab32 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -15,6 +15,7 @@ import os import pickle from argparse import Namespace +from copy import deepcopy import cloudpickle import pytest @@ -26,6 +27,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml from pytorch_lightning.utilities import AttributeDict, HYDRA_EXPERIMENTAL_AVAILABLE, is_picklable from tests.base import BoringModel, EvalModelTemplate, TrialMNIST @@ -654,16 +656,36 @@ def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): This test relies on configuration saved under tests/models/conf/config.yaml """ - initialize(config_path="conf") + class TestHydraModel(BoringModel): - cfg = compose(config_name="config") - - class TestModel(BoringModel): - - def __init__(self, cfg): + def __init__(self, args_1, args_2, kwarg_1=None): self.save_hyperparameters() - assert isinstance(self.hparams.cfg, DictConfig) - assert self.hparams.cfg.log == "Something" + self.test_hparams() + config_file = f"{tmpdir}/hparams.yaml" + save_hparams_to_yaml(config_file, self.hparams) + self.hparams = load_hparams_from_yaml(config_file) + self.test_hparams() super().__init__() - model = TestModel(cfg) + def test_hparams(self): + assert self.hparams.args_1['cfg'].log == "Something" + assert self.hparams.args_2[0].log == "Something" + assert self.hparams.kwarg_1['cfg'][0].log == "Something" + + with initialize(config_path="conf"): + args_1 = {"cfg": compose(config_name="config")} + args_2 = [compose(config_name="config")] + kwarg_1 = {"cfg": [compose(config_name="config")]} + model = TestHydraModel(args_1, args_2, kwarg_1=kwarg_1) + epochs = 2 + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + trainer.fit(model) + _ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path) From 4aedc6d5a52bb3f68bd2c2ba3e5c2f75fee30c0c Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:29:41 +0100 Subject: [PATCH 08/15] update --- pytorch_lightning/core/saving.py | 16 +++++++--------- tests/models/test_hparams.py | 6 ++++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 1430ecada1e3a..6ac49c8d97461 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -340,14 +340,14 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: rank_zero_warn(f"Missing Tags: {config_yaml}.", RuntimeWarning) return {} - if OMEGACONF_AVAILABLE: - with fs.open(config_yaml, "r") as fp: - return OmegaConf.load(fp) - with fs.open(config_yaml, "r") as fp: - tags = yaml.full_load(fp) + hparams = yaml.full_load(fp) - return tags + if OMEGACONF_AVAILABLE: + for k, v in hparams.items(): + hparams[k] = OmegaConf.create(v) + + return hparams def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: @@ -356,7 +356,6 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: config_yaml: path to new YAML file hparams: parameters to be saved """ - print(config_yaml) fs = get_filesystem(config_yaml) if not fs.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") @@ -372,8 +371,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if OMEGACONF_AVAILABLE: def resolve_dict_config(data): - data = OmegaConf.to_container(data, resolve=True) - return OmegaConf.create(data) + return OmegaConf.to_container(data, resolve=True) hparams = apply_to_collection(hparams, DictConfig, resolve_dict_config) with fs.open(config_yaml, "w", encoding="utf-8") as fp: diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 23798758dab32..1788332a81357 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -658,7 +658,7 @@ def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir): class TestHydraModel(BoringModel): - def __init__(self, args_1, args_2, kwarg_1=None): + def __init__(self, args_0, args_1, args_2, kwarg_1=None): self.save_hyperparameters() self.test_hparams() config_file = f"{tmpdir}/hparams.yaml" @@ -668,15 +668,17 @@ def __init__(self, args_1, args_2, kwarg_1=None): super().__init__() def test_hparams(self): + assert self.hparams.args_0.log == "Something" assert self.hparams.args_1['cfg'].log == "Something" assert self.hparams.args_2[0].log == "Something" assert self.hparams.kwarg_1['cfg'][0].log == "Something" with initialize(config_path="conf"): + args_0 = compose(config_name="config") args_1 = {"cfg": compose(config_name="config")} args_2 = [compose(config_name="config")] kwarg_1 = {"cfg": [compose(config_name="config")]} - model = TestHydraModel(args_1, args_2, kwarg_1=kwarg_1) + model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1) epochs = 2 checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) trainer = Trainer( From 5a33c3f59e31a790b25a427b6c1bf9f0465a5302 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:30:46 +0100 Subject: [PATCH 09/15] exit --- pytorch_lightning/core/saving.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 6ac49c8d97461..1960149cf460b 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -360,8 +360,6 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if not fs.isdir(os.path.dirname(config_yaml)): raise RuntimeError(f"Missing folder: {os.path.dirname(config_yaml)}.") - hparams = deepcopy(hparams) - # convert Namespace or AD to dict if isinstance(hparams, Namespace): hparams = vars(hparams) From e942045bac8105e1d21ed1a600313cdf648e53ef Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:33:01 +0100 Subject: [PATCH 10/15] update --- pytorch_lightning/core/saving.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 1960149cf460b..f44465e37bac7 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -368,6 +368,9 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if OMEGACONF_AVAILABLE: + # deepcopy: hparams from user is not resolved + hparams = deepcopy(hparams) + def resolve_dict_config(data): return OmegaConf.to_container(data, resolve=True) From 1b1d6942262e2d8ba589f35544c83009d62fca21 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:37:26 +0100 Subject: [PATCH 11/15] to_container --- pytorch_lightning/core/saving.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index f44465e37bac7..57f5d7bb70d1e 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -18,6 +18,7 @@ import os from argparse import Namespace from copy import deepcopy +from functools import partial from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union from warnings import warn @@ -370,11 +371,8 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: if OMEGACONF_AVAILABLE: # deepcopy: hparams from user is not resolved hparams = deepcopy(hparams) - - def resolve_dict_config(data): - return OmegaConf.to_container(data, resolve=True) - - hparams = apply_to_collection(hparams, DictConfig, resolve_dict_config) + to_container = partial(OmegaConf.to_container, resolve=True) + hparams = apply_to_collection(hparams, DictConfig, to_container) with fs.open(config_yaml, "w", encoding="utf-8") as fp: OmegaConf.save(hparams, fp) return From 2045be692e5a6591d15351f81132ab02881b2b1f Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 12:40:00 +0100 Subject: [PATCH 12/15] typo --- pytorch_lightning/core/saving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 57f5d7bb70d1e..0ff871ac962cc 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -369,7 +369,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: # saving with OmegaConf objects if OMEGACONF_AVAILABLE: - # deepcopy: hparams from user is not resolved + # deepcopy: hparams from user shouldn't be resolved hparams = deepcopy(hparams) to_container = partial(OmegaConf.to_container, resolve=True) hparams = apply_to_collection(hparams, DictConfig, to_container) From cf297f139203cac611bb8511153f77574ae95470 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 13:01:10 +0100 Subject: [PATCH 13/15] resolve import --- pytorch_lightning/core/lightning.py | 1 + pytorch_lightning/core/saving.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 792e60f4991fa..bd6784cc3b4bb 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -34,6 +34,7 @@ from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks from pytorch_lightning.core.memory import ModelSummary +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_warn, TPU_AVAILABLE diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 0ff871ac962cc..9794debb8aebd 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -24,7 +24,6 @@ import torch import yaml -from omegaconf.dictconfig import DictConfig from pytorch_lightning import _logger as log from pytorch_lightning.utilities import AttributeDict, OMEGACONF_AVAILABLE, rank_zero_warn @@ -38,6 +37,8 @@ if OMEGACONF_AVAILABLE: from omegaconf import OmegaConf + from omegaconf.dictconfig import DictConfig + # the older shall be on the top CHECKPOINT_PAST_HPARAMS_KEYS = ( From 2aced4d75f838499cf092073a4875779b44892b0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 14:10:53 +0100 Subject: [PATCH 14/15] update --- pytorch_lightning/core/saving.py | 23 +++++++++++++++++------ tests/models/test_hparams.py | 6 +++--- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 9794debb8aebd..7d6e3d571af7a 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -38,6 +38,7 @@ if OMEGACONF_AVAILABLE: from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig + from omegaconf.errors import UnsupportedValueType, ValidationError # the older shall be on the top @@ -326,9 +327,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: str, use_omegaconf:bool = True) -> Dict[str, Any]: """Load hparams from a file. + Args: + config_yaml: Path to config yaml file + use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True, + the hparams will be converted to `DictConfig` if possible + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') >>> path_yaml = './testing-hparams.yaml' >>> save_hparams_to_yaml(path_yaml, hparams) @@ -346,9 +352,11 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: hparams = yaml.full_load(fp) if OMEGACONF_AVAILABLE: - for k, v in hparams.items(): - hparams[k] = OmegaConf.create(v) - + if use_omegaconf: + try: + return OmegaConf.create(hparams) + except (UnsupportedValueType, ValidationError): + pass return hparams @@ -375,8 +383,11 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: to_container = partial(OmegaConf.to_container, resolve=True) hparams = apply_to_collection(hparams, DictConfig, to_container) with fs.open(config_yaml, "w", encoding="utf-8") as fp: - OmegaConf.save(hparams, fp) - return + try: + OmegaConf.save(hparams, fp) + return + except (UnsupportedValueType, ValidationError): + pass assert isinstance(hparams, dict) hparams_allowed = {} diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index 1788332a81357..e354c6e708d95 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -489,13 +489,13 @@ def test_hparams_save_yaml(tmpdir): path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml') save_hparams_to_yaml(path_yaml, hparams) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, Namespace(**hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) - assert load_hparams_from_yaml(path_yaml) == hparams + assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) assert load_hparams_from_yaml(path_yaml) == hparams From 3427b5013c3a994dfe40683f117c76e46ca07f9b Mon Sep 17 00:00:00 2001 From: tchaton Date: Sat, 9 Jan 2021 14:12:24 +0100 Subject: [PATCH 15/15] resolve pep8 --- pytorch_lightning/core/saving.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 7d6e3d571af7a..12a29246888f7 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -327,7 +327,7 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> writer.writerow({"key": k, "value": v}) -def load_hparams_from_yaml(config_yaml: str, use_omegaconf:bool = True) -> Dict[str, Any]: +def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]: """Load hparams from a file. Args: