diff --git a/CHANGELOG.md b/CHANGELOG.md index 42a60a89ec9..8170e72f05a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * `DistanceIntersectionOverUnion` +- Added `RelativeSquaredError` metric to regression package ([#1765](https://github.com/Lightning-AI/torchmetrics/pull/1765)) + + ### Changed - Changed `update_count` and `update_called` from private to public methods ([#1370](https://github.com/Lightning-AI/metrics/pull/1370)) diff --git a/docs/source/regression/rse.rst b/docs/source/regression/rse.rst new file mode 100644 index 00000000000..80c086f7c98 --- /dev/null +++ b/docs/source/regression/rse.rst @@ -0,0 +1,23 @@ +.. customcarditem:: + :header: Relative Squared Error + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Regression + +.. include:: ../links.rst + +############################ +Relative Squared Error (RSE) +############################ + +Module Interface +________________ + +.. autoclass:: torchmetrics.RelativeSquaredError + :noindex: + :exclude-members: update, compute + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.relative_squared_error + :noindex: diff --git a/src/torchmetrics/__init__.py b/src/torchmetrics/__init__.py index b671fa88393..141ff1efd3c 100644 --- a/src/torchmetrics/__init__.py +++ b/src/torchmetrics/__init__.py @@ -91,6 +91,7 @@ MinkowskiDistance, PearsonCorrCoef, R2Score, + RelativeSquaredError, SpearmanCorrCoef, SymmetricMeanAbsolutePercentageError, TweedieDevianceScore, @@ -190,6 +191,7 @@ "Recall", "RecallAtFixedPrecision", "RelativeAverageSpectralError", + "RelativeSquaredError", "RetrievalFallOut", "RetrievalHitRate", "RetrievalMAP", diff --git a/src/torchmetrics/functional/__init__.py b/src/torchmetrics/functional/__init__.py index 00f734c2a24..87a12d2f118 100644 --- a/src/torchmetrics/functional/__init__.py +++ b/src/torchmetrics/functional/__init__.py @@ -100,6 +100,7 @@ minkowski_distance, pearson_corrcoef, r2_score, + relative_squared_error, spearman_corrcoef, symmetric_mean_absolute_percentage_error, tweedie_deviance_score, @@ -190,6 +191,7 @@ "r2_score", "recall", "relative_average_spectral_error", + "relative_squared_error", "retrieval_average_precision", "retrieval_fall_out", "retrieval_hit_rate", diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index d8f99b610ba..a4336675679 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -25,6 +25,7 @@ from torchmetrics.functional.regression.mse import mean_squared_error from torchmetrics.functional.regression.pearson import pearson_corrcoef from torchmetrics.functional.regression.r2 import r2_score +from torchmetrics.functional.regression.rse import relative_squared_error from torchmetrics.functional.regression.spearman import spearman_corrcoef from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score @@ -45,6 +46,7 @@ "mean_absolute_percentage_error", "minkowski_distance", "r2_score", + "relative_squared_error", "spearman_corrcoef", "symmetric_mean_absolute_percentage_error", "tweedie_deviance_score", diff --git a/src/torchmetrics/functional/regression/rse.py b/src/torchmetrics/functional/regression/rse.py new file mode 100644 index 00000000000..12266b5b98a --- /dev/null +++ b/src/torchmetrics/functional/regression/rse.py @@ -0,0 +1,78 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Union + +import torch +from torch import Tensor + +from torchmetrics.functional.regression.r2 import _r2_score_update + + +def _relative_squared_error_compute( + sum_squared_obs: Tensor, + sum_obs: Tensor, + sum_squared_error: Tensor, + n_obs: Union[int, Tensor], + squared: bool = True, +) -> Tensor: + """Computes Relative Squared Error. + + Args: + sum_squared_obs: Sum of square of all observations + sum_obs: Sum of all observations + sum_squared_error: Residual sum of squares + n_obs: Number of predictions or observations + squared: Returns RRSE value if set to False. + + Example: + >>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]]) + >>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]]) + >>> # RSE uses the same update function as R2 score. + >>> sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target) + >>> _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=True) + tensor(0.0632) + """ + epsilon = torch.finfo(sum_squared_error.dtype).eps + rse = sum_squared_error / torch.clamp(sum_squared_obs - sum_obs * sum_obs / n_obs, min=epsilon) + if not squared: + rse = torch.sqrt(rse) + return torch.mean(rse) + + +def relative_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: + r"""Computes the relative squared error (RSE). + + .. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2} + + Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and + :math:`\hat{y}` is a tensor of predictions. + + If `preds` and `targets` are 2D tensors, the RSE is averaged over the second dim. + + Args: + preds: estimated labels + target: ground truth labels + squared: returns RRSE value if set to False + Return: + Tensor with RSE + + Example: + >>> from torchmetrics.functional.regression import relative_squared_error + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> relative_squared_error(preds, target) + tensor(0.0514) + """ + sum_squared_obs, sum_obs, rss, n_obs = _r2_score_update(preds, target) + return _relative_squared_error_compute(sum_squared_obs, sum_obs, rss, n_obs, squared=squared) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index f2b2c07dae9..21de7bfc852 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -24,6 +24,7 @@ from torchmetrics.regression.mse import MeanSquaredError from torchmetrics.regression.pearson import PearsonCorrCoef from torchmetrics.regression.r2 import R2Score +from torchmetrics.regression.rse import RelativeSquaredError from torchmetrics.regression.spearman import SpearmanCorrCoef from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore @@ -43,6 +44,7 @@ "MeanSquaredError", "PearsonCorrCoef", "R2Score", + "RelativeSquaredError", "SpearmanCorrCoef", "SymmetricMeanAbsolutePercentageError", "TweedieDevianceScore", diff --git a/src/torchmetrics/regression/rse.py b/src/torchmetrics/regression/rse.py new file mode 100644 index 00000000000..20105d5db24 --- /dev/null +++ b/src/torchmetrics/regression/rse.py @@ -0,0 +1,141 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Sequence, Union + +import torch +from torch import Tensor, tensor + +from torchmetrics.functional.regression.r2 import _r2_score_update +from torchmetrics.functional.regression.rse import _relative_squared_error_compute +from torchmetrics.metric import Metric +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["RelativeSquaredError.plot"] + + +class RelativeSquaredError(Metric): + r"""Computes the relative squared error (RSE). + + .. math:: \text{RSE} = \frac{\sum_i^N(y_i - \hat{y_i})^2}{\sum_i^N(y_i - \overline{y})^2} + + Where :math:`y` is a tensor of target values with mean :math:`\overline{y}`, and + :math:`\hat{y}` is a tensor of predictions. + + If num_outputs > 1, the returned value is averaged over all the outputs. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): Predictions from model in float tensor with shape ``(N,)`` + or ``(N, M)`` (multioutput) + - ``target`` (:class:`~torch.Tensor`): Ground truth values in float tensor with shape ``(N,)`` + or ``(N, M)`` (multioutput) + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``rse`` (:class:`~torch.Tensor`): A tensor with the RSE score(s) + + Args: + num_outputs: Number of outputs in multioutput setting + squared: If True returns RSE value, if False returns RRSE value. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example: + >>> from torchmetrics.regression import RelativeSquaredError + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> relative_squared_error = RelativeSquaredError() + >>> relative_squared_error(preds, target) + tensor(0.0514) + """ + is_differentiable = True + higher_is_better = False + full_state_update = False + sum_squared_error: Tensor + sum_error: Tensor + residual: Tensor + total: Tensor + + def __init__( + self, + num_outputs: int = 1, + squared: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + self.squared = squared + + def update(self, preds: Tensor, target: Tensor) -> None: + """Update state with predictions and targets.""" + sum_squared_error, sum_error, residual, total = _r2_score_update(preds, target) + + self.sum_squared_error += sum_squared_error + self.sum_error += sum_error + self.residual += residual + self.total += total + + def compute(self) -> Tensor: + """Computes relative squared error over state.""" + return _relative_squared_error_compute( + self.sum_squared_error, self.sum_error, self.residual, self.total, squared=self.squared + ) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import RelativeSquaredError + >>> metric = RelativeSquaredError() + >>> metric.update(randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import RelativeSquaredError + >>> metric = RelativeSquaredError() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + """ + return self._plot(val, ax) diff --git a/tests/unittests/regression/test_rse.py b/tests/unittests/regression/test_rse.py new file mode 100644 index 00000000000..26ccab2bcc1 --- /dev/null +++ b/tests/unittests/regression/test_rse.py @@ -0,0 +1,148 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import namedtuple +from functools import partial + +import numpy as np +import pytest +import torch + +from torchmetrics.functional import relative_squared_error +from torchmetrics.regression import RelativeSquaredError +from unittests import BATCH_SIZE, NUM_BATCHES +from unittests.helpers import seed_all +from unittests.helpers.testers import MetricTester + +seed_all(42) + +num_targets = 5 + +Input = namedtuple("Input", ["preds", "target"]) + +_single_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE), + target=torch.rand(NUM_BATCHES, BATCH_SIZE), +) + +_multi_target_inputs = Input( + preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), + target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets), +) + + +def _sk_rse(target, preds, squared): + mean = np.mean(target, axis=0, keepdims=True) + error = target - preds + sum_squared_error = np.sum(error * error, axis=0) + deviation = target - mean + sum_squared_deviation = np.sum(deviation * deviation, axis=0) + rse = sum_squared_error / np.maximum(sum_squared_deviation, 1.17e-06) + if not squared: + rse = np.sqrt(rse) + return np.mean(rse) + + +def _single_target_ref_metric(preds, target, squared): + sk_preds = preds.view(-1).numpy() + sk_target = target.view(-1).numpy() + return _sk_rse(sk_target, sk_preds, squared=squared) + + +def _multi_target_ref_metric(preds, target, squared): + sk_preds = preds.view(-1, num_targets).numpy() + sk_target = target.view(-1, num_targets).numpy() + return _sk_rse(sk_target, sk_preds, squared=squared) + + +@pytest.mark.parametrize("squared", [False, True]) +@pytest.mark.parametrize( + "preds, target, ref_metric, num_outputs", + [ + (_single_target_inputs.preds, _single_target_inputs.target, _single_target_ref_metric, 1), + (_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_ref_metric, num_targets), + ], +) +class TestRelativeSquaredError(MetricTester): + """Test class for `RelativeSquaredError` metric.""" + + @pytest.mark.parametrize("ddp", [True, False]) + def test_rse(self, squared, preds, target, ref_metric, num_outputs, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds, + target, + RelativeSquaredError, + partial(ref_metric, squared=squared), + metric_args={"squared": squared, "num_outputs": num_outputs}, + ) + + def test_rse_functional(self, squared, preds, target, ref_metric, num_outputs): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds, + target, + relative_squared_error, + partial(ref_metric, squared=squared), + metric_args={"squared": squared}, + ) + + def test_rse_differentiability(self, squared, preds, target, ref_metric, num_outputs): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=partial(RelativeSquaredError, num_outputs=num_outputs), + metric_functional=relative_squared_error, + metric_args={"squared": squared}, + ) + + @pytest.mark.xfail(raises=RuntimeError, reason="clamp_min_cpu not implented for `Half`.") + def test_rse_half_cpu(self, squared, preds, target, ref_metric, num_outputs): + """Test dtype support of the metric on CPU.""" + self.run_precision_test_cpu( + preds, + target, + partial(RelativeSquaredError, num_outputs=num_outputs), + relative_squared_error, + {"squared": squared}, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_rse_half_gpu(self, squared, preds, target, ref_metric, num_outputs): + """Test dtype support of the metric on GPU.""" + self.run_precision_test_gpu( + preds, + target, + partial(RelativeSquaredError, num_outputs=num_outputs), + relative_squared_error, + {"squared": squared}, + ) + + +def test_error_on_different_shape(metric_class=RelativeSquaredError): + """Test that error is raised on different shapes of input.""" + metric = metric_class() + with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): + metric(torch.randn(100), torch.randn(50)) + + +def test_error_on_multidim_tensors(metric_class=RelativeSquaredError): + """Test that error is raised if a larger than 2D tensor is given as input.""" + metric = metric_class() + with pytest.raises( + ValueError, + match=r"Expected both prediction and target to be 1D or 2D tensors," r" but received tensors with dimension .", + ): + metric(torch.randn(10, 20, 5), torch.randn(10, 20, 5)) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index bb294307bad..e1b17bcbdc7 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -124,6 +124,7 @@ MinkowskiDistance, PearsonCorrCoef, R2Score, + RelativeSquaredError, SpearmanCorrCoef, SymmetricMeanAbsolutePercentageError, TweedieDevianceScore, @@ -466,6 +467,7 @@ pytest.param(partial(MinkowskiDistance, p=3), _rand_input, _rand_input, id="minkowski distance"), pytest.param(PearsonCorrCoef, _rand_input, _rand_input, id="pearson corr coef"), pytest.param(R2Score, _rand_input, _rand_input, id="r2 score"), + pytest.param(RelativeSquaredError, _rand_input, _rand_input, id="relative squared error"), pytest.param(SpearmanCorrCoef, _rand_input, _rand_input, id="spearman corr coef"), pytest.param(SymmetricMeanAbsolutePercentageError, _rand_input, _rand_input, id="symmetric mape"), pytest.param(TweedieDevianceScore, _rand_input, _rand_input, id="tweedie deviance score"), @@ -733,11 +735,13 @@ def test_plot_methods_special_text_metrics(): @pytest.mark.parametrize("num_vals", [1, 2]) def test_plot_methods_retrieval(metric_class, preds, target, indexes, num_vals): """Test the plot method for retrieval metrics by themselves, since retrieval metrics requires an extra argument.""" - if num_vals != 1 and metric_class == RetrievalPrecisionRecallCurve: # curves does not support multiple step plot - pytest.skip("curve objects does not support plotting multiple steps") - metric = metric_class() + if num_vals != 1 and isinstance(metric, RetrievalPrecisionRecallCurve): + pytest.skip("curve objects does not support plotting multiple steps") + if num_vals != 1 and isinstance(metric, BinaryFairness): + pytest.skip("randomness in input leads to different keys for `BinaryFairness` metric and breaks plotting") + if num_vals == 1: metric.update(preds(), target(), indexes()) fig, ax = metric.plot()