Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Apr 16, 2024
2 parents afc5629 + 5259c22 commit 51dcc91
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ A few important things to note for this example:
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
of the mode.

* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care
must be taken when referencing list states. If you require the values after your metric is reset, you must first
copy the attribute to another object (e.g. using `deepcopy.copy`).

*****************
Metric attributes
*****************
Expand Down
29 changes: 25 additions & 4 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def add_state(
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.
Note:
The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
device memory to be automatically reallocated, but may produce unexpected effects when referencing list
states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
object.
Raises:
ValueError:
If ``default`` is not a ``tensor`` or an ``empty list``.
Expand Down Expand Up @@ -325,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
cache = self._copy_state_dict()

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
Expand Down Expand Up @@ -358,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:
"""
# store global state and reset to default
global_state = {attr: getattr(self, attr) for attr in self._defaults}
global_state = self._copy_state_dict()
_update_count = self._update_count
self.reset()

Expand Down Expand Up @@ -525,7 +531,7 @@ def sync(
dist_sync_fn = gather_all_tensors

# cache prior to syncing
self._cache = {attr: getattr(self, attr) for attr in self._defaults}
self._cache = self._copy_state_dict()

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
Expand Down Expand Up @@ -681,7 +687,7 @@ def reset(self) -> None:
if isinstance(default, Tensor):
setattr(self, attr, default.detach().clone().to(current_val.device))
else:
setattr(self, attr, [])
getattr(self, attr).clear() # delete/free list items

# reset internal states
self._cache = None
Expand Down Expand Up @@ -870,6 +876,21 @@ def state_dict( # type: ignore[override] # todo
destination[prefix + key] = deepcopy(current_val)
return destination

def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
"""Copy the current state values."""
cache: Dict[str, Union[Tensor, List[Any]]] = {}
for attr in self._defaults:
current_value = getattr(self, attr)

if isinstance(current_value, Tensor):
cache[attr] = current_value.detach().clone().to(current_value.device)
else:
cache[attr] = [ # safely copy (non-graph leaf) Tensor elements
_.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
]

return cache

def _load_from_state_dict(
self,
state_dict: dict,
Expand Down
32 changes: 27 additions & 5 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@ class B(DummyListMetric):
metric = B()
assert isinstance(metric.x, list)
assert len(metric.x) == 0
metric.x = tensor(5)
metric.x = [tensor(5)]
metric.reset()
assert isinstance(metric.x, list)
assert len(metric.x) == 0

metric = B()
metric.x = [1, 2, 3]
reference = metric.x # prevents garbage collection
metric.reset()
assert len(reference) == 0 # check list state is freed


def test_reset_compute():
"""Test that `reset`+`compute` methods works as expected."""
Expand Down Expand Up @@ -474,18 +480,34 @@ def test_constant_memory_on_repeat_init():
def mem():
return torch.cuda.memory_allocated() / 1024**2

x = torch.randn(10000).cuda()

for i in range(100):
m = DummyListMetric(compute_with_cache=False).cuda()
m(x)
_ = DummyListMetric(compute_with_cache=False).cuda()
if i == 0:
after_one_iter = mem()

# allow for 5% flucturation due to measuring
assert after_one_iter * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_freed_memory_on_reset():
"""Test that resetting a metric frees all the memory allocated when updating it."""

def mem():
return torch.cuda.memory_allocated() / 1024**2

m = DummyListMetric().cuda()
after_init = mem()

for _ in range(100):
m(x=torch.randn(10000).cuda())

m.reset()

# allow for 5% flucturation due to measuring
assert after_init * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
def test_specific_error_on_wrong_device():
"""Test that a specific error is raised if we detect input and metric are on different devices."""
Expand Down

0 comments on commit 51dcc91

Please sign in to comment.