From 9ef98b430b6930ecc146cc3948d9f7abb595801c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 24 Sep 2021 11:08:48 +0200 Subject: [PATCH] Fix child device (#542) * fix * fix * changelog * remove legacy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ tests/bases/test_metric.py | 32 ++++++++++++++++++++++++-- torchmetrics/metric.py | 46 ++++---------------------------------- 3 files changed, 36 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 81654fbd6af..7af766d1c87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539)) +- Fixed bug where `device` property was not properly update when metric was a child of a module ([#542](https://github.com/PyTorchLightning/metrics/pull/542)) + ## [0.5.1] - 2021-08-30 ### Added diff --git a/tests/bases/test_metric.py b/tests/bases/test_metric.py index 88b8ee27c0e..5d65f14f202 100644 --- a/tests/bases/test_metric.py +++ b/tests/bases/test_metric.py @@ -18,7 +18,7 @@ import numpy as np import pytest import torch -from torch import nn, tensor +from torch import Tensor, nn, tensor from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum @@ -258,7 +258,7 @@ def test_device_and_dtype_transfer(tmpdir): metric = metric.to(device="cuda") assert metric.x.is_cuda - assert metric.device == torch.device("cuda") + assert metric.device == torch.device("cuda", index=0) metric.set_dtype(torch.double) assert metric.x.dtype == torch.float64 @@ -326,3 +326,31 @@ def test_forward_and_compute_to_device(metric_class): assert metric._computed is not None is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda assert is_cuda, "computed result was not moved to the correct device" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.") +@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput]) +def test_device_if_child_module(metric_class): + """Test that if a metric is a child module all values gets moved to the correct device.""" + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.metric = metric_class() + self.register_buffer("dummy", torch.zeros(1)) + + @property + def device(self): + return self.dummy.device + + module = TestModule() + + assert module.device == module.metric.device + if isinstance(module.metric.x, Tensor): + assert module.device == module.metric.x.device + + module.to(device="cuda") + + assert module.device == module.metric.device + if isinstance(module.metric.x, Tensor): + assert module.device == module.metric.x.device diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index 698ae799cef..2d1020932d4 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -417,48 +417,6 @@ def device(self) -> "torch.device": """Return the device of the metric.""" return self._device - def to(self, *args: Any, **kwargs: Any) -> "Metric": - """Moves the parameters and buffers. - - Normal dtype casting is not supported by this method instead use the `set_dtype` method instead. - """ - out = torch._C._nn._parse_to(*args, **kwargs) - if len(out) == 4: # pytorch 1.5 and higher - device, dtype, non_blocking, convert_to_format = out - else: # pytorch 1.4 and lower - device, dtype, non_blocking = out - convert_to_format = None - dtype = None # prevent dtype being casted - - def convert(t: Tensor) -> Tensor: - if convert_to_format is not None and t.dim() in (4, 5): - return t.to( - device, - dtype if t.is_floating_point() or t.is_complex() else None, - non_blocking, - memory_format=convert_to_format, - ) - return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) - - self._device = device - return self._apply(convert) - - def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric": - """Moves all model parameters and buffers to the GPU. - - Arguments: - device: if specified, all parameters will be copied to that device - """ - if device is None or isinstance(device, int): - device = torch.device("cuda", index=device) - self._device = device - return super().cuda(device=device) - - def cpu(self) -> "Metric": - """Moves all model parameters and buffers to the CPU.""" - self._device = torch.device("cpu") - return super().cpu() - def type(self, dst_type: Union[str, torch.dtype]) -> "Metric": """Method override default and prevent dtype casting. @@ -515,6 +473,10 @@ def _apply(self, fn: Callable) -> Module: "Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}" ) + # make sure to update the device attribute + # if the dummy tensor moves device by fn function we should also update the attribute + self._device = fn(torch.zeros(1, device=self.device)).device + # Additional apply to forward cache and computed attributes (may be nested) if this._computed is not None: this._computed = apply_to_collection(this._computed, Tensor, fn)