Skip to content

Commit

Permalink
bugfix: Resolve interpolation bug with Hydra (#5406)
Browse files Browse the repository at this point in the history
* resolve bug

* Apply suggestions from code review

* resolve package import

* resolve import

* update on comments

* update on comments

* hacky fix

* update

* exit

* update

* to_container

* typo

* resolve import

* update

* resolve pep8

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: Sean Naren <[email protected]>

(cherry picked from commit bb5031b)
  • Loading branch information
tchaton authored and SeanNaren committed Jan 19, 2021
1 parent b9eaa84 commit 4c308cd
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 19 deletions.
46 changes: 32 additions & 14 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,29 @@
import inspect
import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
from warnings import warn

import torch
import yaml

from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities import AttributeDict, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.parsing import parse_class_init_keys

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)

if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from omegaconf.errors import UnsupportedValueType, ValidationError


# the older shall be on the top
CHECKPOINT_PAST_HPARAMS_KEYS = (
Expand Down Expand Up @@ -322,9 +328,14 @@ def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) ->
writer.writerow({"key": k, "value": v})


def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
def load_hparams_from_yaml(config_yaml: str, use_omegaconf: bool = True) -> Dict[str, Any]:
"""Load hparams from a file.
Args:
config_yaml: Path to config yaml file
use_omegaconf: If both `OMEGACONF_AVAILABLE` and `use_omegaconf` are True,
the hparams will be converted to `DictConfig` if possible
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
Expand All @@ -339,9 +350,15 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
return {}

with fs.open(config_yaml, "r") as fp:
tags = yaml.full_load(fp)
hparams = yaml.full_load(fp)

return tags
if _OMEGACONF_AVAILABLE:
if use_omegaconf:
try:
return OmegaConf.create(hparams)
except (UnsupportedValueType, ValidationError):
pass
return hparams


def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
Expand All @@ -362,15 +379,16 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:

# saving with OmegaConf objects
if _OMEGACONF_AVAILABLE:
if OmegaConf.is_config(hparams):
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
OmegaConf.save(hparams, fp, resolve=True)
return
for v in hparams.values():
if OmegaConf.is_config(v):
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True)
# deepcopy: hparams from user shouldn't be resolved
hparams = deepcopy(hparams)
to_container = partial(OmegaConf.to_container, resolve=True)
hparams = apply_to_collection(hparams, DictConfig, to_container)
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
try:
OmegaConf.save(hparams, fp)
return
except (UnsupportedValueType, ValidationError):
pass

if not isinstance(hparams, dict):
raise TypeError("hparams must be dictionary")
Expand Down
36 changes: 36 additions & 0 deletions pytorch_lightning/utilities/package_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib


def _module_available(module_path: str) -> bool:
"""Testing if given module is avalaible in your env
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
# todo: find a better way than try / except
try:
mods = module_path.split('.')
assert mods, 'nothing given to test'
# it has to be tested as per partets
for i in range(len(mods)):
module_path = '.'.join(mods[:i + 1])
if importlib.util.find_spec(module_path) is None:
return False
return True
except AttributeError:
return False
3 changes: 2 additions & 1 deletion pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from typing import Dict, Tuple, Union

from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.package_utils import _module_available


def str_to_bool_or_str(val: str) -> Union[str, bool]:
Expand Down Expand Up @@ -115,7 +117,6 @@ def get_init_args(frame) -> dict:
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n]
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')

# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_var, {}))
Expand Down
17 changes: 17 additions & 0 deletions tests/models/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
defaults:
- training: default

log: ${training.log}
2 changes: 2 additions & 0 deletions tests/models/conf/training/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# @package training
log: "Something"
54 changes: 50 additions & 4 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import AttributeDict, is_picklable
from pytorch_lightning.utilities import AttributeDict, is_picklable, _HYDRA_EXPERIMENTAL_AVAILABLE
from tests.base import BoringModel, EvalModelTemplate, TrialMNIST

if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize

class SaveHparamsModel(BoringModel):
""" Tests that a model can take an object """
Expand Down Expand Up @@ -483,13 +486,13 @@ def test_hparams_save_yaml(tmpdir):
path_yaml = os.path.join(tmpdir, 'testing-hparams.yaml')

save_hparams_to_yaml(path_yaml, hparams)
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams

save_hparams_to_yaml(path_yaml, Namespace(**hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams

save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
assert load_hparams_from_yaml(path_yaml, use_omegaconf=False) == hparams

save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
assert load_hparams_from_yaml(path_yaml) == hparams
Expand Down Expand Up @@ -636,3 +639,46 @@ def test_model_with_fsspec_as_parameter(tmpdir):
)
trainer.fit(model)
trainer.test()


@pytest.mark.skipif(not HYDRA_EXPERIMENTAL_AVAILABLE, reason="Hydra experimental is not available")
def test_model_save_hyper_parameters_interpolation_with_hydra(tmpdir):
"""
This test relies on configuration saved under tests/models/conf/config.yaml
"""

class TestHydraModel(BoringModel):

def __init__(self, args_0, args_1, args_2, kwarg_1=None):
self.save_hyperparameters()
self.test_hparams()
config_file = f"{tmpdir}/hparams.yaml"
save_hparams_to_yaml(config_file, self.hparams)
self.hparams = load_hparams_from_yaml(config_file)
self.test_hparams()
super().__init__()

def test_hparams(self):
assert self.hparams.args_0.log == "Something"
assert self.hparams.args_1['cfg'].log == "Something"
assert self.hparams.args_2[0].log == "Something"
assert self.hparams.kwarg_1['cfg'][0].log == "Something"

with initialize(config_path="conf"):
args_0 = compose(config_name="config")
args_1 = {"cfg": compose(config_name="config")}
args_2 = [compose(config_name="config")]
kwarg_1 = {"cfg": [compose(config_name="config")]}
model = TestHydraModel(args_0, args_1, args_2, kwarg_1=kwarg_1)
epochs = 2
checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint_callback],
limit_train_batches=10,
limit_val_batches=10,
max_epochs=epochs,
logger=False,
)
trainer.fit(model)
_ = TestHydraModel.load_from_checkpoint(checkpoint_callback.best_model_path)

0 comments on commit 4c308cd

Please sign in to comment.