Skip to content

Commit

Permalink
Merge branch 'master' into aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki authored Sep 24, 2021
2 parents 04e44f9 + 9ef98b4 commit 15bfcb6
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `SSIM` metric using too much memory ([#539](https://github.com/PyTorchLightning/metrics/pull/539))


- Fixed bug where `device` property was not properly update when metric was a child of a module ([#542](https://github.com/PyTorchLightning/metrics/pull/542))

## [0.5.1] - 2021-08-30

### Added
Expand Down
32 changes: 30 additions & 2 deletions tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np
import pytest
import torch
from torch import nn, tensor
from torch import Tensor, nn, tensor

from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -258,7 +258,7 @@ def test_device_and_dtype_transfer(tmpdir):

metric = metric.to(device="cuda")
assert metric.x.is_cuda
assert metric.device == torch.device("cuda")
assert metric.device == torch.device("cuda", index=0)

metric.set_dtype(torch.double)
assert metric.x.dtype == torch.float64
Expand Down Expand Up @@ -326,3 +326,31 @@ def test_forward_and_compute_to_device(metric_class):
assert metric._computed is not None
is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda
assert is_cuda, "computed result was not moved to the correct device"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput])
def test_device_if_child_module(metric_class):
"""Test that if a metric is a child module all values gets moved to the correct device."""

class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.metric = metric_class()
self.register_buffer("dummy", torch.zeros(1))

@property
def device(self):
return self.dummy.device

module = TestModule()

assert module.device == module.metric.device
if isinstance(module.metric.x, Tensor):
assert module.device == module.metric.x.device

module.to(device="cuda")

assert module.device == module.metric.device
if isinstance(module.metric.x, Tensor):
assert module.device == module.metric.x.device
46 changes: 4 additions & 42 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,48 +421,6 @@ def device(self) -> "torch.device":
"""Return the device of the metric."""
return self._device

def to(self, *args: Any, **kwargs: Any) -> "Metric":
"""Moves the parameters and buffers.
Normal dtype casting is not supported by this method instead use the `set_dtype` method instead.
"""
out = torch._C._nn._parse_to(*args, **kwargs)
if len(out) == 4: # pytorch 1.5 and higher
device, dtype, non_blocking, convert_to_format = out
else: # pytorch 1.4 and lower
device, dtype, non_blocking = out
convert_to_format = None
dtype = None # prevent dtype being casted

def convert(t: Tensor) -> Tensor:
if convert_to_format is not None and t.dim() in (4, 5):
return t.to(
device,
dtype if t.is_floating_point() or t.is_complex() else None,
non_blocking,
memory_format=convert_to_format,
)
return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)

self._device = device
return self._apply(convert)

def cuda(self, device: Optional[Union[torch.device, int]] = None) -> "Metric":
"""Moves all model parameters and buffers to the GPU.
Arguments:
device: if specified, all parameters will be copied to that device
"""
if device is None or isinstance(device, int):
device = torch.device("cuda", index=device)
self._device = device
return super().cuda(device=device)

def cpu(self) -> "Metric":
"""Moves all model parameters and buffers to the CPU."""
self._device = torch.device("cpu")
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> "Metric":
"""Method override default and prevent dtype casting.
Expand Down Expand Up @@ -519,6 +477,10 @@ def _apply(self, fn: Callable) -> Module:
"Expected metric state to be either a Tensor" f"or a list of Tensor, but encountered {current_val}"
)

# make sure to update the device attribute
# if the dummy tensor moves device by fn function we should also update the attribute
self._device = fn(torch.zeros(1, device=self.device)).device

# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
this._computed = apply_to_collection(this._computed, Tensor, fn)
Expand Down

0 comments on commit 15bfcb6

Please sign in to comment.