diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ed928f5681..044b2a72cba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,7 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed multi device aggregation in `PearsonCorrCoef` ([#998](https://github.com/PyTorchLightning/metrics/pull/998)) - diff --git a/tests/regression/test_pearson.py b/tests/regression/test_pearson.py index f1517b1f1c1..d8ce95a7866 100644 --- a/tests/regression/test_pearson.py +++ b/tests/regression/test_pearson.py @@ -51,8 +51,6 @@ def _sk_pearsonr(preds, target): ], ) class TestPearsonCorrcoef(MetricTester): - atol = 1e-2 - @pytest.mark.parametrize("compute_on_cpu", [True, False]) @pytest.mark.parametrize("ddp", [True, False]) def test_pearson_corrcoef(self, preds, target, compute_on_cpu, ddp): diff --git a/torchmetrics/regression/pearson.py b/torchmetrics/regression/pearson.py index c655e4b0b3e..832c7390670 100644 --- a/torchmetrics/regression/pearson.py +++ b/torchmetrics/regression/pearson.py @@ -36,19 +36,30 @@ def _final_aggregation( mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0] for i in range(1, len(means_x)): mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i] - nb = n1 + n2 mean_x = (n1 * mx1 + n2 * mx2) / nb mean_y = (n1 * my1 + n2 * my2) / nb - var_x = 1 / (n1 + n2 - 1) * ((n1 - 1) * vx1 + (n2 - 1) * vx2 + ((n1 * n2) / (n1 + n2)) * (mx1 - mx2) ** 2) - var_y = 1 / (n1 + n2 - 1) * ((n1 - 1) * vy1 + (n2 - 1) * vy2 + ((n1 * n2) / (n1 + n2)) * (my1 - my2) ** 2) - corr1 = n1 * cxy1 + n1 * (mx1 - mean_x) * (my1 - mean_y) - corr2 = n2 * cxy2 + n2 * (mx2 - mean_x) * (my2 - mean_y) - corr_xy = (corr1 + corr2) / (n1 + n2) + # var_x + element_x1 = (n1 + 1) * mean_x - n1 * mx1 + vx1 += (element_x1 - mx1) * (element_x1 - mean_x) - (element_x1 - mean_x) ** 2 + element_x2 = (n2 + 1) * mean_x - n2 * mx2 + vx2 += (element_x2 - mx2) * (element_x2 - mean_x) - (element_x2 - mean_x) ** 2 + var_x = vx1 + vx2 + + # var_y + element_y1 = (n1 + 1) * mean_y - n1 * my1 + vy1 += (element_y1 - my1) * (element_y1 - mean_y) - (element_y1 - mean_y) ** 2 + element_y2 = (n2 + 1) * mean_y - n2 * my2 + vy2 += (element_y2 - my2) * (element_y2 - mean_y) - (element_y2 - mean_y) ** 2 + var_y = vy1 + vy2 + + # corr + cxy1 += (element_x1 - mx1) * (element_y1 - mean_y) - (element_x1 - mean_x) * (element_y1 - mean_y) + cxy2 += (element_x2 - mx2) * (element_y2 - mean_y) - (element_x2 - mean_x) * (element_y2 - mean_y) + corr_xy = cxy1 + cxy2 mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb - return var_x, var_y, corr_xy, nb @@ -123,5 +134,4 @@ def compute(self) -> Tensor: var_y = self.var_y corr_xy = self.corr_xy n_total = self.n_total - return _pearson_corrcoef_compute(var_x, var_y, corr_xy, n_total)