From be60efb3cf6bbb9c0dbc0279f9bdbb805e5856f8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 16 Nov 2020 11:02:26 +0100 Subject: [PATCH] allow decorate model init with saving hparams (#4662) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * addd tests * use boring model * parsing init * chlog * double decorate * Apply suggestions from code review Co-authored-by: Carlos MocholĂ­ * bug Co-authored-by: chaton Co-authored-by: Carlos MocholĂ­ Co-authored-by: Nicki Skafte Co-authored-by: Roger Shieh --- CHANGELOG.md | 4 ++- pytorch_lightning/core/saving.py | 6 ++-- pytorch_lightning/utilities/parsing.py | 43 ++++++++++++++++++++------ tests/models/test_hparams.py | 43 +++++++++++++++++++++++--- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ecc779bfbbba..1ed35ef6f61f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,8 +48,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608)) +- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662)) + +- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608)) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 2662aa6758332..53210009db9ed 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import rank_zero_warn, AttributeDict from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.cloud_io import get_filesystem - +from pytorch_lightning.utilities.parsing import parse_class_init_keys PRIMITIVE_TYPES = (bool, int, float, str) ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace) @@ -159,8 +159,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() - self_name = cls_spec.args[0] - drop_names = (self_name, cls_spec.varargs, cls_spec.varkw) + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + drop_names = [n for n in (self_var, args_var, kwargs_var) if n] cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name)) cls_kwargs_loaded = {} diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 348eec110c3a1..b207320c25ccc 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -15,7 +15,7 @@ import inspect import pickle from argparse import Namespace -from typing import Dict, Union +from typing import Dict, Union, Tuple from pytorch_lightning.utilities import rank_zero_warn @@ -79,23 +79,46 @@ def clean_namespace(hparams): del hparams_dict[k] +def parse_class_init_keys(cls) -> Tuple[str, str, str]: + """Parse key words for standard self, *args and **kwargs + + >>> class Model(): + ... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): + ... pass + >>> parse_class_init_keys(Model) + ('self', 'my_args', 'my_kwargs') + """ + init_parameters = inspect.signature(cls.__init__).parameters + # docs claims the params are always ordered + # https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + init_params = list(init_parameters.values()) + # self is always first + n_self = init_params[0].name + + def _get_first_if_any(params, param_type): + for p in params: + if p.kind == param_type: + return p.name + + n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) + n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) + + return n_self, n_args, n_kwargs + + def get_init_args(frame) -> dict: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: - return + return {} cls = local_vars['__class__'] - spec = inspect.getfullargspec(cls.__init__) init_parameters = inspect.signature(cls.__init__).parameters - self_identifier = spec.args[0] # "self" unless user renames it (always first arg) - varargs_identifier = spec.varargs # by convention this is named "*args" - kwargs_identifier = spec.varkw # by convention this is named "**kwargs" - exclude_argnames = ( - varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args' - ) + self_var, args_var, kwargs_var = parse_class_init_keys(cls) + filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n] + exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} - local_args.update(local_args.get(kwargs_identifier, {})) + local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} return local_args diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index b7d0be01e9622..7df78d9760bd9 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import os import pickle from argparse import Namespace @@ -29,20 +30,46 @@ from tests.base import EvalModelTemplate, TrialMNIST, BoringModel -class SaveHparamsModel(EvalModelTemplate): +class SaveHparamsModel(BoringModel): """ Tests that a model can take an object """ def __init__(self, hparams): super().__init__() self.save_hyperparameters(hparams) -class AssignHparamsModel(EvalModelTemplate): +class AssignHparamsModel(BoringModel): """ Tests that a model can take an object with explicit setter """ def __init__(self, hparams): super().__init__() self.hparams = hparams +def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +class SaveHparamsDecoratedModel(BoringModel): + """ Tests that a model can take an object """ + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.save_hyperparameters(hparams) + + +class AssignHparamsDecoratedModel(BoringModel): + """ Tests that a model can take an object with explicit setter""" + @decorate + @decorate + def __init__(self, hparams, *my_args, **my_kwargs): + super().__init__() + self.hparams = hparams + + # ------------------------- # STANDARD TESTS # ------------------------- @@ -78,7 +105,9 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False): return raw_checkpoint_path -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_namespace_hparams(tmpdir, cls): # init model model = cls(hparams=Namespace(test_arg=14)) @@ -87,7 +116,9 @@ def test_namespace_hparams(tmpdir, cls): _run_standard_hparams_test(tmpdir, model, cls) -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_dict_hparams(tmpdir, cls): # init model model = cls(hparams={'test_arg': 14}) @@ -96,7 +127,9 @@ def test_dict_hparams(tmpdir, cls): _run_standard_hparams_test(tmpdir, model, cls) -@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel]) +@pytest.mark.parametrize("cls", [ + SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel +]) def test_omega_conf_hparams(tmpdir, cls): # init model conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))