From a8bf78ba1e2afce050940669093d7edb70abc8c0 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Jul 2021 17:08:44 +0200 Subject: [PATCH 1/3] move to device --- tests/bases/test_metric.py | 23 ++++++++++++++++++++++- tests/helpers/testers.py | 5 +++++ torchmetrics/metric.py | 7 +++++++ 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 736740aab5d..7b17254b8d8 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -21,7 +21,7 @@ from torch import nn, tensor from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all -from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum +from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum, DummyMetricMultiOutput from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6 seed_all(42) @@ -279,6 +279,7 @@ def test_device_and_dtype_transfer(tmpdir): def test_warning_on_compute_before_update(): + """ test that an warning is raised if user tries to call compute before update """ metric = DummyMetricSum() # make sure everything is fine with forward @@ -301,13 +302,33 @@ def test_warning_on_compute_before_update(): def test_metric_scripts(): + """ test that metrics are scriptable """ torch.jit.script(DummyMetric()) torch.jit.script(DummyMetricSum()) def test_metric_forward_cache_reset(): + """ test that forward cache is reset when `reset` is called """ metric = DummyMetricSum() _ = metric(2.0) assert metric._forward_cache == 2.0 metric.reset() assert metric._forward_cache is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput]) +def test_forward_and_compute_to_device(metric_class): + metric = metric_class() + metric(1) + metric.to(device='cuda') + + assert metric._forward_cache is not None + is_cuda = metric._forward_cache[0].is_cuda if isinstance(metric._forward_cache, list) \ + else metric._forward_cache.is_cuda + assert is_cuda, 'forward cache was not moved to the correct device' + + metric.compute() + assert metric._computed is not None + is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda + assert is_cuda, 'computed result was not moved to the correct device' diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index 3da6f68133a..e8602c6a49b 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -541,3 +541,8 @@ def update(self, y): def compute(self): return self.x + + +class DummyMetricMultiOutput(DummyMetricSum): + def compute(self): + return [self.x, self.x] diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 64934001559..d5e30fc355b 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -444,6 +444,13 @@ def _apply(self, fn: Callable) -> Module: "Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}" ) + + # Additional apply to forward cache and computed attributes (may be nested) + if this._computed is not None: + this._computed = apply_to_collection(this._computed, Tensor, fn) + if this._forward_cache is not None: + this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn) + return this def persistent(self, mode: bool = False) -> None: From f94c3e045a26a4b423c6b888a93af1da07c2b221 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 15:11:01 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_metric.py | 2 +- tests/helpers/testers.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 7b17254b8d8..222ca15802e 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -21,7 +21,7 @@ from torch import nn, tensor from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all -from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum, DummyMetricMultiOutput +from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6 seed_all(42) diff --git a/tests/helpers/testers.py b/tests/helpers/testers.py index e8602c6a49b..20628474cf3 100644 --- a/tests/helpers/testers.py +++ b/tests/helpers/testers.py @@ -544,5 +544,6 @@ def compute(self): class DummyMetricMultiOutput(DummyMetricSum): + def compute(self): return [self.x, self.x] From faef0136fd13aabd53dd75d4985761c27f2cd8b5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 29 Jul 2021 17:12:25 +0200 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6473041ae79..5508fb6cec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348)) +- Fixed that `_forward_cache` and `_computed` attributes are also moved to the correct device if metric is moved ([#413](https://github.com/PyTorchLightning/metrics/pull/413)) + + ## [0.4.1] - 2021-07-05 ### Changed