From 7ae61a16ae599a5ff31e7726a009d51b6b2bb31f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 15 Nov 2020 15:56:00 +0100 Subject: [PATCH 1/3] fix state dict --- docs/source/metrics.rst | 4 ++-- pytorch_lightning/metrics/metric.py | 5 +++-- tests/metrics/test_metric.py | 13 ++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index d80b35f91abd1..556ff4af7d6ce 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us .. note:: - Metric states will as default add their internal state to the models ``state_dict``. - To change this after initializing the metric the method ``.persistent(mode)`` can + Metric states are **not** as default added to the models ``state_dict``. + To change this after initializing the metric, the method ``.persistent(mode)`` can be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. ********************* diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 9fa479dfb567a..2bc7977be25cf 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -86,7 +86,7 @@ def __init__( self._reductions = {} def add_state( - self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True + self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False ): """ Adds metric state variable. Only used by subclasses. @@ -100,6 +100,7 @@ def add_state( and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom function in this parameter. persistent (Optional): whether the state will be saved as part of the modules ``state_dict``. + Default is ``False``. Note: Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes. @@ -278,7 +279,7 @@ def _apply(self, fn): f'or a list of torch.Tensor, but encountered {current_val}') return self - def persistent(self, mode: bool = True): + def persistent(self, mode: bool = False): """ Method for post-init to change if metric states should be saved to its state_dict """ diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 3c85a4c126a27..d97cd1a176cf2 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -1,6 +1,7 @@ import pickle - +from collections import OrderedDict from distutils.version import LooseVersion + import cloudpickle import numpy as np import pytest @@ -167,3 +168,13 @@ def test_pickle(tmpdir): metric_loaded = cloudpickle.loads(metric_pickled) assert metric_loaded.compute() == 1 + + +def test_state_dict(tmpdir): + """ test that metric states can be removed and added to state dict """ + metric = Dummy() + assert metric.state_dict() == OrderedDict() + metric.persistent(True) + assert metric.state_dict() == OrderedDict(x=0) + metric.persistent(False) + assert metric.state_dict() == OrderedDict() From d2d863bd62fe194c9fbf6b7427c2ce16e79d0817 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 16 Nov 2020 10:48:33 +0100 Subject: [PATCH 2/3] Update docs/source/metrics.rst Co-authored-by: Rohit Gupta --- docs/source/metrics.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 556ff4af7d6ce..494f1bb443d87 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us .. note:: - Metric states are **not** as default added to the models ``state_dict``. - To change this after initializing the metric, the method ``.persistent(mode)`` can + Metric states are **not** added to the models ``state_dict`` by default. + To change this, after initializing the metric, the method ``.persistent(mode)`` can be used to enable (``mode=True``) or disable (``mode=False``) this behaviour. ********************* From 4ccd1e5917e42e0af83144fd0baf4dcfb95e43a3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 16 Nov 2020 10:51:09 +0100 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ecc779bfbbba..01b517423aca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) +- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/)) + + ### Deprecated @@ -79,7 +82,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) - Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485)) - Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482)) - Added congratulations at the end of our notebooks ([#4555](https://github.com/PyTorchLightning/pytorch-lightning/pull/4555))