From f3b62788edd30a8ee409a19255557f2649fb9a4d Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:21:30 +0200 Subject: [PATCH] fix: pearson changes inputs (#2765) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nicki Skafte Detlefsen --- CHANGELOG.md | 2 +- src/torchmetrics/functional/image/rmse_sw.py | 3 ++- .../functional/regression/concordance.py | 2 ++ .../functional/regression/pearson.py | 7 +++--- tests/unittests/regression/test_pearson.py | 22 +++++++++++++++++++ 5 files changed, 31 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a9d9fc6764..5700a43a98b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,7 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765)) ## [1.4.2] - 2022-09-12 diff --git a/src/torchmetrics/functional/image/rmse_sw.py b/src/torchmetrics/functional/image/rmse_sw.py index 16f1a0af80a..a27582bd11a 100644 --- a/src/torchmetrics/functional/image/rmse_sw.py +++ b/src/torchmetrics/functional/image/rmse_sw.py @@ -104,7 +104,8 @@ def _rmse_sw_compute( """ rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None if rmse_map is not None: - rmse_map /= total_images + # prevent overwrite the inputs + rmse_map = rmse_map / total_images return rmse, rmse_map diff --git a/src/torchmetrics/functional/regression/concordance.py b/src/torchmetrics/functional/regression/concordance.py index e18afe2f619..501cf8da054 100644 --- a/src/torchmetrics/functional/regression/concordance.py +++ b/src/torchmetrics/functional/regression/concordance.py @@ -27,6 +27,8 @@ def _concordance_corrcoef_compute( ) -> Tensor: """Compute the final concordance correlation coefficient based on accumulated statistics.""" pearson = _pearson_corrcoef_compute(var_x, var_y, corr_xy, nb) + var_x = var_x / (nb - 1) + var_y = var_y / (nb - 1) return 2.0 * pearson * var_x.sqrt() * var_y.sqrt() / (var_x + var_y + (mean_x - mean_y) ** 2) diff --git a/src/torchmetrics/functional/regression/pearson.py b/src/torchmetrics/functional/regression/pearson.py index c98bc65a85c..47b26344163 100644 --- a/src/torchmetrics/functional/regression/pearson.py +++ b/src/torchmetrics/functional/regression/pearson.py @@ -92,9 +92,10 @@ def _pearson_corrcoef_compute( nb: number of observations """ - var_x /= nb - 1 - var_y /= nb - 1 - corr_xy /= nb - 1 + # prevent overwrite the inputs + var_x = var_x / (nb - 1) + var_y = var_y / (nb - 1) + corr_xy = corr_xy / (nb - 1) # if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16 # on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"): diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index 0d23507aeed..07cbf3fd65c 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -164,3 +164,25 @@ def test_single_sample_update(): metric(torch.tensor([7.0]), torch.tensor([8.0])) res2 = metric.compute() assert torch.allclose(res1, res2) + + +def test_overwrite_reference_inputs(): + """Test that the normalizations does not overwrite inputs. + + Variables var_x, var_y, corr_xy are references to the object variables and get incorrectly scaled down such that + when you update again and compute you get very wrong values. + + """ + y = torch.randn(100) + y_pred = y + torch.randn(y.shape) / 5 + # Initialize Pearson correlation coefficient metric + pearson = PearsonCorrCoef() + # Compute the Pearson correlation coefficient + correlation = pearson(y, y_pred) + + pearson = PearsonCorrCoef() + for lower, upper in [(0, 33), (33, 66), (66, 99), (99, 100)]: + pearson.update(torch.tensor(y[lower:upper]), torch.tensor(y_pred[lower:upper])) + pearson.compute() + + assert torch.isclose(pearson.compute(), correlation)