Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Mean Absolute Percentage Error #248

Merged
merged 45 commits into from
Jun 10, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0f81384
initial code, test failing
pranjaldatta May 14, 2021
844fc9e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2021
ecccca9
increase requirement
SkafteNicki May 15, 2021
b04fa68
Merge branch 'master' into metrics/mape
SkafteNicki May 15, 2021
62107cd
Apply suggestions from code review
SkafteNicki May 15, 2021
a61efa6
added docs+changelog and other review changes
pranjaldatta May 15, 2021
df634e3
consistent shorted type hints
pranjaldatta May 16, 2021
2e2d582
fix tests
SkafteNicki May 18, 2021
8515662
remove arg
SkafteNicki May 18, 2021
b862861
Merge branch 'master' into metrics/mape
SkafteNicki May 18, 2021
e3c2db9
Merge branch 'master' into metrics/mape
mergify[bot] May 18, 2021
ef8e342
Merge branch 'master' into metrics/mape
pranjaldatta May 20, 2021
8b01f0b
minor merge fixes
pranjaldatta May 20, 2021
fdc3c3b
Merge branch 'master' into metrics/mape
mergify[bot] May 23, 2021
6a97671
Merge branch 'master' into metrics/mape
mergify[bot] May 25, 2021
53c218f
Merge branch 'master' into metrics/mape
mergify[bot] May 27, 2021
6c35769
Merge branch 'master' into metrics/mape
mergify[bot] May 28, 2021
e0a540d
Merge branch 'master' into metrics/mape
mergify[bot] May 31, 2021
29364ae
Merge branch 'master' into metrics/mape
mergify[bot] May 31, 2021
a20720e
Merge branch 'master' into metrics/mape
mergify[bot] Jun 3, 2021
ee9c22b
Merge branch 'master' into metrics/mape
mergify[bot] Jun 3, 2021
f0498d6
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
537302b
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
d7ceccf
Merge branch 'master' into metrics/mape
mergify[bot] Jun 8, 2021
31d3435
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
b5e7590
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
6046312
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
7b93371
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
1f490e6
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
e41b49a
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
fc99d2a
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
6dad077
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
dca19b5
dep warning added to mean_rel_err
pranjaldatta Jun 9, 2021
7f6d9da
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
f608e56
Apply suggestions from code review
Borda Jun 9, 2021
ef336e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
f679d44
Merge branch 'master' into metrics/mape
mergify[bot] Jun 9, 2021
eae2eef
applied suggestions from code review
pranjaldatta Jun 9, 2021
c8709d3
update
Borda Jun 9, 2021
13574f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
2ffd6a4
added comments regarding test case fixes
pranjaldatta Jun 9, 2021
dd860da
removed unused imports
pranjaldatta Jun 9, 2021
3f48a8d
,
Borda Jun 9, 2021
8064453
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2021
65e09f0
Merge branch 'master' into metrics/mape
mergify[bot] Jun 10, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added support for unnormalized scores (e.g. logits) in `Accuracy`, `Precision`, `Recall`, `FBeta`, `F1`, `StatScore`, `Hamming`, `ConfusionMatrix` metrics ([#200](https://github.com/PyTorchLightning/metrics/pull/200))


- Added `MeanAbsolutePercentageError(MAPE)` metric. ([#248](https://github.com/PyTorchLightning/metrics/pull/248))


- Added `squared` argument to `MeanSquaredError` for computing `RMSE` ([#249](https://github.com/PyTorchLightning/metrics/pull/249))


Expand All @@ -36,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Deprecated

- Remove `torchmetrics.functional.mean_relative_error`([#248](https://github.com/PyTorchLightning/metrics/pull/248))

### Removed

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ mean_absolute_error [func]
:noindex:


mean_absolute_percentage_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.functional.mean_absolute_percentage_error
:noindex:


mean_squared_error [func]
~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
7 changes: 7 additions & 0 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,13 @@ MeanAbsoluteError
:noindex:


MeanAbsolutePercentageError
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: torchmetrics.MeanAbsolutePercentageError
:noindex:


MeanSquaredError
~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ yapf>=0.29.0
phmdoctest>=1.1.1

cloudpickle>=1.3
scikit-learn>0.22.1
scikit-learn>=0.24
scikit-image>0.17.1
nltk>=3.6
29 changes: 24 additions & 5 deletions tests/regression/test_mean_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,24 @@
import pytest
import torch
from sklearn.metrics import mean_absolute_error as sk_mean_absolute_error
from sklearn.metrics import mean_absolute_percentage_error as sk_mean_abs_percentage_error
from sklearn.metrics import mean_squared_error as sk_mean_squared_error
from sklearn.metrics import mean_squared_log_error as sk_mean_squared_log_error

from tests.helpers import seed_all
from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.functional import mean_absolute_error, mean_squared_error, mean_squared_log_error
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, MeanSquaredLogError
from torchmetrics.functional import (
mean_absolute_error,
mean_absolute_percentage_error,
mean_squared_error,
mean_squared_log_error,
)
from torchmetrics.regression import (
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
)
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_6

seed_all(42)
Expand All @@ -47,14 +58,14 @@
def _single_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
res = sk_fn(sk_preds, sk_target)
res = sk_fn(sk_target, sk_preds)
return math.sqrt(res) if (metric_args and not metric_args['squared']) else res


def _multi_target_sk_metric(preds, target, sk_fn, metric_args):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
res = sk_fn(sk_preds, sk_target)
res = sk_fn(sk_target, sk_preds)
return math.sqrt(res) if (metric_args and not metric_args['squared']) else res


Expand All @@ -75,6 +86,7 @@ def _multi_target_sk_metric(preds, target, sk_fn, metric_args):
'squared': False
}),
(MeanAbsoluteError, mean_absolute_error, sk_mean_absolute_error, {}),
(MeanAbsolutePercentageError, mean_absolute_percentage_error, sk_mean_abs_percentage_error, {}),
(MeanSquaredLogError, mean_squared_log_error, sk_mean_squared_log_error, {}),
],
)
Expand Down Expand Up @@ -124,14 +136,21 @@ def test_mean_error_half_cpu(self, preds, target, sk_metric, metric_class, metri
if metric_class == MeanSquaredLogError:
# MeanSquaredLogError half + cpu does not work due to missing support in torch.log
pytest.xfail("MeanSquaredLogError metric does not support cpu + half precision")

if metric_class == MeanAbsolutePercentageError:
# MeanSquaredPercentageError half + cpu does not work due to missing support in torch.log
pytest.xfail("MeanSquaredPercentageError metric does not support cpu + half precision")

self.run_precision_test_cpu(preds, target, metric_class, metric_functional)

@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
def test_mean_error_half_gpu(self, preds, target, sk_metric, metric_class, metric_functional, sk_fn, metric_args):
self.run_precision_test_gpu(preds, target, metric_class, metric_functional)


@pytest.mark.parametrize("metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError])
@pytest.mark.parametrize(
"metric_class", [MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError, MeanAbsolutePercentageError]
)
def test_error_on_different_shape(metric_class):
metric = metric_class()
with pytest.raises(RuntimeError, match='Predictions and targets are expected to have the same shape'):
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
SSIM,
ExplainedVariance,
MeanAbsoluteError,
MeanAbsolutePercentageError,
MeanSquaredError,
MeanSquaredLogError,
PearsonCorrcoef,
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from torchmetrics.functional.nlp import bleu_score # noqa: F401
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401
mean_absolute_percentage_error,
)
from torchmetrics.functional.regression.mean_relative_error import mean_relative_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
Expand Down
3 changes: 3 additions & 0 deletions torchmetrics/functional/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
from torchmetrics.functional.regression.mean_absolute_error import mean_absolute_error # noqa: F401
from torchmetrics.functional.regression.mean_absolute_percentage_error import ( # noqa: F401
mean_absolute_percentage_error,
)
from torchmetrics.functional.regression.mean_squared_error import mean_squared_error # noqa: F401
from torchmetrics.functional.regression.mean_squared_log_error import mean_squared_log_error # noqa: F401
from torchmetrics.functional.regression.pearson import pearson_corrcoef # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def _mean_absolute_percentage_error_update(
preds: Tensor,
target: Tensor,
epsilon: float = 1.17e-06,
) -> Tuple[Tensor, int]:

_check_same_shape(preds, target)

abs_diff = torch.abs(preds - target)
abs_per_error = abs_diff / torch.clamp(torch.abs(target), min=epsilon)

sum_abs_per_error = torch.sum(abs_per_error)

num_obs = target.numel()

return sum_abs_per_error, num_obs


def _mean_absolute_percentage_error_compute(sum_abs_per_error: Tensor, num_obs: int) -> Tensor:
return sum_abs_per_error / num_obs


def mean_absolute_percentage_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean absolute percentage error.

Args:
preds: estimated labels
target: ground truth labels

Return:
Tensor with MAPE

Note:
The epsilon value is taken from `scikit-learn's
implementation
<https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/metrics/_regression.py#L197>`_.

Example:
>>> from torchmetrics.functional import mean_absolute_percentage_error
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> mean_absolute_percentage_error(preds, target)
tensor(0.2667)
"""
sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target)
mean_ape = _mean_absolute_percentage_error_compute(sum_abs_per_error, num_obs)

return mean_ape
30 changes: 15 additions & 15 deletions torchmetrics/functional/regression/mean_relative_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from warnings import warn

import torch
from torch import Tensor

from torchmetrics.functional.regression.mean_absolute_percentage_error import (
_mean_absolute_percentage_error_compute,
_mean_absolute_percentage_error_update,
)
from torchmetrics.utilities.checks import _check_same_shape


def _mean_relative_error_update(preds: Tensor, target: Tensor) -> Tuple[Tensor, int]:
_check_same_shape(preds, target)
target_nz = target.clone()
target_nz[target == 0] = 1
sum_rltv_error = torch.sum(torch.abs((preds - target) / target_nz))
n_obs = target.numel()
return sum_rltv_error, n_obs


def _mean_relative_error_compute(sum_rltv_error: Tensor, n_obs: int) -> Tensor:
return sum_rltv_error / n_obs


def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
"""
Computes mean relative error
Expand All @@ -49,7 +41,15 @@ def mean_relative_error(preds: Tensor, target: Tensor) -> Tensor:
>>> y = torch.tensor([0., 1, 2, 2])
>>> mean_relative_error(x, y)
tensor(0.1250)

.. deprecated::
Use :func:`torchmetrics.functional.mean_absolute_percentage_error`. Will be removed in v0.5.0.

"""
sum_rltv_error, n_obs = _mean_relative_error_update(preds, target)
return _mean_relative_error_compute(sum_rltv_error, n_obs)
warn(
"Function `mean_relative_error` was deprecated v0.4 and will be removed in v0.5."
"Use `mean_absolute_percentage_error` instead.",
DeprecationWarning,
)
sum_rltv_error, n_obs = _mean_absolute_percentage_error_update(preds, target)
return _mean_absolute_percentage_error_compute(sum_rltv_error, n_obs)
1 change: 1 addition & 0 deletions torchmetrics/regression/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401
from torchmetrics.regression.mean_absolute_error import MeanAbsoluteError # noqa: F401
from torchmetrics.regression.mean_absolute_percentage_error import MeanAbsolutePercentageError # noqa: F401
from torchmetrics.regression.mean_squared_error import MeanSquaredError # noqa: F401
from torchmetrics.regression.mean_squared_log_error import MeanSquaredLogError # noqa: F401
from torchmetrics.regression.pearson import PearsonCorrcoef # noqa: F401
Expand Down
101 changes: 101 additions & 0 deletions torchmetrics/regression/mean_absolute_percentage_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional

import torch
from torch import Tensor, tensor

from torchmetrics.functional.regression.mean_absolute_percentage_error import (
_mean_absolute_percentage_error_compute,
_mean_absolute_percentage_error_update,
)
from torchmetrics.metric import Metric


class MeanAbsolutePercentageError(Metric):
r"""
Computes `mean absolute percentage error <https://en.wikipedia.org/wiki/Mean_absolute_percentage_error>`_ (MAPE):

.. math:: \text{MAPE} = \frac{1}{n}\sum_1^n\frac{| y_i - \hat{y_i} |}{\max(\epsilon, y_i)}

Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.

Args:
compute_on_step:
Forward only calls ``update()`` and return None if this is set to False. default: True
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)

Note:
The epsilon value is taken from `scikit-learn's implementation
<https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/metrics/_regression.py#L197>`_.

Note:
MAPE output is a non-negative floating point. Best result is 0.0 . But it is important to note that,
bad predictions, can lead to arbitarily large values. Especially when some ``target`` values are close to 0.
This implementation returns a very large number instead of ``inf``.
For more information, `read here
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.mean_absolute_percentage_error.html>`_.

Example:
>>> from torchmetrics import MeanAbsolutePercentageError
>>> target = torch.tensor([1, 10, 1e6])
>>> preds = torch.tensor([0.9, 15, 1.2e6])
>>> mean_abs_percentage_error = MeanAbsolutePercentageError()
>>> mean_abs_percentage_error(preds, target)
tensor(0.2667)
"""

def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)

self.add_state("sum_abs_per_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
"""
Update state with predictions and targets.

Args:
preds: Predictions from model
target: Ground truth values
"""
sum_abs_per_error, num_obs = _mean_absolute_percentage_error_update(preds, target)

self.sum_abs_per_error += sum_abs_per_error
self.total += num_obs

def compute(self) -> Tensor:
"""
Computes mean absolute percentage error over state.
"""
return _mean_absolute_percentage_error_compute(self.sum_abs_per_error, self.total)

@property
def is_differentiable(self) -> bool:
return True