diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ed35ef6f61f5..236784afec84e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) +- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/)) + + ### Deprecated @@ -81,7 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) - Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485)) - Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) - Added congratulations at the end of our notebooks ([#4555](https://github.com/PyTorchLightning/pytorch-lightning/pull/4555)) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d80b35f91abd1..494f1bb443d87 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us .. note:: - Metric states will as default add their internal state to the models ``state_dict``. - To change this after initializing the metric the method ``.persistent(mode)`` can + Metric states are **not** added to the models ``state_dict`` by default. + To change this, after initializing the metric, the method ``.persistent(mode)`` can be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. ********************* diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 9fa479dfb567a..2bc7977be25cf 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -86,7 +86,7 @@ def __init__( self._reductions = {} def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True + self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False ): """ Adds metric state variable. Only used by subclasses. @@ -100,6 +100,7 @@ def add_state( and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. + Default is ``False``. Note: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. @@ -278,7 +279,7 @@ def _apply(self, fn): f'or a list of torch.Tensor, but encountered {current_val}') return self - def persistent(self, mode: bool = True): + def persistent(self, mode: bool = False): """ Method for post-init to change if metric states should be saved to its state_dict """ diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 3c85a4c126a27..d97cd1a176cf2 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -1,6 +1,7 @@ import pickle - +from collections import OrderedDict from distutils.version import LooseVersion + import cloudpickle import numpy as np import pytest @@ -167,3 +168,13 @@ def test_pickle(tmpdir): metric_loaded = cloudpickle.loads(metric_pickled) assert metric_loaded.compute() == 1 + + +def test_state_dict(tmpdir): + """ test that metric states can be removed and added to state dict """ + metric = Dummy() + assert metric.state_dict() == OrderedDict() + metric.persistent(True) + assert metric.state_dict() == OrderedDict(x=0) + metric.persistent(False) + assert metric.state_dict() == OrderedDict()