Skip to content

Commit cbe6f3e

Browse files
committed
Fix single update in pearson corrcoef (Lightning-AI#2019)
* fix * changelog
1 parent 0e97213 commit cbe6f3e

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626

2727
### Fixed
2828

29+
- Fixed bug in `PearsonCorrCoef` is updated on single samples at a time ([#2019](https://github.com/Lightning-AI/torchmetrics/pull/2019)
30+
31+
2932
- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017)
3033

3134

src/torchmetrics/functional/regression/pearson.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,9 @@ def _pearson_corrcoef_update(
5757
if weights is not None:
5858
_check_data_shape_to_weights(preds, weights)
5959

60-
cond = n_prior.mean() > 0
61-
6260
n_obs = preds.shape[0] if weights is None else weights.sum()
61+
cond = n_prior.mean() > 0 or n_obs == 1
6362

64-
# Calculate means
6563
if cond:
6664
if weights is None:
6765
mx_new = (n_prior * mean_x + preds.sum(0)) / (n_prior + n_obs)

tests/unittests/regression/test_pearson.py

+17
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,20 @@ def test_pearsons_warning_on_small_input(dtype, scale):
193193
target = scale * torch.randn(100, dtype=dtype)
194194
with pytest.warns(UserWarning, match="The variance of predictions or target is close to zero.*"):
195195
pearson_corrcoef(preds, target)
196+
197+
198+
def test_single_sample_update():
199+
"""See issue: https://github.com/Lightning-AI/torchmetrics/issues/2014."""
200+
metric = PearsonCorrCoef()
201+
202+
# Works
203+
metric(torch.tensor([3.0, -0.5, 2.0, 7.0]), torch.tensor([2.5, 0.0, 2.0, 8.0]))
204+
res1 = metric.compute()
205+
metric.reset()
206+
207+
metric(torch.tensor([3.0]), torch.tensor([2.5]))
208+
metric(torch.tensor([-0.5]), torch.tensor([0.0]))
209+
metric(torch.tensor([2.0]), torch.tensor([2.0]))
210+
metric(torch.tensor([7.0]), torch.tensor([8.0]))
211+
res2 = metric.compute()
212+
assert torch.allclose(res1, res2)

0 commit comments

Comments
 (0)