Skip to content

Commit

Permalink
Deprecate num_outputs in R2 because it is no longer needed (#2705)
Browse files Browse the repository at this point in the history
* trying to see what happens
* fix doctests
* add deprecation test
* Apply suggestions from code review

---------

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
SkafteNicki and Borda authored Sep 11, 2024
1 parent 708f11d commit 4ce2278
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- update `InfoLM` class to dynamically set `higher_is_better` ([#2674](https://github.com/Lightning-AI/torchmetrics/pull/2674))


### Deprecated

- Deprecated `num_outputs` in `R2Score` ([#2705](https://github.com/Lightning-AI/torchmetrics/pull/2705))


### Removed

-
Expand Down
47 changes: 31 additions & 16 deletions src/torchmetrics/regression/r2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

Expand Down Expand Up @@ -65,23 +65,32 @@ class R2Score(Metric):
* ``'variance_weighted'`` scores are weighted by their individual variances
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
.. warning::
Argument ``num_outputs`` in ``R2Score`` has been deprecated because it is no longer necessary and will be
removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape
of the input tensors.
Raises:
ValueError:
If ``adjusted`` parameter is not an integer larger or equal to 0.
ValueError:
If ``multioutput`` is not one of ``"raw_values"``, ``"uniform_average"`` or ``"variance_weighted"``.
Example:
Example (single output):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> target = tensor([3, -0.5, 2, 7])
>>> preds = tensor([2.5, 0.0, 2, 8])
>>> r2score = R2Score()
>>> r2score(preds, target)
tensor(0.9486)
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
Example (multioutput):
>>> from torch import tensor
>>> from torchmetrics.regression import R2Score
>>> target = tensor([[0.5, 1], [-1, 1], [7, -6]])
>>> preds = tensor([[0, 2], [-1, 2], [8, -5]])
>>> r2score = R2Score(multioutput='raw_values')
>>> r2score(preds, target)
tensor([0.9654, 0.9082])
Expand All @@ -100,14 +109,20 @@ class R2Score(Metric):

def __init__(
self,
num_outputs: int = 1,
num_outputs: Optional[int] = None,
adjusted: int = 0,
multioutput: str = "uniform_average",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.num_outputs = num_outputs
if num_outputs is not None:
rank_zero_warn(
"Argument `num_outputs` in `R2Score` has been deprecated because it is no longer necessary and will be"
"removed in v1.6.0 of TorchMetrics. The number of outputs is now automatically inferred from the shape"
"of the input tensors.",
DeprecationWarning,
)

if adjusted < 0 or not isinstance(adjusted, int):
raise ValueError("`adjusted` parameter should be an integer larger or equal to 0.")
Expand All @@ -120,19 +135,19 @@ def __init__(
)
self.multioutput = multioutput

self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("sum_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("residual", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target)

self.sum_squared_error += sum_squared_error
self.sum_error += sum_error
self.residual += residual
self.total += total
self.sum_squared_error = self.sum_squared_error + sum_squared_error
self.sum_error = self.sum_error + sum_error
self.residual = self.residual + residual
self.total = self.total + total

def compute(self) -> Tensor:
"""Compute r2 score over the metric states."""
Expand Down
8 changes: 7 additions & 1 deletion tests/unittests/test_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch
from torchmetrics.functional.regression import kl_divergence
from torchmetrics.regression import KLDivergence
from torchmetrics.regression import KLDivergence, R2Score


def test_deprecated_kl_divergence_input_order():
Expand All @@ -14,3 +14,9 @@ def test_deprecated_kl_divergence_input_order():

with pytest.deprecated_call(match="The input order and naming in metric `KLDivergence` is set to be deprecated.*"):
KLDivergence()


def test_deprecated_r2_score_num_outputs():
"""Ensure that the deprecated num_outputs argument in R2Score raises a warning."""
with pytest.deprecated_call(match="Argument `num_outputs` in `R2Score` has been deprecated"):
R2Score(num_outputs=2)

0 comments on commit 4ce2278

Please sign in to comment.