Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move forward cache and computed to device #413

Merged
merged 9 commits into from
Aug 2, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `weighted`, `multi-class` AUROC computation to allow for 0 observations of some class, as contribution to final AUROC is 0 ([#348](https://github.com/PyTorchLightning/metrics/issues/348))


- Fixed that `_forward_cache` and `_computed` attributes are also moved to the correct device if metric is moved ([#413](https://github.com/PyTorchLightning/metrics/pull/413))


- Fixed calculation in `IoU` metric when using `ignore_index` argument ([#328](https://github.com/PyTorchLightning/metrics/pull/328))


Expand Down
23 changes: 22 additions & 1 deletion tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import nn, tensor

from tests.helpers import _LIGHTNING_GREATER_EQUAL_1_3, seed_all
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricSum
from tests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from torchmetrics.utilities.imports import _LIGHTNING_AVAILABLE, _TORCH_LOWER_1_6

seed_all(42)
Expand Down Expand Up @@ -279,6 +279,7 @@ def test_device_and_dtype_transfer(tmpdir):


def test_warning_on_compute_before_update():
""" test that an warning is raised if user tries to call compute before update """
metric = DummyMetricSum()

# make sure everything is fine with forward
Expand All @@ -301,13 +302,33 @@ def test_warning_on_compute_before_update():


def test_metric_scripts():
""" test that metrics are scriptable """
torch.jit.script(DummyMetric())
torch.jit.script(DummyMetricSum())


def test_metric_forward_cache_reset():
""" test that forward cache is reset when `reset` is called """
metric = DummyMetricSum()
_ = metric(2.0)
assert metric._forward_cache == 2.0
metric.reset()
assert metric._forward_cache is None


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
@pytest.mark.parametrize("metric_class", [DummyMetricSum, DummyMetricMultiOutput])
def test_forward_and_compute_to_device(metric_class):
metric = metric_class()
metric(1)
metric.to(device='cuda')

assert metric._forward_cache is not None
is_cuda = metric._forward_cache[0].is_cuda if isinstance(metric._forward_cache, list) \
else metric._forward_cache.is_cuda
assert is_cuda, 'forward cache was not moved to the correct device'

metric.compute()
assert metric._computed is not None
is_cuda = metric._computed[0].is_cuda if isinstance(metric._computed, list) else metric._computed.is_cuda
assert is_cuda, 'computed result was not moved to the correct device'
6 changes: 6 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,9 @@ def update(self, y):

def compute(self):
return self.x


class DummyMetricMultiOutput(DummyMetricSum):

def compute(self):
return [self.x, self.x]
7 changes: 7 additions & 0 deletions torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ def _apply(self, fn: Callable) -> Module:
"Expected metric state to be either a Tensor"
f"or a list of Tensor, but encountered {current_val}"
)

# Additional apply to forward cache and computed attributes (may be nested)
if this._computed is not None:
this._computed = apply_to_collection(this._computed, Tensor, fn)
if this._forward_cache is not None:
this._forward_cache = apply_to_collection(this._forward_cache, Tensor, fn)

return this

def persistent(self, mode: bool = False) -> None:
Expand Down