Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear list states (i.e. delete their contents), not reassign the default [] #2493

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
11df0eb
Clear (i.e. delete) list state items, not simply overwrite. Previous …
Apr 6, 2024
1fa7077
Added test to check list states elements are deleted (even when refer…
Apr 7, 2024
4b5c099
Updated documentation - highlighted reset clears list states, and tha…
Apr 9, 2024
8bf151b
Add missing method (sphinx) role
dominicgkerr Apr 9, 2024
64fd4d2
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
Borda Apr 10, 2024
b991b3b
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
mergify[bot] Apr 11, 2024
82f808b
changelog
SkafteNicki Apr 12, 2024
65b02fa
Remove failing testcode example (fixing introduces too much complexity)
Apr 12, 2024
b9dcc8b
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
mergify[bot] Apr 13, 2024
5565524
Merge branch 'bugfix/2492-clear-list-states-not-reassign' of github.c…
Apr 13, 2024
6241a6b
Linting - Line break docstring
Apr 13, 2024
5758977
copy internal states in forward
SkafteNicki Apr 13, 2024
c9d2a86
Detach Tensor | list[Tensor] state values before copying.
Apr 13, 2024
e1872de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
afdd4c5
Use 'typing' type hints
Apr 13, 2024
76cc0a1
Merge remote-tracking branch 'origin/bugfix/2492-clear-list-states-no…
Apr 13, 2024
ef27215
DO not clone (when caching) Tensor states, but retain references to a…
Apr 13, 2024
e104587
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 13, 2024
21c7970
Revert "DO not clone (when caching) Tensor states, but retain referen…
Apr 13, 2024
51975a8
Added mypy type-hinting requirement/recommendation
Apr 13, 2024
5954d02
Moved update from test checking .__init__ memory leakage. Added test …
Apr 13, 2024
cd11bb0
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
mergify[bot] Apr 14, 2024
dc14e5e
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
SkafteNicki Apr 15, 2024
925a3b1
Merge branch 'master' into bugfix/2492-clear-list-states-not-reassign
Borda Apr 15, 2024
3f013cf
Fix unused loop control variable for pre-commit
stancld Apr 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462))


- Fixed memory leak in metrics using list states ([#2492](https://github.com/Lightning-AI/torchmetrics/pull/2492))


- Fixed bug in computation of `ERGAS` metric ([#2498](https://github.com/Lightning-AI/torchmetrics/pull/2498))


Expand Down
4 changes: 4 additions & 0 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ A few important things to note for this example:
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
of the mode.

* Calling the ``reset`` method will clear the list state, deleting any values inserted into it. For this reason, care
must be taken when referencing list states. If you require the values after your metric is reset, you must first
copy the attribute to another object (e.g. using `deepcopy.copy`).

*****************
Metric attributes
*****************
Expand Down
29 changes: 25 additions & 4 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,12 @@ def add_state(
When passing a custom function to ``dist_reduce_fx``, expect the synchronized metric state to follow
the format discussed in the above note.

Note:
The values inserted into a list state are deleted whenever :meth:`~Metric.reset` is called. This allows
device memory to be automatically reallocated, but may produce unexpected effects when referencing list
states. To retain such values after :meth:`~Metric.reset` is called, you must first copy them to another
object.

Raises:
ValueError:
If ``default`` is not a ``tensor`` or an ``empty list``.
Expand Down Expand Up @@ -325,7 +331,7 @@ def _forward_full_state_update(self, *args: Any, **kwargs: Any) -> Any:
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
cache = self._copy_state_dict()

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
Expand Down Expand Up @@ -358,7 +364,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

"""
# store global state and reset to default
global_state = {attr: getattr(self, attr) for attr in self._defaults}
global_state = self._copy_state_dict()
_update_count = self._update_count
self.reset()

Expand Down Expand Up @@ -525,7 +531,7 @@ def sync(
dist_sync_fn = gather_all_tensors

# cache prior to syncing
self._cache = {attr: getattr(self, attr) for attr in self._defaults}
self._cache = self._copy_state_dict()

# sync
self._sync_dist(dist_sync_fn, process_group=process_group)
Expand Down Expand Up @@ -681,7 +687,7 @@ def reset(self) -> None:
if isinstance(default, Tensor):
setattr(self, attr, default.detach().clone().to(current_val.device))
else:
setattr(self, attr, [])
getattr(self, attr).clear() # delete/free list items

# reset internal states
self._cache = None
Expand Down Expand Up @@ -870,6 +876,21 @@ def state_dict( # type: ignore[override] # todo
destination[prefix + key] = deepcopy(current_val)
return destination

def _copy_state_dict(self) -> Dict[str, Union[Tensor, List[Any]]]:
"""Copy the current state values."""
cache: Dict[str, Union[Tensor, List[Any]]] = {}
for attr in self._defaults:
current_value = getattr(self, attr)

if isinstance(current_value, Tensor):
cache[attr] = current_value.detach().clone().to(current_value.device)
else:
cache[attr] = [ # safely copy (non-graph leaf) Tensor elements
_.detach().clone().to(_.device) if isinstance(_, Tensor) else deepcopy(_) for _ in current_value
]

return cache

def _load_from_state_dict(
self,
state_dict: dict,
Expand Down
32 changes: 27 additions & 5 deletions tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,17 @@ class B(DummyListMetric):
metric = B()
assert isinstance(metric.x, list)
assert len(metric.x) == 0
metric.x = tensor(5)
metric.x = [tensor(5)]
metric.reset()
assert isinstance(metric.x, list)
assert len(metric.x) == 0

metric = B()
metric.x = [1, 2, 3]
reference = metric.x # prevents garbage collection
metric.reset()
assert len(reference) == 0 # check list state is freed


def test_reset_compute():
"""Test that `reset`+`compute` methods works as expected."""
Expand Down Expand Up @@ -474,18 +480,34 @@ def test_constant_memory_on_repeat_init():
def mem():
return torch.cuda.memory_allocated() / 1024**2

x = torch.randn(10000).cuda()

for i in range(100):
m = DummyListMetric(compute_with_cache=False).cuda()
m(x)
_ = DummyListMetric(compute_with_cache=False).cuda()
if i == 0:
after_one_iter = mem()

# allow for 5% flucturation due to measuring
assert after_one_iter * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_freed_memory_on_reset():
"""Test that resetting a metric frees all the memory allocated when updating it."""

def mem():
return torch.cuda.memory_allocated() / 1024**2

m = DummyListMetric().cuda()
after_init = mem()

for _ in range(100):
m(x=torch.randn(10000).cuda())

m.reset()

# allow for 5% flucturation due to measuring
assert after_init * 1.05 >= mem(), "memory increased too much above base level"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires gpu")
def test_specific_error_on_wrong_device():
"""Test that a specific error is raised if we detect input and metric are on different devices."""
Expand Down
Loading