Skip to content

Commit

Permalink
Move checking in r2 for number of samples to compute (#426)
Browse files Browse the repository at this point in the history
* fix

* changelog
  • Loading branch information
SkafteNicki authored Aug 4, 2021
1 parent ad9845b commit b487ac3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use `torch.argmax` instead of `torch.topk` when `k=1` for better performance ([#419](https://github.com/PyTorchLightning/metrics/pull/419))


- Moved check for number of samples in R2 score to support single sample updating ([#426](https://github.com/PyTorchLightning/metrics/pull/426))


### Deprecated

- Rename `r2score` >> `r2_score` and `kldivergence` >> `kl_divergence` in `functional` ([#371](https://github.com/PyTorchLightning/metrics/pull/371))
Expand Down
6 changes: 6 additions & 0 deletions tests/regression/test_r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def test_error_on_too_few_samples(metric_class=R2Score):
metric = metric_class()
with pytest.raises(ValueError, match="Needs at least two samples to calculate r2 score."):
metric(torch.randn(1), torch.randn(1))
metric.reset()

# calling update twice should still work
metric.update(torch.randn(1), torch.randn(1))
metric.update(torch.randn(1), torch.randn(1))
assert metric.compute()


def test_warning_on_too_large_adjusted(metric_class=R2Score):
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/functional/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def _r2_score_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor, Ten
"Expected both prediction and target to be 1D or 2D tensors,"
f" but received tensors with dimension {preds.shape}"
)
if len(preds) < 2:
raise ValueError("Needs at least two samples to calculate r2 score.")

sum_obs = torch.sum(target, dim=0)
sum_squared_obs = torch.sum(target * target, dim=0)
Expand Down Expand Up @@ -77,6 +75,9 @@ def _r2_score_compute(
>>> _r2_score_compute(sum_squared_obs, sum_obs, rss, n_obs, multioutput="raw_values")
tensor([0.9654, 0.9082])
"""
if n_obs < 2:
raise ValueError("Needs at least two samples to calculate r2 score.")

mean_obs = sum_obs / n_obs
tss = sum_squared_obs - sum_obs * mean_obs
raw_scores = 1 - (rss / tss)
Expand Down

0 comments on commit b487ac3

Please sign in to comment.