diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f30ab1d83e..a6c5885b728 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419)) +- Moved check for number of samples in R2 score to support single sample updating ([#426](https://github.com/PyTorchLightning/metrics/pull/426)) + + ### Deprecated - Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371)) diff --git a/tests/regression/test_r2.py b/tests/regression/test_r2.py index d723c9478cb..6882191b870 100644 --- a/tests/regression/test_r2.py +++ b/tests/regression/test_r2.py @@ -143,6 +143,12 @@ def test_error_on_too_few_samples(metric_class=R2Score): metric = metric_class() with pytest.raises(ValueError, match="Needs at least two samples to calculate r2 score."): metric(torch.randn(1), torch.randn(1)) + metric.reset() + + # calling update twice should still work + metric.update(torch.randn(1), torch.randn(1)) + metric.update(torch.randn(1), torch.randn(1)) + assert metric.compute() def test_warning_on_too_large_adjusted(metric_class=R2Score): diff --git a/torchmetrics/functional/regression/r2.py b/torchmetrics/functional/regression/r2.py index c7b9073ef3e..c522f2e03fe 100644 --- a/torchmetrics/functional/regression/r2.py +++ b/torchmetrics/functional/regression/r2.py @@ -34,8 +34,6 @@ def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Ten "Expected both prediction and target to be 1D or 2D tensors," f" but received tensors with dimension {preds.shape}" ) - if len(preds) < 2: - raise ValueError("Needs at least two samples to calculate r2 score.") sum_obs = torch.sum(target, dim=0) sum_squared_obs = torch.sum(target * target, dim=0) @@ -77,6 +75,9 @@ def _r2_score_compute( >>> _r2_score_compute(sum_squared_obs, sum_obs, rss, n_obs, multioutput="raw_values") tensor([0.9654, 0.9082]) """ + if n_obs < 2: + raise ValueError("Needs at least two samples to calculate r2 score.") + mean_obs = sum_obs / n_obs tss = sum_squared_obs - sum_obs * mean_obs raw_scores = 1 - (rss / tss)