Skip to content

Commit

Permalink
Fix metric state reset (#5273)
Browse files Browse the repository at this point in the history
* Fix metric state reset

* Fix test

* Improve formatting

Co-authored-by: Ananya Harsh Jha <[email protected]>
  • Loading branch information
tadejsv and ananyahjha93 authored Dec 29, 2020
1 parent dabfeca commit 4913cbb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def add_state(
reset to this value when ``self.reset()`` is called.
dist_reduce_fx (Optional): Function to reduce state accross mutliple processes in distributed mode.
If value is ``"sum"``, ``"mean"``, or ``"cat"``, we will use ``torch.sum``, ``torch.mean``,
and ``torch.cat`` respectively, each with argument ``dim=0``. The user can also pass a custom
and ``torch.cat`` respectively, each with argument ``dim=0``. Note that the ``"cat"`` reduction
only makes sense if the state is a list, and not a tensor. The user can also pass a custom
function in this parameter.
persistent (Optional): whether the state will be saved as part of the modules ``state_dict``.
Default is ``False``.
Expand Down Expand Up @@ -244,7 +245,7 @@ def reset(self):
"""
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
if isinstance(default, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))
Expand Down
23 changes: 23 additions & 0 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ def compute(self):
pass


class DummyList(Metric):
name = "DummyList"

def __init__(self):
super().__init__()
self.add_state("x", list(), dist_reduce_fx=None)

def update(self):
pass

def compute(self):
pass


def test_inherit():
a = Dummy()

Expand Down Expand Up @@ -77,12 +91,21 @@ def test_reset():
class A(Dummy):
pass

class B(DummyList):
pass

a = A()
assert a.x == 0
a.x = torch.tensor(5)
a.reset()
assert a.x == 0

b = B()
assert isinstance(b.x, list) and len(b.x) == 0
b.x = torch.tensor(5)
b.reset()
assert isinstance(b.x, list) and len(b.x) == 0


def test_update():
class A(Dummy):
Expand Down

0 comments on commit 4913cbb

Please sign in to comment.