Skip to content

Commit 0858d54

Browse files
authored
Merge branch 'master' into feature/normalize_image_metrics
2 parents 1b09f78 + ec5dfc8 commit 0858d54

File tree

18 files changed

+333
-10
lines changed

18 files changed

+333
-10
lines changed

.github/workflows/ci_test-conda.yml

+14-1
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,25 @@ jobs:
4545
python-version: ${{ matrix.python-version }}
4646
offset: "pt"
4747

48+
- name: Cache Conda env
49+
uses: actions/cache@v2
50+
with:
51+
path: ${{ env.CONDA }}/envs
52+
key: ${{ runner.os }}-py${{ matrix.python-version }}-pt{{ matrix.pytorch-version }}-${{ env.TIMESTAMP }}-${{ env.CACHE_NUMBER }}
53+
env:
54+
# Increase this value to reset cache if etc/example-environment.yml has not changed
55+
CACHE_NUMBER: 0
56+
# Automatically reset every week
57+
TIMESTAMP: $(/bin/date -u '+%Yw%W')
58+
id: cache
59+
4860
# https://docs.conda.io/projects/conda/en/4.6.0/_downloads/52a95608c49671267e40c689e0bc00ca/conda-cheatsheet.pdf
4961
# https://gist.github.com/mwouts/9842452d020c08faf9e84a3bba38a66f
5062
- name: Setup Miniconda
5163
uses: conda-incubator/setup-miniconda@v2
64+
if: steps.cache.outputs.cache-hit != 'true'
5265
with:
53-
miniforge-variant: Mambaforge
66+
# miniforge-variant: Mambaforge
5467
miniforge-version: latest
5568
use-mamba: true
5669
# miniconda-version: "4.7.12"

.github/workflows/ci_test-full.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ jobs:
3535
- {python-version: '3.10', requires: 'oldest'}
3636
- {python-version: '3.10', os: 'windows'} # todo: https://discuss.pytorch.org/t/numpy-is-not-available-error/146192
3737
include:
38-
- {python-version: '3.10', requires: 'latest', os: 'ubuntu-22.04'}
39-
- {python-version: '3.10', requires: 'latest', os: 'macOS-12'}
38+
- {os: 'ubuntu-22.04', python-version: '3.10'}
39+
- {os: 'macOS-12', python-version: '3.10'}
4040
env:
4141
PYTEST_ARTEFACT: test-results-${{ matrix.os }}-py${{ matrix.python-version }}-${{ matrix.requires }}.xml
4242
PYTORCH_URL: https://download.pytorch.org/whl/cpu/torch_stable.html

