diff --git a/CHANGELOG.md b/CHANGELOG.md index 94a52a9a6fa..1c14035e158 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed support for pixelwise MSE ([#2017](https://github.com/Lightning-AI/torchmetrics/pull/2017) ## [1.1.0] - 2023-08-22 diff --git a/src/torchmetrics/functional/regression/mse.py b/src/torchmetrics/functional/regression/mse.py index fa20a3c96b6..c7d6d47dbfe 100644 --- a/src/torchmetrics/functional/regression/mse.py +++ b/src/torchmetrics/functional/regression/mse.py @@ -16,7 +16,6 @@ import torch from torch import Tensor -from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs from torchmetrics.utilities.checks import _check_same_shape @@ -32,7 +31,6 @@ def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) """ _check_same_shape(preds, target) - _check_data_shape_to_num_outputs(preds, target, num_outputs, allow_1d_reshape=True) if num_outputs == 1: preds = preds.view(-1) target = target.view(-1)