Skip to content

Commit

Permalink
Move forward cache and computed to device (#413)
Browse files Browse the repository at this point in the history
* move to device

* changelog

Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 2, 2021
1 parent d5aa720 commit 4fd18b3
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


- Fixed that `_forward_cache` and `_computed` attributes are also moved to the correct device if metric is moved ([#413](https://github.com/PyTorchLightning/metrics/pull/413))


- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


Expand Down
23 changes: 22 additions & 1 deletion tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn, tensor

from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6

seed_all(42)
Expand Down Expand Up @@ -279,6 +279,7 @@ def test_device_and_dtype_transfer(tmpdir):


def test_warning_on_compute_before_update():
""" test that an warning is raised if user tries to call compute before update """
metric = DummyMetricSum()

# make sure everything is fine with forward
Expand All @@ -301,13 +302,33 @@ def test_warning_on_compute_before_update():


def test_metric_scripts():
""" test that metrics are scriptable """
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())


def test_metric_forward_cache_reset():
""" test that forward cache is reset when `reset` is called """
metric = DummyMetricSum()
_ = metric(2.0)
assert metric._forward_cache == 2.0
metric.reset()
assert metric._forward_cache is None


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput])
def test_forward_and_compute_to_device(metric_class):
metric = metric_class()
metric(1)
metric.to(device='cuda')

assert metric._forward_cache is not None
is_cuda = metric._forward_cache[0].is_cuda if isinstance(metric._forward_cache, list) \
else metric._forward_cache.is_cuda
assert is_cuda, 'forward cache was not moved to the correct device'

metric.compute()
assert metric._computed is not None
is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda
assert is_cuda, 'computed result was not moved to the correct device'
6 changes: 6 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,9 @@ def update(self, y):

def compute(self):
return self.x


class DummyMetricMultiOutput(DummyMetricSum):

def compute(self):
return [self.x, self.x]
7 changes: 7 additions & 0 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ def _apply(self, fn: Callable) -> Module:
"Expected metric state to be either a Tensor"
f"or a list of Tensor, but encountered {current_val}"
)

# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
this._computed = apply_to_collection(this._computed, Tensor, fn)
if this._forward_cache is not None:
this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)

return this

def persistent(self, mode: bool = False) -> None:
Expand Down

0 comments on commit 4fd18b3

Please sign in to comment.