From ee69d5553b9307515a097d19b68cc0f8b41b6d1a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 10:40:09 +0200 Subject: [PATCH 1/8] implementation --- src/torchmetrics/functional/regression/mse.py | 17 +++++++--- .../functional/regression/utils.py | 21 ++++++++++--- src/torchmetrics/regression/mse.py | 31 ++++++++++++++++--- 3 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index 4102cf39625..86cdfda0108 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -16,10 +16,11 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: """Update and returns variables required to compute Mean Squared Error. Check for same shape of input tensors. @@ -27,12 +28,17 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i Args: preds: Predicted tensor target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting """ _check_same_shape(preds, target) + _check_data_shape_to_num_outputs(preds, target, num_outputs, allow_1d_reshape=True) + if num_outputs == 1: + preds = preds.view(-1) + target = target.view(-1) diff = preds - target - sum_squared_error = torch.sum(diff * diff) - n_obs = target.numel() + sum_squared_error = torch.sum(diff * diff, dim=0) + n_obs = target.shape[0] return sum_squared_error, n_obs @@ -55,13 +61,14 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: Union[int, Ten return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs) -def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: +def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True, num_outputs: int = 1) -> Tensor: """Compute mean squared error. Args: preds: estimated labels target: ground truth labels squared: returns RMSE value if set to False + num_outputs: Number of outputs in multioutput setting Return: Tensor with MSE @@ -74,5 +81,5 @@ def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> T tensor(0.2500) """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=num_outputs) return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared) diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index 609ddf88f30..59612927f26 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -14,15 +14,28 @@ from torch import Tensor -def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None: - """Check that predictions and target have the correct shape, else raise error.""" +def _check_data_shape_to_num_outputs( + preds: Tensor, target: Tensor, num_outputs: int, allow_1d_reshape: bool = False +) -> None: + """Check that predictions and target have the correct shape, else raise error. + + Args: + preds: Predicted tensor + target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting + allow_1d_reshape: Allow that for num_outputs=1 that preds and target does not need to be 1d tensors. Instead + code that follows are expected to reshape the tensors to 1d. + + """ if preds.ndim > 2 or target.ndim > 2: raise ValueError( f"Expected both predictions and target to be either 1- or 2-dimensional tensors," f" but got {target.ndim} and {preds.ndim}." ) - cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1) - cond2 = num_outputs > 1 and num_outputs != preds.shape[1] + cond1 = False + if not allow_1d_reshape: + cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1) + cond2 = num_outputs > 1 and preds.ndim > 1 and num_outputs != preds.shape[1] if cond1 or cond2: raise ValueError( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 4a67d54b513..34aaf196803 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Optional, Sequence, Union +import torch from torch import Tensor, tensor from torchmetrics.functional.regression.mse import _mean_squared_error_compute, _mean_squared_error_update @@ -42,9 +43,12 @@ class MeanSquaredError(Metric): Args: squared: If True returns MSE value, if False returns RMSE value. + num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Example: + Example:: + Single output mse computation: + >>> from torch import tensor >>> from torchmetrics.regression import MeanSquaredError >>> target = tensor([2.5, 5.0, 4.0, 8.0]) @@ -53,6 +57,17 @@ class MeanSquaredError(Metric): >>> mean_squared_error(preds, target) tensor(0.8750) + Example:: + Multioutput mse computation: + + >>> from torch import tensor + >>> from torchmetrics.regression import MeanSquaredError + >>> target = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + >>> preds = tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + >>> mean_squared_error = MeanSquaredError(num_outputs=3) + >>> mean_squared_error(preds, target) + tensor([1., 4., 9.]) + """ is_differentiable = True higher_is_better = False @@ -65,17 +80,25 @@ class MeanSquaredError(Metric): def __init__( self, squared: bool = True, + num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + if not isinstance(squared, bool): + raise ValueError(f"Expected argument `squared` to be a boolean but got {squared}") self.squared = squared + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=self.num_outputs) self.sum_squared_error += sum_squared_error self.total += n_obs From af88b110ab9c8dc56ae5e833f9a7b048f921bc97 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 10:40:25 +0200 Subject: [PATCH 2/8] tests --- tests/unittests/regression/test_mean_error.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 41eeab8075b..e5fc1810fbc 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -133,7 +133,10 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - res = sk_fn(sk_target, sk_preds) + if metric_args and "num_outputs" in metric_args: + res = sk_fn(sk_target, sk_preds, multioutput="raw_values") + else: + res = sk_fn(sk_target, sk_preds) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res @@ -150,6 +153,7 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): [ (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": num_targets}), (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), ( From d6d3286d01b496a69d5acca88d03814c0693c94b Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 10:42:47 +0200 Subject: [PATCH 3/8] CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05fb07fbd12..26b8682f383 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added warning to `PearsonCorrCoeff` if input has a very small variance for its given dtype ([#1926](https://github.com/Lightning-AI/torchmetrics/pull/1926)) + +- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) + ### Changed - From 4836654da804738f7a1d10e1547f3b131ee0d77f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 12:04:43 +0200 Subject: [PATCH 4/8] fix mistake --- src/torchmetrics/functional/regression/mse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index 86cdfda0108..fa20a3c96b6 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -53,7 +53,7 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: Union[int, Ten Example: >>> preds = torch.tensor([0., 1, 2, 3]) >>> target = torch.tensor([0., 1, 2, 2]) - >>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + >>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=1) >>> _mean_squared_error_compute(sum_squared_error, n_obs) tensor(0.2500) From 3ddb5a815fbfadbe80b92dbafd5f7d8a7cc30853 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 1 Aug 2023 14:44:29 +0200 Subject: [PATCH 5/8] Update tests/unittests/regression/test_mean_error.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- tests/unittests/regression/test_mean_error.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index e5fc1810fbc..4468db4e2ee 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -134,9 +134,8 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_target = target.view(-1, num_targets).numpy() if metric_args and "num_outputs" in metric_args: - res = sk_fn(sk_target, sk_preds, multioutput="raw_values") - else: - res = sk_fn(sk_target, sk_preds) + sk_kwargs = dict(multioutput="raw_values") if if metric_args and "num_outputs" in metric_args else {} + res = sk_fn(sk_target, sk_preds, **sk_kwargs) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res From 75ba98358ebf9d0f6b907d00138cf64ed1115305 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 1 Aug 2023 08:31:27 -1000 Subject: [PATCH 6/8] fix @Borda typo --- tests/unittests/regression/test_mean_error.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 4468db4e2ee..a01d2e85f12 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -133,8 +133,7 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - if metric_args and "num_outputs" in metric_args: - sk_kwargs = dict(multioutput="raw_values") if if metric_args and "num_outputs" in metric_args else {} + sk_kwargs = dict(multioutput="raw_values") if metric_args and "num_outputs" in metric_args else {} res = sk_fn(sk_target, sk_preds, **sk_kwargs) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res From bad919aeda79fd47058c6431ff635dec5f4eec42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Aug 2023 18:32:04 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/regression/test_mean_error.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index a01d2e85f12..7848e16a8b5 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -133,7 +133,7 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - sk_kwargs = dict(multioutput="raw_values") if metric_args and "num_outputs" in metric_args else {} + sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} res = sk_fn(sk_target, sk_preds, **sk_kwargs) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res From 87187dd06ef15f4d54ec8f26262d7ada816e0160 Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Tue, 1 Aug 2023 08:35:36 -1000 Subject: [PATCH 8/8] one --- tests/unittests/regression/test_mean_error.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 7848e16a8b5..10671df138d 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -132,10 +132,8 @@ def _single_target_ref_metric(preds, target, sk_fn, metric_args): def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} res = sk_fn(sk_target, sk_preds, **sk_kwargs) - return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res