From 615ceecdc87927c0d30499cac821e1a4db2feae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 12 Jan 2022 09:09:36 +0100 Subject: [PATCH] Avoid in-place ops during logging result updates (#11401) Co-authored-by: rohitgr7 --- CHANGELOG.md | 3 ++- .../connectors/logger_connector/result.py | 8 +++++--- tests/core/test_metric_result_integration.py | 20 +++++++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9776d5fa85bf3..db46c604f00e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,12 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [1.5.9] - 2022-01-11 +## [1.5.9] - 2022-01-18 ### Fixed - Pin sphinx-autodoc-typehints with None: # do not set a dtype in case the default dtype was changed self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: + self.cumulated_batch_size: torch.Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) def update(self, value: _IN_METRIC, batch_size: int) -> None: @@ -240,12 +241,13 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # perform accumulation with reduction if self.meta.is_mean_reduction: - self.value += value.mean() * batch_size - self.cumulated_batch_size += batch_size + # do not use `+=` as it doesn't do type promotion + self.value = self.value + value.mean() * batch_size + self.cumulated_batch_size = self.cumulated_batch_size + batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: - self.value += value.mean() + self.value = self.value + value.mean() else: self.value = value self._forward_cache = value._forward_cache diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index d660054d39d13..85149a78211f6 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -593,6 +593,26 @@ def test_metric_result_respects_dtype(floating_dtype): torch.set_default_dtype(torch.float) +@pytest.mark.parametrize("reduce_fx", ("mean", sum)) +def test_metric_result_dtype_promotion(reduce_fx): + metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) + metadata.sync = _Sync() + rm = ResultMetric(metadata, is_tensor=True) + assert rm.value.dtype == torch.float + + # log a double + rm.update(torch.tensor(0, dtype=torch.double), 1) + # `rm.value.dtype` is promoted + assert rm.value.dtype == torch.double + # log a float + rm.update(torch.tensor(0, dtype=torch.float), 1) + # the previous dtype stays + assert rm.value.dtype == torch.double + + total = rm.compute() + assert total.dtype == torch.double + + @pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)]) def test_result_metric_max_min(reduce_fx, expected): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)