Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: multioutput mse #1937

Merged
merged 22 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
ee69d55
implementation
SkafteNicki Jul 25, 2023
af88b11
tests
SkafteNicki Jul 25, 2023
d6d3286
CHANGELOG.md
SkafteNicki Jul 25, 2023
4836654
fix mistake
SkafteNicki Jul 25, 2023
6923bed
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Jul 25, 2023
d65ce08
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Jul 28, 2023
3ddb5a8
Update tests/unittests/regression/test_mean_error.py
SkafteNicki Aug 1, 2023
0c88aa2
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Aug 1, 2023
75ba983
fix @Borda typo
Borda Aug 1, 2023
bad919a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2023
8ce3361
Merge branch 'master' into feature/multioutput_mse
Borda Aug 1, 2023
87187dd
one
Borda Aug 1, 2023
788fbba
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Aug 3, 2023
89e6aa0
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Aug 3, 2023
2852a9f
Merge branch 'master' into feature/multioutput_mse
Borda Aug 3, 2023
c63b6ef
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Aug 5, 2023
883ef75
Merge branch 'master' into feature/multioutput_mse
mergify[bot] Aug 7, 2023
59510cc
Merge branch 'master' into feature/multioutput_mse
SkafteNicki Aug 7, 2023
43ad379
Merge branch 'master' into feature/multioutput_mse
mergify[bot] Aug 7, 2023
1ce63fb
Merge branch 'master' into feature/multioutput_mse
mergify[bot] Aug 7, 2023
28b2da7
Merge branch 'master' into feature/multioutput_mse
mergify[bot] Aug 7, 2023
68ea8d0
Merge branch 'master' into feature/multioutput_mse
mergify[bot] Aug 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `top_k` argument to `RetrievalMRR` in retrieval package ([#1961](https://github.com/Lightning-AI/torchmetrics/pull/1961))


- Added support for multioutput evaluation in `MeanSquaredError` ([#1937](https://github.com/Lightning-AI/torchmetrics/pull/1937))


- Added warning to `MeanAveragePrecision` if too many detections are observed ([#1978](https://github.com/Lightning-AI/torchmetrics/pull/1978))


Expand Down
19 changes: 13 additions & 6 deletions src/torchmetrics/functional/regression/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,29 @@
import torch
from torch import Tensor

from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
from torchmetrics.utilities.checks import _check_same_shape


def _mean_squared_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
def _mean_squared_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, int]:
"""Update and returns variables required to compute Mean Squared Error.

Check for same shape of input tensors.

Args:
preds: Predicted tensor
target: Ground truth tensor
num_outputs: Number of outputs in multioutput setting

"""
_check_same_shape(preds, target)
_check_data_shape_to_num_outputs(preds, target, num_outputs, allow_1d_reshape=True)
if num_outputs == 1:
preds = preds.view(-1)
target = target.view(-1)
diff = preds - target
sum_squared_error = torch.sum(diff * diff)
n_obs = target.numel()
sum_squared_error = torch.sum(diff * diff, dim=0)
n_obs = target.shape[0]
return sum_squared_error, n_obs


Expand All @@ -47,21 +53,22 @@ def _mean_squared_error_compute(sum_squared_error: Tensor, n_obs: Union[int, Ten
Example:
>>> preds = torch.tensor([0., 1, 2, 3])
>>> target = torch.tensor([0., 1, 2, 2])
>>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
>>> sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=1)
>>> _mean_squared_error_compute(sum_squared_error, n_obs)
tensor(0.2500)

"""
return sum_squared_error / n_obs if squared else torch.sqrt(sum_squared_error / n_obs)


def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> Tensor:
def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True, num_outputs: int = 1) -> Tensor:
"""Compute mean squared error.

Args:
preds: estimated labels
target: ground truth labels
squared: returns RMSE value if set to False
num_outputs: Number of outputs in multioutput setting

Return:
Tensor with MSE
Expand All @@ -74,5 +81,5 @@ def mean_squared_error(preds: Tensor, target: Tensor, squared: bool = True) -> T
tensor(0.2500)

"""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=num_outputs)
return _mean_squared_error_compute(sum_squared_error, n_obs, squared=squared)
21 changes: 17 additions & 4 deletions src/torchmetrics/functional/regression/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,28 @@
from torch import Tensor


def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None:
"""Check that predictions and target have the correct shape, else raise error."""
def _check_data_shape_to_num_outputs(
preds: Tensor, target: Tensor, num_outputs: int, allow_1d_reshape: bool = False
) -> None:
"""Check that predictions and target have the correct shape, else raise error.

Args:
preds: Predicted tensor
target: Ground truth tensor
num_outputs: Number of outputs in multioutput setting
allow_1d_reshape: Allow that for num_outputs=1 that preds and target does not need to be 1d tensors. Instead
code that follows are expected to reshape the tensors to 1d.

"""
if preds.ndim > 2 or target.ndim > 2:
raise ValueError(
f"Expected both predictions and target to be either 1- or 2-dimensional tensors,"
f" but got {target.ndim} and {preds.ndim}."
)
cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1)
cond2 = num_outputs > 1 and num_outputs != preds.shape[1]
cond1 = False
if not allow_1d_reshape:
cond1 = num_outputs == 1 and not (preds.ndim == 1 or preds.shape[1] == 1)
cond2 = num_outputs > 1 and preds.ndim > 1 and num_outputs != preds.shape[1]
if cond1 or cond2:
raise ValueError(
f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}"
Expand Down
31 changes: 27 additions & 4 deletions src/torchmetrics/regression/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from typing import Any, Optional, Sequence, Union

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.mse import _mean_squared_error_compute, _mean_squared_error_update
Expand Down Expand Up @@ -42,9 +43,12 @@ class MeanSquaredError(Metric):

Args:
squared: If True returns MSE value, if False returns RMSE value.
num_outputs: Number of outputs in multioutput setting
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.

Example:
Example::
Single output mse computation:

>>> from torch import tensor
>>> from torchmetrics.regression import MeanSquaredError
>>> target = tensor([2.5, 5.0, 4.0, 8.0])
Expand All @@ -53,6 +57,17 @@ class MeanSquaredError(Metric):
>>> mean_squared_error(preds, target)
tensor(0.8750)

Example::
Multioutput mse computation:

>>> from torch import tensor
>>> from torchmetrics.regression import MeanSquaredError
>>> target = tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
>>> preds = tensor([[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]])
>>> mean_squared_error = MeanSquaredError(num_outputs=3)
>>> mean_squared_error(preds, target)
tensor([1., 4., 9.])

"""
is_differentiable = True
higher_is_better = False
Expand All @@ -65,17 +80,25 @@ class MeanSquaredError(Metric):
def __init__(
self,
squared: bool = True,
num_outputs: int = 1,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)

self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
if not isinstance(squared, bool):
raise ValueError(f"Expected argument `squared` to be a boolean but got {squared}")
self.squared = squared

if not (isinstance(num_outputs, int) and num_outputs > 0):
raise ValueError(f"Expected num_outputs to be a positive integer but got {num_outputs}")
self.num_outputs = num_outputs

self.add_state("sum_squared_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""Update state with predictions and targets."""
sum_squared_error, n_obs = _mean_squared_error_update(preds, target)
sum_squared_error, n_obs = _mean_squared_error_update(preds, target, num_outputs=self.num_outputs)

self.sum_squared_error += sum_squared_error
self.total += n_obs
Expand Down
6 changes: 3 additions & 3 deletions tests/unittests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,8 @@ def _single_target_ref_metric(preds, target, sk_fn, metric_args):
def _multi_target_ref_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()

res = sk_fn(sk_target, sk_preds)

sk_kwargs = {"multioutput": "raw_values"} if metric_args and "num_outputs" in metric_args else {}
res = sk_fn(sk_target, sk_preds, **sk_kwargs)
return math.sqrt(res) if (metric_args and not metric_args["squared"]) else res


Expand All @@ -150,6 +149,7 @@ def _multi_target_ref_metric(preds, target, sk_fn, metric_args):
[
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True}),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": False}),
(MeanSquaredError, mean_squared_error, sk_mean_squared_error, {"squared": True, "num_outputs": num_targets}),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}),
(MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}),
(
Expand Down