diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f557931cbc..661ed4d6be5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348)) +- Fixed that `_forward_cache` and `_computed` attributes are also moved to the correct device if metric is moved ([#413](https://github.com/PyTorchLightning/metrics/pull/413)) + + - Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328)) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 736740aab5d..222ca15802e 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -21,7 +21,7 @@ from torch import nn, tensor from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all -from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6 seed_all(42) @@ -279,6 +279,7 @@ def test_device_and_dtype_transfer(tmpdir): def test_warning_on_compute_before_update(): + """ test that an warning is raised if user tries to call compute before update """ metric = DummyMetricSum() # make sure everything is fine with forward @@ -301,13 +302,33 @@ def test_warning_on_compute_before_update(): def test_metric_scripts(): + """ test that metrics are scriptable """ torch.jit.script(DummyMetric()) torch.jit.script(DummyMetricSum()) def test_metric_forward_cache_reset(): + """ test that forward cache is reset when `reset` is called """ metric = DummyMetricSum() _ = metric(2.0) assert metric._forward_cache == 2.0 metric.reset() assert metric._forward_cache is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput]) +def test_forward_and_compute_to_device(metric_class): + metric = metric_class() + metric(1) + metric.to(device='cuda') + + assert metric._forward_cache is not None + is_cuda = metric._forward_cache[0].is_cuda if isinstance(metric._forward_cache, list) \ + else metric._forward_cache.is_cuda + assert is_cuda, 'forward cache was not moved to the correct device' + + metric.compute() + assert metric._computed is not None + is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda + assert is_cuda, 'computed result was not moved to the correct device' diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 3da6f68133a..20628474cf3 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -541,3 +541,9 @@ def update(self, y): def compute(self): return self.x + + +class DummyMetricMultiOutput(DummyMetricSum): + + def compute(self): + return [self.x, self.x] diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 64934001559..d5e30fc355b 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -444,6 +444,13 @@ def _apply(self, fn: Callable) -> Module: "Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}" ) + + # Additional apply to forward cache and computed attributes (may be nested) + if this._computed is not None: + this._computed = apply_to_collection(this._computed, Tensor, fn) + if this._forward_cache is not None: + this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn) + return this def persistent(self, mode: bool = False) -> None: