From be12ada01ff5b53d3b68d450863923976e0f1b01 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Sun, 13 Jun 2021 10:32:29 +0530 Subject: [PATCH 01/13] Add log flag to save_hyperparameters --- pytorch_lightning/core/lightning.py | 13 ++++++++++++- pytorch_lightning/trainer/trainer.py | 3 ++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c41423948f8c8..6017bb1f68c7c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -111,6 +111,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() + self._log_hyperparams = True def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: if use_pl_optimizer: @@ -214,6 +215,14 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None + @property + def log_hyperparams(self) -> bool: + return self._log_hyperparams + + @log_hyperparams.setter + def log_hyperparams(self, log: bool) -> None: + self._log_hyperparams = log + def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: @@ -1738,7 +1747,8 @@ def save_hyperparameters( self, *args, ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None + frame: Optional[types.FrameType] = None, + log: bool = True ) -> None: """Save model arguments to ``hparams`` attribute. @@ -1800,6 +1810,7 @@ class ``__init__`` to be ignored "arg1": 1 "arg3": 3.14 """ + self.log_hyperparams(log) # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6979a859b0e9a..785a5773781c9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -882,7 +882,8 @@ def _pre_dispatch(self): # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - self.logger.log_hyperparams(self.lightning_module.hparams_initial) + if self.lightning_module.log_hyperparams: + self.logger.log_hyperparams(self.lightning_module.hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() From 5c11860245271f55d9c8d5b8372775d1b356a92d Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 14 Jun 2021 04:47:16 +0530 Subject: [PATCH 02/13] FIx setter --- pytorch_lightning/core/lightning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 6017bb1f68c7c..fe3a020509d0e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1810,7 +1810,7 @@ class ``__init__`` to be ignored "arg1": 1 "arg3": 3.14 """ - self.log_hyperparams(log) + self.log_hyperparams = log # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back From b8d69b041075a2bcf10dcb28fea9f00625ead062 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 18 Jun 2021 09:56:38 +0530 Subject: [PATCH 03/13] Add test & Update changelog --- CHANGELOG.md | 5 ++++- tests/loggers/test_base.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 23ba6c4f26411..493323c4480d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,7 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992)) -- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980)) +- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980)) + + +- Added `log` flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960)) ### Changed diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 9209083148265..7574ff7f111b0 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock import numpy as np +import pytest from pytorch_lightning import Trainer from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger @@ -290,3 +291,37 @@ def log_hyperparams(self, params): } logger.log_hyperparams(Namespace(**np_params)) assert logger.logged_params == sanitized_params + + +@pytest.mark.parametrize("log", [True, False]) +def test_log_hyperparams_being_called(tmpdir, log): + + class TestLogger(DummyLogger): + + def __init__(self): + super().__init__() + self.log_hyperparams_called = False + + def log_hyperparams(self, *args, **kwargs): + self.log_hyperparams_called = True + + class TestModel(BoringModel): + + def __init__(self, param_one, param_two): + super().__init__() + self.save_hyperparameters(log=log) + + logger = TestLogger() + model = TestModel("pytorch", "lightning") + + trainer = Trainer( + default_root_dir=tmpdir, + logger=logger, + max_epochs=1, + limit_train_batches=0.1, + limit_val_batches=0.1, + num_sanity_val_steps=0, + ) + trainer.fit(model) + + assert log == logger.log_hyperparams_called From d00fd7bff22da6946f82bcd67084a6fae5246560 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 18 Jun 2021 13:27:18 +0530 Subject: [PATCH 04/13] Address comments --- pytorch_lightning/core/lightning.py | 13 +++---------- pytorch_lightning/trainer/trainer.py | 2 +- tests/loggers/test_base.py | 12 ++++++------ 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index fe3a020509d0e..d99ad7f37cf5e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -215,14 +215,6 @@ def logger(self): """ Reference to the logger object in the Trainer. """ return self.trainer.logger if self.trainer else None - @property - def log_hyperparams(self) -> bool: - return self._log_hyperparams - - @log_hyperparams.setter - def log_hyperparams(self, log: bool) -> None: - self._log_hyperparams = log - def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: Optional[int] = None ) -> Any: @@ -1748,7 +1740,7 @@ def save_hyperparameters( *args, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, - log: bool = True + logger: bool = True ) -> None: """Save model arguments to ``hparams`` attribute. @@ -1758,6 +1750,7 @@ def save_hyperparameters( ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None + logger: Whether to save hyperparameters by logger. Default: True Example:: >>> class ManuallyArgsModel(LightningModule): @@ -1810,7 +1803,7 @@ class ``__init__`` to be ignored "arg1": 1 "arg3": 3.14 """ - self.log_hyperparams = log + self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 785a5773781c9..71d9e6cb9f454 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -882,7 +882,7 @@ def _pre_dispatch(self): # log hyper-parameters if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - if self.lightning_module.log_hyperparams: + if self.lightning_module._log_hyperparams: self.logger.log_hyperparams(self.lightning_module.hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 7574ff7f111b0..981913f432063 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -293,8 +293,8 @@ def log_hyperparams(self, params): assert logger.logged_params == sanitized_params -@pytest.mark.parametrize("log", [True, False]) -def test_log_hyperparams_being_called(tmpdir, log): +@pytest.mark.parametrize("logger", [True, False]) +def test_log_hyperparams_being_called(tmpdir, logger): class TestLogger(DummyLogger): @@ -309,14 +309,14 @@ class TestModel(BoringModel): def __init__(self, param_one, param_two): super().__init__() - self.save_hyperparameters(log=log) + self.save_hyperparameters(logger=logger) - logger = TestLogger() + test_logger = TestLogger() model = TestModel("pytorch", "lightning") trainer = Trainer( default_root_dir=tmpdir, - logger=logger, + logger=test_logger, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0.1, @@ -324,4 +324,4 @@ def __init__(self, param_one, param_two): ) trainer.fit(model) - assert log == logger.log_hyperparams_called + assert logger == test_logger.log_hyperparams_called From dd56dc66aed88bac7edf4f6eab6e64544479c4a7 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 21:02:44 +0530 Subject: [PATCH 05/13] Fix conflicts --- pytorch_lightning/core/datamodule.py | 2 ++ pytorch_lightning/utilities/hparams_mixin.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index b1f42c6ea4390..a55900dd15b9c 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -97,6 +97,8 @@ def __init__( self._has_teardown_test = False self._has_teardown_predict = False + self._log_hyperparams = True + @property def train_transforms(self): """ diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 8dd4b23c89398..e5b553ad189b1 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -30,7 +30,8 @@ def save_hyperparameters( self, *args, ignore: Optional[Union[Sequence[str], str]] = None, - frame: Optional[types.FrameType] = None + frame: Optional[types.FrameType] = None, + logger: bool = True ) -> None: """Save arguments to ``hparams`` attribute. @@ -40,6 +41,7 @@ def save_hyperparameters( ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None + logger: Whether to save hyperparameters by logger. Default: True Example:: >>> class ManuallyArgsModel(HyperparametersMixin): @@ -92,6 +94,7 @@ class ``__init__`` to be ignored "arg1": 1 "arg3": 3.14 """ + self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: frame = inspect.currentframe().f_back From 15336b15941c779caa1e8cde42e62e2d95f1f13c Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Fri, 9 Jul 2021 21:11:37 +0530 Subject: [PATCH 06/13] Update trainer --- pytorch_lightning/trainer/trainer.py | 31 +++++++++++++++++----------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4c034ac843361..cf6722c62e18d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -907,20 +907,27 @@ def _pre_dispatch(self): def _log_hyperparams(self): # log hyper-parameters + hparams_initial = None + if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} - lightning_hparams = self.lightning_module.hparams_initial - colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() - if colliding_keys: - raise MisconfigurationException( - f"Error while merging hparams: the keys {colliding_keys} are present " - "in both the LightningModule's and LightningDataModule's hparams." - ) - - hparams_initial = {**lightning_hparams, **datamodule_hparams} - - self.logger.log_hyperparams(hparams_initial) + if self.lightning_module._log_hyperparams and self.datamodule._log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} + lightning_hparams = self.lightning_module.hparams_initial + colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() + if colliding_keys: + raise MisconfigurationException( + f"Error while merging hparams: the keys {colliding_keys} are present " + "in both the LightningModule's and LightningDataModule's hparams." + ) + hparams_initial = {**lightning_hparams, **datamodule_hparams} + elif self.lightning_module._log_hyperparams: + hparams_initial = self.lightning_module.hparams_initial + elif self.datamodule._log_hyperparams: + hparams_initial = self.datamodule.hparams_initial + + if hparams_initial is not None: + self.logger.log_hyperparams(hparams_initial) self.logger.log_graph(self.lightning_module) self.logger.save() From fef7b922aef66a3ca3c5a6903dbaf936d852e131 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Fri, 9 Jul 2021 21:11:59 +0530 Subject: [PATCH 07/13] Update CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3a5c14610f143..6568614497eeb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,7 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980)) -- Added `log` flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960)) +- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960)) - Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073)) From 2992716cfc7037ff23e1988b492adce97203b900 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 12 Jul 2021 13:10:40 +0530 Subject: [PATCH 08/13] Fix datamodule hparams fix --- pytorch_lightning/trainer/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf6722c62e18d..44a12ae30bc71 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -911,9 +911,12 @@ def _log_hyperparams(self): if self.logger is not None: # save exp to get started (this is where the first experiment logs are written) - if self.lightning_module._log_hyperparams and self.datamodule._log_hyperparams: - datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {} + datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False + + if self.lightning_module._log_hyperparams and datamodule_log_hyperparams: + datamodule_hparams = self.datamodule.hparams_initial lightning_hparams = self.lightning_module.hparams_initial + colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys() if colliding_keys: raise MisconfigurationException( From c6f25d8c23a5ad1ae8ec5f55ebf4661458c54655 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 12 Jul 2021 13:21:14 +0530 Subject: [PATCH 09/13] Fix datamodule hparams fix --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 44a12ae30bc71..5f6305d5e5b34 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -926,7 +926,7 @@ def _log_hyperparams(self): hparams_initial = {**lightning_hparams, **datamodule_hparams} elif self.lightning_module._log_hyperparams: hparams_initial = self.lightning_module.hparams_initial - elif self.datamodule._log_hyperparams: + elif datamodule_log_hyperparams: hparams_initial = self.datamodule.hparams_initial if hparams_initial is not None: From 2d72c0e097f8e5bb13dc546a1c60f6c7a1a755fb Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 12 Jul 2021 13:41:41 +0530 Subject: [PATCH 10/13] Update test with patch --- tests/loggers/test_base.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/tests/loggers/test_base.py b/tests/loggers/test_base.py index 981913f432063..5ecc372ec0acf 100644 --- a/tests/loggers/test_base.py +++ b/tests/loggers/test_base.py @@ -14,7 +14,7 @@ import pickle from argparse import Namespace from typing import Optional -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -294,16 +294,8 @@ def log_hyperparams(self, params): @pytest.mark.parametrize("logger", [True, False]) -def test_log_hyperparams_being_called(tmpdir, logger): - - class TestLogger(DummyLogger): - - def __init__(self): - super().__init__() - self.log_hyperparams_called = False - - def log_hyperparams(self, *args, **kwargs): - self.log_hyperparams_called = True +@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams") +def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger): class TestModel(BoringModel): @@ -311,12 +303,9 @@ def __init__(self, param_one, param_two): super().__init__() self.save_hyperparameters(logger=logger) - test_logger = TestLogger() model = TestModel("pytorch", "lightning") - trainer = Trainer( default_root_dir=tmpdir, - logger=test_logger, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0.1, @@ -324,4 +313,7 @@ def __init__(self, param_one, param_two): ) trainer.fit(model) - assert logger == test_logger.log_hyperparams_called + if logger: + log_hyperparams_mock.assert_called() + else: + log_hyperparams_mock.assert_not_called() From eec65550c90f9e6931d1d994ae79b58f5997f735 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 12 Jul 2021 14:34:49 +0530 Subject: [PATCH 11/13] Update pytorch_lightning/utilities/hparams_mixin.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/hparams_mixin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index e5b553ad189b1..802d3dfab453a 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -41,7 +41,7 @@ def save_hyperparameters( ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None - logger: Whether to save hyperparameters by logger. Default: True + logger: Whether to send the hyperparameters to the logger. Default: True Example:: >>> class ManuallyArgsModel(HyperparametersMixin): From 41035e6ff893f5257e298bdcf557217144d3d4e8 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 12 Jul 2021 15:04:33 +0530 Subject: [PATCH 12/13] Move log_hyperparams to mixin --- pytorch_lightning/core/datamodule.py | 2 -- pytorch_lightning/core/lightning.py | 1 - pytorch_lightning/utilities/hparams_mixin.py | 5 ++++- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index a55900dd15b9c..b1f42c6ea4390 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -97,8 +97,6 @@ def __init__( self._has_teardown_test = False self._has_teardown_predict = False - self._log_hyperparams = True - @property def train_transforms(self): """ diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d4fb0a6832a3b..735f8ab160c1f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -108,7 +108,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._automatic_optimization: bool = True self._truncated_bptt_steps: int = 0 self._param_requires_grad_state = dict() - self._log_hyperparams = True self._metric_attributes: Optional[Dict[int, str]] = None # deprecated, will be removed in 1.6 diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 802d3dfab453a..3a33d7cba6077 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -22,10 +22,13 @@ from pytorch_lightning.utilities.parsing import save_hyperparameters -class HyperparametersMixin: +class HyperparametersMixin(object): __jit_unused_properties__ = ["hparams", "hparams_initial"] + def __init__(self) -> None: + self._log_hyperparams = True + def save_hyperparameters( self, *args, From 3b0a84ab420a272364fe8c469ab37d185b7975a3 Mon Sep 17 00:00:00 2001 From: Kaushik B Date: Mon, 12 Jul 2021 20:22:39 +0530 Subject: [PATCH 13/13] Update hparams mixin --- pytorch_lightning/utilities/hparams_mixin.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/hparams_mixin.py b/pytorch_lightning/utilities/hparams_mixin.py index 3a33d7cba6077..b1cb9492e91d5 100644 --- a/pytorch_lightning/utilities/hparams_mixin.py +++ b/pytorch_lightning/utilities/hparams_mixin.py @@ -27,6 +27,7 @@ class HyperparametersMixin(object): __jit_unused_properties__ = ["hparams", "hparams_initial"] def __init__(self) -> None: + super().__init__() self._log_hyperparams = True def save_hyperparameters(