Skip to content

Commit

Permalink
allow decorate model init with saving hparams (#4662)
Browse files Browse the repository at this point in the history
* addd tests

* use boring model

* parsing init

* chlog

* double decorate

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <[email protected]>

* bug

Co-authored-by: chaton <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Nicki Skafte <[email protected]>
Co-authored-by: Roger Shieh <[email protected]>
  • Loading branch information
5 people authored Nov 16, 2020
1 parent 886702a commit be60efb
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 19 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608))
- Allowing decorate model init with saving `hparams` inside ([#4662](https://github.com/PyTorchLightning/pytorch-lightning/pull/4662))


- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608))



Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import get_filesystem

from pytorch_lightning.utilities.parsing import parse_class_init_keys

PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
Expand Down Expand Up @@ -159,8 +159,8 @@ def _load_model_state(cls, checkpoint: Dict[str, Any], strict: bool = True, **cl
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()

self_name = cls_spec.args[0]
drop_names = (self_name, cls_spec.varargs, cls_spec.varkw)
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
drop_names = [n for n in (self_var, args_var, kwargs_var) if n]
cls_init_args_name = list(filter(lambda n: n not in drop_names, cls_init_args_name))

cls_kwargs_loaded = {}
Expand Down
43 changes: 33 additions & 10 deletions pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
import pickle
from argparse import Namespace
from typing import Dict, Union
from typing import Dict, Union, Tuple

from pytorch_lightning.utilities import rank_zero_warn

Expand Down Expand Up @@ -79,23 +79,46 @@ def clean_namespace(hparams):
del hparams_dict[k]


def parse_class_init_keys(cls) -> Tuple[str, str, str]:
"""Parse key words for standard self, *args and **kwargs
>>> class Model():
... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs):
... pass
>>> parse_class_init_keys(Model)
('self', 'my_args', 'my_kwargs')
"""
init_parameters = inspect.signature(cls.__init__).parameters
# docs claims the params are always ordered
# https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters
init_params = list(init_parameters.values())
# self is always first
n_self = init_params[0].name

def _get_first_if_any(params, param_type):
for p in params:
if p.kind == param_type:
return p.name

n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL)
n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD)

return n_self, n_args, n_kwargs


def get_init_args(frame) -> dict:
_, _, _, local_vars = inspect.getargvalues(frame)
if '__class__' not in local_vars:
return
return {}
cls = local_vars['__class__']
spec = inspect.getfullargspec(cls.__init__)
init_parameters = inspect.signature(cls.__init__).parameters
self_identifier = spec.args[0] # "self" unless user renames it (always first arg)
varargs_identifier = spec.varargs # by convention this is named "*args"
kwargs_identifier = spec.varkw # by convention this is named "**kwargs"
exclude_argnames = (
varargs_identifier, kwargs_identifier, self_identifier, '__class__', 'frame', 'frame_args'
)
self_var, args_var, kwargs_var = parse_class_init_keys(cls)
filtered_vars = [n for n in (self_var, args_var, kwargs_var) if n]
exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args')

# only collect variables that appear in the signature
local_args = {k: local_vars[k] for k in init_parameters.keys()}
local_args.update(local_args.get(kwargs_identifier, {}))
local_args.update(local_args.get(kwargs_var, {}))
local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames}
return local_args

Expand Down
43 changes: 38 additions & 5 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import pickle
from argparse import Namespace
Expand All @@ -29,20 +30,46 @@
from tests.base import EvalModelTemplate, TrialMNIST, BoringModel


class SaveHparamsModel(EvalModelTemplate):
class SaveHparamsModel(BoringModel):
""" Tests that a model can take an object """
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)


class AssignHparamsModel(EvalModelTemplate):
class AssignHparamsModel(BoringModel):
""" Tests that a model can take an object with explicit setter """
def __init__(self, hparams):
super().__init__()
self.hparams = hparams


def decorate(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper


class SaveHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object """
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.save_hyperparameters(hparams)


class AssignHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object with explicit setter"""
@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.hparams = hparams


# -------------------------
# STANDARD TESTS
# -------------------------
Expand Down Expand Up @@ -78,7 +105,9 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):
return raw_checkpoint_path


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_namespace_hparams(tmpdir, cls):
# init model
model = cls(hparams=Namespace(test_arg=14))
Expand All @@ -87,7 +116,9 @@ def test_namespace_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_dict_hparams(tmpdir, cls):
# init model
model = cls(hparams={'test_arg': 14})
Expand All @@ -96,7 +127,9 @@ def test_dict_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)


@pytest.mark.parametrize("cls", [SaveHparamsModel, AssignHparamsModel])
@pytest.mark.parametrize("cls", [
SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel
])
def test_omega_conf_hparams(tmpdir, cls):
# init model
conf = OmegaConf.create(dict(test_arg=14, mylist=[15.4, dict(a=1, b=2)]))
Expand Down

0 comments on commit be60efb

Please sign in to comment.