Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] change default behaviour of state dict #4685

Merged
merged 5 commits into from
Nov 16, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -38,6 +38,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))


- Metric states are no longer as default added to `state_dict` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/))


### Deprecated


@@ -81,7 +84,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775))
- Added `manual_optimizer_step` which work with `AMP Native` and `accumulated_grad_batches` ([#4485](https://github.com/PyTorchLightning/pytorch-lightning/pull/4485))
- Added `persistent(mode)` method to metrics, to enable and disable metric states being added to `state_dict` ([#4482](https://github.com/PyTorchLightning/pytorch-lightning/pull/4482))
- Added congratulations at the end of our notebooks ([#4555](https://github.com/PyTorchLightning/pytorch-lightning/pull/4555))
4 changes: 2 additions & 2 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
@@ -133,8 +133,8 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us

.. note::

Metric states will as default add their internal state to the models ``state_dict``.
To change this after initializing the metric the method ``.persistent(mode)`` can
Metric states are **not** added to the models ``state_dict`` by default.
To change this, after initializing the metric, the method ``.persistent(mode)`` can
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.

*********************
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -86,7 +86,7 @@ def __init__(
self._reductions = {}

def add_state(
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = True
self, name: str, default, dist_reduce_fx: Optional[Union[str, Callable]] = None, persistent: bool = False
):
"""
Adds metric state variable. Only used by subclasses.
@@ -100,6 +100,7 @@ def add_state(
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.

Note:
Setting ``dist_reduce_fx`` to None will return the metric state synchronized across different processes.
@@ -278,7 +279,7 @@ def _apply(self, fn):
f'or a list of torch.Tensor, but encountered {current_val}')
return self

def persistent(self, mode: bool = True):
def persistent(self, mode: bool = False):
""" Method for post-init to change if metric states should be saved to
its state_dict
"""
13 changes: 12 additions & 1 deletion tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pickle

from collections import OrderedDict
from distutils.version import LooseVersion

import cloudpickle
import numpy as np
import pytest
@@ -167,3 +168,13 @@ def test_pickle(tmpdir):
metric_loaded = cloudpickle.loads(metric_pickled)

assert metric_loaded.compute() == 1


def test_state_dict(tmpdir):
""" test that metric states can be removed and added to state dict """
metric = Dummy()
assert metric.state_dict() == OrderedDict()
metric.persistent(True)
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()