diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a51b9d3264f9..94971f3cde2ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267)) +- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125)) + ## [1.4.5] - 2021-08-31 - Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142)) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 03a748b050dbc..2e347175a534e 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -22,8 +22,12 @@ from typing_extensions import Literal import pytorch_lightning as pl +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.warnings import rank_zero_warn +if _OMEGACONF_AVAILABLE: + from omegaconf.dictconfig import DictConfig + def str_to_bool_or_str(val: str) -> Union[str, bool]: """Possibly convert a string representation of truth to bool. @@ -204,46 +208,57 @@ def save_hyperparameters( obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: """See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`""" - + hparams_container_types = [Namespace, dict] + if _OMEGACONF_AVAILABLE: + hparams_container_types.append(DictConfig) + # empty container if len(args) == 1 and not isinstance(args, str) and not args[0]: - # args[0] is an empty container return - - if not frame: - current_frame = inspect.currentframe() - # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available - if current_frame: - frame = current_frame.f_back - if not isinstance(frame, types.FrameType): - raise AttributeError("There is no `frame` available while being required.") - - if is_dataclass(obj): - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} - else: - init_args = get_init_args(frame) - assert init_args, "failed to inspect the obj init" - - if ignore is not None: - if isinstance(ignore, str): - ignore = [ignore] - if isinstance(ignore, (list, tuple)): - ignore = [arg for arg in ignore if isinstance(arg, str)] - init_args = {k: v for k, v in init_args.items() if k not in ignore} - - if not args: - # take all arguments - hp = init_args - obj._hparams_name = "kwargs" if hp else None + # container + elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)): + hp = args[0] + obj._hparams_name = "hparams" + obj._set_hparams(hp) + obj._hparams_initial = copy.deepcopy(obj._hparams) + return + # non-container args parsing else: - # take only listed arguments in `save_hparams` - isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] - if len(isx_non_str) == 1: - hp = args[isx_non_str[0]] - cand_names = [k for k, v in init_args.items() if v == hp] - obj._hparams_name = cand_names[0] if cand_names else None + if not frame: + current_frame = inspect.currentframe() + # inspect.currentframe() return type is Optional[types.FrameType] + # current_frame.f_back called only if available + if current_frame: + frame = current_frame.f_back + if not isinstance(frame, types.FrameType): + raise AttributeError("There is no `frame` available while being required.") + + if is_dataclass(obj): + init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} + else: + init_args = get_init_args(frame) + assert init_args, f"failed to inspect the obj init - {frame}" + + if ignore is not None: + if isinstance(ignore, str): + ignore = [ignore] + if isinstance(ignore, (list, tuple, set)): + ignore = [arg for arg in ignore if isinstance(arg, str)] + init_args = {k: v for k, v in init_args.items() if k not in ignore} + + if not args: + # take all arguments + hp = init_args + obj._hparams_name = "kwargs" if hp else None else: - hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} - obj._hparams_name = "kwargs" + # take only listed arguments in `save_hparams` + isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)] + if len(isx_non_str) == 1: + hp = args[isx_non_str[0]] + cand_names = [k for k, v in init_args.items() if v == hp] + obj._hparams_name = cand_names[0] if cand_names else None + else: + hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)} + obj._hparams_name = "kwargs" # `hparams` are expected here if hp: diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index eebbe3ec2138f..78d6fbe7b32fa 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import pickle -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace +from dataclasses import dataclass from typing import Any, Dict from unittest import mock from unittest.mock import call, PropertyMock import pytest import torch +from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel from tests.helpers.datamodules import ClassifDataModule @@ -532,12 +535,101 @@ def test_dm_init_from_datasets_dataloaders(iterable): ) -class DataModuleWithHparams(LightningDataModule): +# all args +class DataModuleWithHparams_0(LightningDataModule): def __init__(self, arg0, arg1, kwarg0=None): super().__init__() self.save_hyperparameters() -def test_simple_hyperparameters_saving(): - data = DataModuleWithHparams(10, "foo", kwarg0="bar") +# single arg +class DataModuleWithHparams_1(LightningDataModule): + def __init__(self, arg0, *args, **kwargs): + super().__init__() + self.save_hyperparameters(arg0) + + +def test_hyperparameters_saving(): + data = DataModuleWithHparams_0(10, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"}) + + data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"hello": "world"}) + + data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar") + assert data.hparams == AttributeDict({"hello": "world"}) + + data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") + assert data.hparams == OmegaConf.create({"hello": "world"}) + + +def test_define_as_dataclass(): + # makes sure that no functionality is broken and the user can still manually make + # super().__init__ call with parameters + # also tests all the dataclass features that can be enabled without breaking anything + @dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False) + class BoringDataModule1(LightningDataModule): + batch_size: int + dims: int = 2 + + def __post_init__(self): + super().__init__(dims=self.dims) + + # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. + # __repr__, __eq__, __lt__, __le__, etc. + assert BoringDataModule1(batch_size=64).dims == 2 + assert BoringDataModule1(batch_size=32) + assert hasattr(BoringDataModule1, "__repr__") + assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) + + # asserts inherent calling of super().__init__ in case user doesn't make the call + @dataclass + class BoringDataModule2(LightningDataModule): + batch_size: int + + # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. + # __init__, __repr__, __eq__, __lt__, __le__, etc. + assert BoringDataModule2(batch_size=32) + assert hasattr(BoringDataModule2, "__repr__") + assert BoringDataModule2(batch_size=32).prepare_data() is None + assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) + + # checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule + @dataclass + class BoringModuleBase1(LightningDataModule): + num_features: int + + class BoringModuleBase2(LightningDataModule): + def __init__(self, num_features: int): + self.num_features = num_features + + @dataclass + class BoringModuleDerived1(BoringModuleBase1): + ... + + class BoringModuleDerived2(BoringModuleBase1): + def __init__(self): + ... + + @dataclass + class BoringModuleDerived3(BoringModuleBase2): + ... + + class BoringModuleDerived4(BoringModuleBase2): + def __init__(self): + ... + + assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data") + assert hasattr(BoringModuleDerived2(), "_has_prepared_data") + assert hasattr(BoringModuleDerived3(), "_has_prepared_data") + assert hasattr(BoringModuleDerived4(), "_has_prepared_data") + + +def test_inconsistent_prepare_data_per_node(tmpdir): + with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."): + model = BoringModel() + dm = BoringDataModule() + trainer = Trainer(prepare_data_per_node=False) + trainer.model = model + trainer.datamodule = dm + trainer.data_connector.prepare_data()