CHANGELOG.md

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2323
- Added `KendallRankCorrCoef` to regression package ([#1271](https://github.com/Lightning-AI/metrics/pull/1271))
2424

2525

26+
- Added `LogCoshError` to regression package ([#1316](https://github.com/Lightning-AI/metrics/pull/1316))
27+
28+
2629
### Changed
2730

2831
- Changed `MeanAveragePrecision` to vectorize `_find_best_gt_match` operation ([#1259](https://github.com/Lightning-AI/metrics/pull/1259))
@@ -47,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4750
- Fixed bug in `Metrictracker.best_metric` when `return_step=False` ([#1306](https://github.com/Lightning-AI/metrics/pull/1306))
4851

4952

53+
- Fixed bug to prevent users from going into a infinite loop if trying to iterate of a single metric ([#1320](https://github.com/Lightning-AI/metrics/pull/1320))
54+
55+
5056
## [0.10.2] - 2022-10-31
5157

5258
### Changed

docs/source/links.rst

+1
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,4 @@
9595
.. _Fisher-Rao distance: http://www.scholarpedia.org/article/Fisher-Rao_metric
9696
.. _Kendall Rank Correlation Coefficient: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
9797
.. _The Treatment of Ties in Ranking Problems: https://www.jstor.org/stable/2332303
98+
.. _LogCosh Error: https://arxiv.org/pdf/2101.10427.pdf
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
.. customcarditem::
2+
:header: Log Cosh Error
3+
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
4+
:tags: Regression
5+
6+
.. include:: ../links.rst
7+
8+
##############
9+
Log Cosh Error
10+
##############
11+
12+
Module Interface
13+
________________
14+
15+
.. autoclass:: torchmetrics.LogCoshError
16+
:noindex:
17+
18+
Functional Interface
19+
____________________
20+
21+
.. autofunction:: torchmetrics.functional.log_cosh_error
22+
:noindex:

src/torchmetrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ExplainedVariance,
6060
KendallRankCorrCoef,
6161
KLDivergence,
62+
LogCoshError,
6263
MeanAbsoluteError,
6364
MeanAbsolutePercentageError,
6465
MeanSquaredError,
@@ -132,6 +133,7 @@
132133
"JaccardIndex",
133134
"KendallRankCorrCoef",
134135
"KLDivergence",
136+
"LogCoshError",
135137
"MatchErrorRate",
136138
"MatthewsCorrCoef",
137139
"MaxMetric",

src/torchmetrics/functional/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from torchmetrics.functional.regression.explained_variance import explained_variance
5252
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef
5353
from torchmetrics.functional.regression.kl_divergence import kl_divergence
54+
from torchmetrics.functional.regression.log_cosh import log_cosh_error
5455
from torchmetrics.functional.regression.log_mse import mean_squared_log_error
5556
from torchmetrics.functional.regression.mae import mean_absolute_error
5657
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error
@@ -115,6 +116,7 @@
115116
"jaccard_index",
116117
"kendall_rank_corrcoef",
117118
"kl_divergence",
119+
"log_cosh_error",
118120
"match_error_rate",
119121
"matthews_corrcoef",
120122
"mean_absolute_error",

src/torchmetrics/functional/regression/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchmetrics.functional.regression.explained_variance import explained_variance # noqa: F401
1717
from torchmetrics.functional.regression.kendall import kendall_rank_corrcoef # noqa: F401
1818
from torchmetrics.functional.regression.kl_divergence import kl_divergence # noqa: F401
19+
from torchmetrics.functional.regression.log_cosh import log_cosh_error # noqa: F401
1920
from torchmetrics.functional.regression.log_mse import mean_squared_log_error # noqa: F401
2021
from torchmetrics.functional.regression.mae import mean_absolute_error # noqa: F401
2122
from torchmetrics.functional.regression.mape import mean_absolute_percentage_error # noqa: F401

src/torchmetrics/functional/regression/kendall.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from torch import Tensor
1818
from typing_extensions import Literal
1919

20-
from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
20+
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
2121
from torchmetrics.utilities.checks import _check_same_shape
2222
from torchmetrics.utilities.data import _bincount, dim_zero_cat
2323
from torchmetrics.utilities.enums import EnumStr
@@ -265,7 +265,7 @@ def _kendall_corrcoef_update(
265265
"""
266266
# Data checking
267267
_check_same_shape(preds, target)
268-
_check_data_shape_for_corr_coef(preds, target, num_outputs)
268+
_check_data_shape_to_num_outputs(preds, target, num_outputs)
269269

270270
if num_outputs == 1:
271271
preds = preds.unsqueeze(1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Tuple
15+
16+
import torch
17+
from torch import Tensor
18+
19+
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
20+
from torchmetrics.utilities.checks import _check_same_shape
21+
22+
23+
def _unsqueeze_tensors(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
24+
if preds.ndim == 2:
25+
return preds, target
26+
return preds.unsqueeze(1), target.unsqueeze(1)
27+
28+
29+
def _log_cosh_error_update(preds: Tensor, target: Tensor, num_outputs: int) -> Tuple[Tensor, Tensor]:
30+
"""Updates and returns variables required to compute LogCosh error.
31+
32+
Checks for same shape of input tensors.
33+
34+
Args:
35+
preds: Predicted tensor
36+
target: Ground truth tensor
37+
38+
Return:
39+
Sum of LogCosh error over examples, and total number of examples
40+
"""
41+
_check_same_shape(preds, target)
42+
_check_data_shape_to_num_outputs(preds, target, num_outputs)
43+
44+
preds, target = _unsqueeze_tensors(preds, target)
45+
diff = preds - target
46+
sum_log_cosh_error = torch.log((torch.exp(diff) + torch.exp(-diff)) / 2).sum(0).squeeze()
47+
n_obs = torch.tensor(target.shape[0], device=preds.device)
48+
return sum_log_cosh_error, n_obs
49+
50+
51+
def _log_cosh_error_compute(sum_log_cosh_error: Tensor, n_obs: Tensor) -> Tensor:
52+
"""Computes Mean Squared Error.
53+
54+
Args:
55+
sum_squared_error: Sum of LogCosh errors over all observations
56+
n_obs: Number of predictions or observations
57+
"""
58+
return (sum_log_cosh_error / n_obs).squeeze()
59+
60+
61+
def log_cosh_error(preds: Tensor, target: Tensor) -> Tensor:
62+
r"""Compute the `LogCosh Error`_.
63+
64+
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)
65+
66+
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
67+
68+
Args:
69+
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
70+
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
71+
72+
Return:
73+
Tensor with LogCosh error
74+
75+
Example (single output regression)::
76+
>>> from torchmetrics.functional import log_cosh_error
77+
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
78+
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
79+
>>> log_cosh_error(preds, target)
80+
tensor(0.3523)
81+
82+
Example (multi output regression)::
83+
>>> from torchmetrics.functional import log_cosh_error
84+
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
85+
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
86+
>>> log_cosh_error(preds, target)
87+
tensor([0.9176, 0.4277, 0.2194])
88+
"""
89+
sum_log_cosh_error, n_obs = _log_cosh_error_update(
90+
preds, target, num_outputs=1 if preds.ndim == 1 else preds.shape[-1]
91+
)
92+
return _log_cosh_error_compute(sum_log_cosh_error, n_obs)

src/torchmetrics/functional/regression/pearson.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from torch import Tensor
1818

19-
from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
19+
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
2020
from torchmetrics.utilities.checks import _check_same_shape
2121

2222

@@ -45,7 +45,7 @@ def _pearson_corrcoef_update(
4545
"""
4646
# Data checking
4747
_check_same_shape(preds, target)
48-
_check_data_shape_for_corr_coef(preds, target, num_outputs)
48+
_check_data_shape_to_num_outputs(preds, target, num_outputs)
4949

5050
n_obs = preds.shape[0]
5151
mx_new = (n_prior * mean_x + preds.mean(0) * n_obs) / (n_prior + n_obs)

src/torchmetrics/functional/regression/spearman.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import torch
1717
from torch import Tensor
1818

19-
from torchmetrics.functional.regression.utils import _check_data_shape_for_corr_coef
19+
from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs
2020
from torchmetrics.utilities.checks import _check_same_shape
2121

2222

@@ -68,7 +68,7 @@ def _spearman_corrcoef_update(preds: Tensor, target: Tensor, num_outputs: int) -
6868
"Expected `preds` and `target` both to be floating point tensors, but got {pred.dtype} and {target.dtype}"
6969
)
7070
_check_same_shape(preds, target)
71-
_check_data_shape_for_corr_coef(preds, target, num_outputs)
71+
_check_data_shape_to_num_outputs(preds, target, num_outputs)
7272

7373
return preds, target
7474

src/torchmetrics/functional/regression/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch import Tensor
1515

1616

17-
def _check_data_shape_for_corr_coef(preds: Tensor, target: Tensor, num_outputs: int) -> None:
17+
def _check_data_shape_to_num_outputs(preds: Tensor, target: Tensor, num_outputs: int) -> None:
1818
"""Check that predictions and target have the correct shape, else raise error."""
1919
if preds.ndim > 2 or target.ndim > 2:
2020
raise ValueError(

src/torchmetrics/metric.py

+3
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,9 @@ def __getitem__(self, idx: int) -> "Metric":
858858
def __getnewargs__(self) -> Tuple:
859859
return (Metric.__str__(self),)
860860

861+
def __iter__(self):
862+
raise NotImplementedError("Metrics does not support iteration.")
863+
861864

862865
def _neg(x: Tensor) -> Tensor:
863866
return -torch.abs(x)

src/torchmetrics/regression/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchmetrics.regression.explained_variance import ExplainedVariance # noqa: F401
1717
from torchmetrics.regression.kendall import KendallRankCorrCoef # noqa: F401
1818
from torchmetrics.regression.kl_divergence import KLDivergence # noqa: F401
19+
from torchmetrics.regression.log_cosh import LogCoshError # noqa: F401
1920
from torchmetrics.regression.log_mse import MeanSquaredLogError # noqa: F401
2021
from torchmetrics.regression.mae import MeanAbsoluteError # noqa: F401
2122
from torchmetrics.regression.mape import MeanAbsolutePercentageError # noqa: F401
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any
15+
16+
import torch
17+
from torch import Tensor
18+
19+
from torchmetrics.functional.regression.log_cosh import _log_cosh_error_compute, _log_cosh_error_update
20+
from torchmetrics.metric import Metric
21+
22+
23+
class LogCoshError(Metric):
24+
r"""Compute the `LogCosh Error`_.
25+
26+
.. math:: \text{LogCoshError} = \log\left(\frac{\exp(\hat{y} - y) + \exp(\hat{y - y})}{2}\right)
27+
28+
Where :math:`y` is a tensor of target values, and :math:`\hat{y}` is a tensor of predictions.
29+
30+
Args:
31+
num_outputs: Number of outputs in multioutput setting
32+
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
33+
34+
Example (single output regression)::
35+
>>> from torchmetrics import LogCoshError
36+
>>> preds = torch.tensor([3.0, 5.0, 2.5, 7.0])
37+
>>> target = torch.tensor([2.5, 5.0, 4.0, 8.0])
38+
>>> log_cosh_error = LogCoshError()
39+
>>> log_cosh_error(preds, target)
40+
tensor(0.3523)
41+
42+
Example (multi output regression)::
43+
>>> from torchmetrics import LogCoshError
44+
>>> preds = torch.tensor([[3.0, 5.0, 1.2], [-2.1, 2.5, 7.0]])
45+
>>> target = torch.tensor([[2.5, 5.0, 1.3], [0.3, 4.0, 8.0]])
46+
>>> log_cosh_error = LogCoshError(num_outputs=3)
47+
>>> log_cosh_error(preds, target)
48+
tensor([0.9176, 0.4277, 0.2194])
49+
"""
50+
51+
is_differentiable = True
52+
higher_is_better = False
53+
full_state_update = False
54+
sum_log_cosh_error: Tensor
55+
total: Tensor
56+
57+
def __init__(self, num_outputs: int = 1, **kwargs: Any) -> None:
58+
super().__init__(**kwargs)
59+
60+
if not isinstance(num_outputs, int) and num_outputs < 1:
61+
raise ValueError(f"Expected argument `num_outputs` to be an int larger than 0, but got {num_outputs}")
62+
self.num_outputs = num_outputs
63+
self.add_state("sum_log_cosh_error", default=torch.zeros(num_outputs), dist_reduce_fx="sum")
64+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
65+
66+
def update(self, preds: Tensor, target: Tensor) -> None:
67+
"""Update state with predictions and targets.
68+
69+
Args:
70+
preds: estimated labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
71+
target: ground truth labels with shape ``(batch_size,)`` or `(batch_size, num_outputs)``
72+
73+
Raises:
74+
ValueError:
75+
If ``preds`` or ``target`` has multiple outputs when ``num_outputs=1``
76+
"""
77+
sum_log_cosh_error, n_obs = _log_cosh_error_update(preds, target, self.num_outputs)
78+
self.sum_log_cosh_error += sum_log_cosh_error
79+
self.total += n_obs
80+
81+
def compute(self) -> Tensor:
82+
"""Compute LogCosh error over state."""
83+
return _log_cosh_error_compute(self.sum_log_cosh_error, self.total)

tests/unittests/bases/test_metric.py

+7
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,10 @@ def test_custom_availability_check_and_sync_fn():
464464
acc.compute()
465465
dummy_availability_check.assert_called_once()
466466
assert dummy_dist_sync_fn.call_count == 4 # tp, fp, tn, fn
467+
468+
469+
def test_no_iteration_allowed():
470+
metric = DummyMetric()
471+
with pytest.raises(NotImplementedError, match="Metrics does not support iteration."):
472+
for m in metric:
473+
continue

0 commit comments

Comments
 (0)