diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a03c3f8c9d..7cae9f47ad3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961)) +- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937)) + + - Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978)) diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index 4102cf39625..fa20a3c96b6 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -16,10 +16,11 @@ import torch from torch import Tensor +from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs from torchmetrics.utilities.checks import _check_same_shape -def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]: +def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]: """Update and returns variables required to compute Mean Squared Error. Check for same shape of input tensors. @@ -27,12 +28,17 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, i Args: preds: Predicted tensor target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting """ _check_same_shape(preds, target) + _check_data_shape_to_num_outputs(preds, target, num_outputs, allow_1d_reshape=True) + if num_outputs == 1: + preds = preds.view(-1) + target = target.view(-1) diff = preds - target - sum_squared_error = torch.sum(diff * diff) - n_obs = target.numel() + sum_squared_error = torch.sum(diff * diff, dim=0) + n_obs = target.shape[0] return sum_squared_error, n_obs @@ -47,7 +53,7 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: Union[int, Ten Example: >>> preds = torch.tensor([0., 1, 2, 3]) >>> target = torch.tensor([0., 1, 2, 2]) - >>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + >>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=1) >>> _mean_squared_error_compute(sum_squared_error, n_obs) tensor(0.2500) @@ -55,13 +61,14 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: Union[int, Ten return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs) -def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor: +def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True, num_outputs: int = 1) -> Tensor: """Compute mean squared error. Args: preds: estimated labels target: ground truth labels squared: returns RMSE value if set to False + num_outputs: Number of outputs in multioutput setting Return: Tensor with MSE @@ -74,5 +81,5 @@ def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> T tensor(0.2500) """ - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=num_outputs) return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared) diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index 609ddf88f30..59612927f26 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -14,15 +14,28 @@ from torch import Tensor -def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None: - """Check that predictions and target have the correct shape, else raise error.""" +def _check_data_shape_to_num_outputs( + preds: Tensor, target: Tensor, num_outputs: int, allow_1d_reshape: bool = False +) -> None: + """Check that predictions and target have the correct shape, else raise error. + + Args: + preds: Predicted tensor + target: Ground truth tensor + num_outputs: Number of outputs in multioutput setting + allow_1d_reshape: Allow that for num_outputs=1 that preds and target does not need to be 1d tensors. Instead + code that follows are expected to reshape the tensors to 1d. + + """ if preds.ndim > 2 or target.ndim > 2: raise ValueError( f"Expected both predictions and target to be either 1- or 2-dimensional tensors," f" but got {target.ndim} and {preds.ndim}." ) - cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1) - cond2 = num_outputs > 1 and num_outputs != preds.shape[1] + cond1 = False + if not allow_1d_reshape: + cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1) + cond2 = num_outputs > 1 and preds.ndim > 1 and num_outputs != preds.shape[1] if cond1 or cond2: raise ValueError( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" diff --git a/src/torchmetrics/regression/mse.py b/src/torchmetrics/regression/mse.py index 4a67d54b513..34aaf196803 100644 --- a/src/torchmetrics/regression/mse.py +++ b/src/torchmetrics/regression/mse.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Optional, Sequence, Union +import torch from torch import Tensor, tensor from torchmetrics.functional.regression.mse import _mean_squared_error_compute, _mean_squared_error_update @@ -42,9 +43,12 @@ class MeanSquaredError(Metric): Args: squared: If True returns MSE value, if False returns RMSE value. + num_outputs: Number of outputs in multioutput setting kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Example: + Example:: + Single output mse computation: + >>> from torch import tensor >>> from torchmetrics.regression import MeanSquaredError >>> target = tensor([2.5, 5.0, 4.0, 8.0]) @@ -53,6 +57,17 @@ class MeanSquaredError(Metric): >>> mean_squared_error(preds, target) tensor(0.8750) + Example:: + Multioutput mse computation: + + >>> from torch import tensor + >>> from torchmetrics.regression import MeanSquaredError + >>> target = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]) + >>> preds = tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]) + >>> mean_squared_error = MeanSquaredError(num_outputs=3) + >>> mean_squared_error(preds, target) + tensor([1., 4., 9.]) + """ is_differentiable = True higher_is_better = False @@ -65,17 +80,25 @@ class MeanSquaredError(Metric): def __init__( self, squared: bool = True, + num_outputs: int = 1, **kwargs: Any, ) -> None: super().__init__(**kwargs) - self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum") - self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + if not isinstance(squared, bool): + raise ValueError(f"Expected argument `squared` to be a boolean but got {squared}") self.squared = squared + if not (isinstance(num_outputs, int) and num_outputs > 0): + raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}") + self.num_outputs = num_outputs + + self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum") + self.add_state("total", default=tensor(0), dist_reduce_fx="sum") + def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - sum_squared_error, n_obs = _mean_squared_error_update(preds, target) + sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=self.num_outputs) self.sum_squared_error += sum_squared_error self.total += n_obs diff --git a/tests/unittests/regression/test_mean_error.py b/tests/unittests/regression/test_mean_error.py index 41eeab8075b..10671df138d 100644 --- a/tests/unittests/regression/test_mean_error.py +++ b/tests/unittests/regression/test_mean_error.py @@ -132,9 +132,8 @@ def _single_target_ref_metric(preds, target, sk_fn, metric_args): def _multi_target_ref_metric(preds, target, sk_fn, metric_args): sk_preds = preds.view(-1, num_targets).numpy() sk_target = target.view(-1, num_targets).numpy() - - res = sk_fn(sk_target, sk_preds) - + sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {} + res = sk_fn(sk_target, sk_preds, **sk_kwargs) return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res @@ -150,6 +149,7 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args): [ (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}), (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}), + (MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": num_targets}), (MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}), (MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}), (