Skip to content

Commit

Permalink
Fix inspection of unspecified args for container hparams (#9125)
Browse files Browse the repository at this point in the history
* Update parsing.py

* add todo (for single arg)

* unblock non container single arg

* init test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update CHANGELOG.md

* pep8 line length

* Update pytorch_lightning/utilities/parsing.py

* remove dict namespace conversion

* add omegaconf support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add dict test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add omegaconf test

* Update CHANGELOG.md

* Update pytorch_lightning/utilities/parsing.py

Co-authored-by: Jirka Borovec <[email protected]>

* Update pytorch_lightning/utilities/parsing.py

Co-authored-by: Jirka Borovec <[email protected]>

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
3 people authored and awaelchli committed Sep 7, 2021
1 parent 96541cf commit 185c4fd
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 40 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error handling in DDP process reconciliation when `_sync_dir` was not initialized ([#9267](https://github.com/PyTorchLightning/pytorch-lightning/pull/9267))


- Fixed inspection of other args when a container is specified in `save_hyperparameters` ([#9125](https://github.com/PyTorchLightning/pytorch-lightning/pull/9125))

## [1.4.5] - 2021-08-31

- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
Expand Down
87 changes: 51 additions & 36 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.warnings import rank_zero_warn

if _OMEGACONF_AVAILABLE:
from omegaconf.dictconfig import DictConfig


def str_to_bool_or_str(val: str) -> Union[str, bool]:
"""Possibly convert a string representation of truth to bool.
Expand Down Expand Up @@ -204,46 +208,57 @@ def save_hyperparameters(
obj: Any, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None
) -> None:
"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""

hparams_container_types = [Namespace, dict]
if _OMEGACONF_AVAILABLE:
hparams_container_types.append(DictConfig)
# empty container
if len(args) == 1 and not isinstance(args, str) and not args[0]:
# args[0] is an empty container
return

if not frame:
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
assert init_args, "failed to inspect the obj init"

if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
if isinstance(ignore, (list, tuple)):
ignore = [arg for arg in ignore if isinstance(arg, str)]
init_args = {k: v for k, v in init_args.items() if k not in ignore}

if not args:
# take all arguments
hp = init_args
obj._hparams_name = "kwargs" if hp else None
# container
elif len(args) == 1 and isinstance(args[0], tuple(hparams_container_types)):
hp = args[0]
obj._hparams_name = "hparams"
obj._set_hparams(hp)
obj._hparams_initial = copy.deepcopy(obj._hparams)
return
# non-container args parsing
else:
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
cand_names = [k for k, v in init_args.items() if v == hp]
obj._hparams_name = cand_names[0] if cand_names else None
if not frame:
current_frame = inspect.currentframe()
# inspect.currentframe() return type is Optional[types.FrameType]
# current_frame.f_back called only if available
if current_frame:
frame = current_frame.f_back
if not isinstance(frame, types.FrameType):
raise AttributeError("There is no `frame` available while being required.")

if is_dataclass(obj):
init_args = {f.name: getattr(obj, f.name) for f in fields(obj)}
else:
init_args = get_init_args(frame)
assert init_args, f"failed to inspect the obj init - {frame}"

if ignore is not None:
if isinstance(ignore, str):
ignore = [ignore]
if isinstance(ignore, (list, tuple, set)):
ignore = [arg for arg in ignore if isinstance(arg, str)]
init_args = {k: v for k, v in init_args.items() if k not in ignore}

if not args:
# take all arguments
hp = init_args
obj._hparams_name = "kwargs" if hp else None
else:
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
obj._hparams_name = "kwargs"
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
cand_names = [k for k, v in init_args.items() if v == hp]
obj._hparams_name = cand_names[0] if cand_names else None
else:
hp = {arg: init_args[arg] for arg in args if isinstance(arg, str)}
obj._hparams_name = "kwargs"

# `hparams` are expected here
if hp:
Expand Down
100 changes: 96 additions & 4 deletions tests/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import ArgumentParser
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from typing import Any, Dict
from unittest import mock
from unittest.mock import call, PropertyMock

import pytest
import torch
from omegaconf import OmegaConf

from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.datamodules import ClassifDataModule
Expand Down Expand Up @@ -532,12 +535,101 @@ def test_dm_init_from_datasets_dataloaders(iterable):
)


class DataModuleWithHparams(LightningDataModule):
# all args
class DataModuleWithHparams_0(LightningDataModule):
def __init__(self, arg0, arg1, kwarg0=None):
super().__init__()
self.save_hyperparameters()


def test_simple_hyperparameters_saving():
data = DataModuleWithHparams(10, "foo", kwarg0="bar")
# single arg
class DataModuleWithHparams_1(LightningDataModule):
def __init__(self, arg0, *args, **kwargs):
super().__init__()
self.save_hyperparameters(arg0)


def test_hyperparameters_saving():
data = DataModuleWithHparams_0(10, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"arg0": 10, "arg1": "foo", "kwarg0": "bar"})

data = DataModuleWithHparams_1(Namespace(**{"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})

data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})


def test_define_as_dataclass():
# makes sure that no functionality is broken and the user can still manually make
# super().__init__ call with parameters
# also tests all the dataclass features that can be enabled without breaking anything
@dataclass(init=True, repr=True, eq=True, order=True, unsafe_hash=True, frozen=False)
class BoringDataModule1(LightningDataModule):
batch_size: int
dims: int = 2

def __post_init__(self):
super().__init__(dims=self.dims)

# asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
# __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule1(batch_size=64).dims == 2
assert BoringDataModule1(batch_size=32)
assert hasattr(BoringDataModule1, "__repr__")
assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32)

# asserts inherent calling of super().__init__ in case user doesn't make the call
@dataclass
class BoringDataModule2(LightningDataModule):
batch_size: int

# asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
# __init__, __repr__, __eq__, __lt__, __le__, etc.
assert BoringDataModule2(batch_size=32)
assert hasattr(BoringDataModule2, "__repr__")
assert BoringDataModule2(batch_size=32).prepare_data() is None
assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32)

# checking for all the different multilevel inhertiance scenarios, for init call on LightningDataModule
@dataclass
class BoringModuleBase1(LightningDataModule):
num_features: int

class BoringModuleBase2(LightningDataModule):
def __init__(self, num_features: int):
self.num_features = num_features

@dataclass
class BoringModuleDerived1(BoringModuleBase1):
...

class BoringModuleDerived2(BoringModuleBase1):
def __init__(self):
...

@dataclass
class BoringModuleDerived3(BoringModuleBase2):
...

class BoringModuleDerived4(BoringModuleBase2):
def __init__(self):
...

assert hasattr(BoringModuleDerived1(num_features=2), "_has_prepared_data")
assert hasattr(BoringModuleDerived2(), "_has_prepared_data")
assert hasattr(BoringModuleDerived3(), "_has_prepared_data")
assert hasattr(BoringModuleDerived4(), "_has_prepared_data")


def test_inconsistent_prepare_data_per_node(tmpdir):
with pytest.raises(MisconfigurationException, match="Inconsistent settings found for `prepare_data_per_node`."):
model = BoringModel()
dm = BoringDataModule()
trainer = Trainer(prepare_data_per_node=False)
trainer.model = model
trainer.datamodule = dm
trainer.data_connector.prepare_data()

0 comments on commit 185c4fd

Please sign in to comment.