Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple aggregation metrics #506

Merged
merged 44 commits into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
cd7d788
update
SkafteNicki Sep 8, 2021
32d2c42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
3f75b10
changelog
SkafteNicki Sep 8, 2021
17fd6d1
pep8
SkafteNicki Sep 8, 2021
ce9621f
docs
SkafteNicki Sep 8, 2021
11d6780
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
2357353
examples
SkafteNicki Sep 8, 2021
b0cfb03
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Sep 8, 2021
03c1ddb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2021
eb9ff4a
Merge branch 'master' into aggregation
SkafteNicki Sep 10, 2021
43ad430
Merge branch 'master' into aggregation
Borda Sep 14, 2021
227e0cd
Merge branch 'master' into aggregation
SkafteNicki Sep 21, 2021
f42dddb
change max and min
SkafteNicki Sep 21, 2021
6acea7e
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Sep 21, 2021
950c91c
mask gpu testing
SkafteNicki Sep 21, 2021
8cf52a3
remove half test
SkafteNicki Sep 21, 2021
4447272
fix tests
SkafteNicki Sep 24, 2021
7606768
aggr
SkafteNicki Sep 24, 2021
04e44f9
Merge branch 'master' into aggregation
SkafteNicki Sep 24, 2021
15bfcb6
Merge branch 'master' into aggregation
SkafteNicki Sep 24, 2021
59936ac
fix
SkafteNicki Sep 24, 2021
08786e2
Merge branch 'master' into aggregation
SkafteNicki Sep 24, 2021
8f6f4ce
Merge branch 'master' into aggregation
SkafteNicki Sep 27, 2021
7a44155
fix test
SkafteNicki Sep 30, 2021
0eaade7
fix attribute
SkafteNicki Sep 30, 2021
a25a0ec
remove
SkafteNicki Sep 30, 2021
3cff7f8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
c65d0b3
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Sep 30, 2021
9990f58
Merge branch 'master' into aggregation
SkafteNicki Sep 30, 2021
84ab17f
add docstrings
SkafteNicki Sep 30, 2021
cab4d21
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Sep 30, 2021
a8d2cce
fix mistake
SkafteNicki Sep 30, 2021
eba0aa6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 30, 2021
ec6c17f
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Sep 30, 2021
0ad2b41
Update torchmetrics/aggregation.py
SkafteNicki Sep 30, 2021
e22ef2c
Merge branch 'master' into aggregation
SkafteNicki Oct 1, 2021
bc891a9
suggestions
SkafteNicki Oct 7, 2021
8766315
Merge branch 'master' into aggregation
SkafteNicki Oct 7, 2021
f302e11
diff test
SkafteNicki Oct 8, 2021
2df33e0
Merge branch 'aggregation' of https://github.com/PyTorchLightning/met…
SkafteNicki Oct 8, 2021
57dccee
docs
SkafteNicki Oct 8, 2021
1c74b95
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2021
c26cd4f
Merge branch 'master' into aggregation
SkafteNicki Oct 11, 2021
6461b90
Merge branch 'master' into aggregation
Borda Oct 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `SacreBLEUScore` metric to text package ([#546](https://github.com/PyTorchLightning/metrics/pull/546))


- Added simple aggregation metrics: `SumMetric`, `MeanMetric`, `CatMetric`, `MinMetric`, `MaxMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- `AveragePrecision` will now as default output the `macro` average for multilabel and multiclass problems ([#477](https://github.com/PyTorchLightning/metrics/pull/477))
Expand All @@ -42,9 +45,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `half`, `double`, `float` will no longer change the dtype of the metric states. Use `metric.set_dtype` instead ([#493](https://github.com/PyTorchLightning/metrics/pull/493))


- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551))
- Renamed `AverageMeter` to `MeanMetric` ([#506](https://github.com/PyTorchLightning/metrics/pull/506))


- Changed `is_differentiable` from property to a constant attribute ([#551](https://github.com/PyTorchLightning/metrics/pull/551))

### Deprecated


Expand Down
38 changes: 35 additions & 3 deletions docs/source/references/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,42 @@ metrics.
.. autoclass:: torchmetrics.Metric
:noindex:

We also have an ``AverageMeter`` class that is helpful for defining ad-hoc metrics, when creating
your own metric type might be too burdensome.

.. autoclass:: torchmetrics.AverageMeter
*************************
Basic Aggregation Metrics
*************************

Torchmetrics comes with a number of metrics for aggregation of basic statistics: mean, max, min etc. of
either tensors or native python floats.

CatMetric
~~~~~~~~~

.. autoclass:: torchmetrics.CatMetric
:noindex:

MaxMetric
~~~~~~~~~

.. autoclass:: torchmetrics.MaxMetric
:noindex:

MeanMetric
~~~~~~~~~~

.. autoclass:: torchmetrics.MeanMetric
:noindex:

MinMetric
~~~~~~~~~

.. autoclass:: torchmetrics.MinMetric
:noindex:

SumMetric
~~~~~~~~~

.. autoclass:: torchmetrics.SumMetric
:noindex:

*************
Expand Down
166 changes: 166 additions & 0 deletions tests/bases/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import numpy as np
import pytest
import torch

from tests.helpers.testers import BATCH_SIZE, NUM_BATCHES, MetricTester
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric


def compare_mean(values, weights):
"""reference implementation for mean aggregation."""
return np.average(values.numpy(), weights=weights)


def compare_sum(values, weights):
"""reference implementation for sum aggregation."""
return np.sum(values.numpy())


def compare_min(values, weights):
"""reference implementation for min aggregation."""
return np.min(values.numpy())


def compare_max(values, weights):
"""reference implementation for max aggregation."""
return np.max(values.numpy())
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved


# wrap all other than mean metric to take an additional argument
# this lets them fit into the testing framework
class WrappedMinMetric(MinMetric):
"""Wrapped min metric."""

def update(self, values, weights):
"""only pass values on."""
super().update(values)


class WrappedMaxMetric(MaxMetric):
"""Wrapped max metric."""

def update(self, values, weights):
"""only pass values on."""
super().update(values)


class WrappedSumMetric(SumMetric):
"""Wrapped min metric."""

def update(self, values, weights):
"""only pass values on."""
super().update(values)


class WrappedCatMetric(CatMetric):
"""Wrapped cat metric."""

def update(self, values, weights):
"""only pass values on."""
super().update(values)


@pytest.mark.parametrize(
"values, weights",
[
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.ones(NUM_BATCHES, BATCH_SIZE)),
(torch.rand(NUM_BATCHES, BATCH_SIZE), torch.rand(NUM_BATCHES, BATCH_SIZE) > 0.5),
(torch.rand(NUM_BATCHES, BATCH_SIZE, 2), torch.rand(NUM_BATCHES, BATCH_SIZE, 2) > 0.5),
],
)
@pytest.mark.parametrize(
"metric_class, compare_fn",
[
(WrappedMinMetric, compare_min),
(WrappedMaxMetric, compare_max),
(WrappedSumMetric, compare_sum),
(MeanMetric, compare_mean),
],
)
class TestAggregation(MetricTester):
"""Test aggregation metrics."""

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False])
def test_aggreagation(self, ddp, dist_sync_on_step, metric_class, compare_fn, values, weights):
"""test modular implementation."""
self.run_class_metric_test(
ddp=ddp,
dist_sync_on_step=dist_sync_on_step,
metric_class=metric_class,
sk_metric=compare_fn,
check_scriptable=True,
# Abuse of names here
preds=values,
target=weights,
)


_case1 = float("nan") * torch.ones(5)
_case2 = torch.tensor([1.0, 2.0, float("nan"), 4.0, 5.0])


@pytest.mark.parametrize("value", [_case1, _case2])
@pytest.mark.parametrize("nan_strategy", ["error", "warn"])
@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric])
def test_nan_error(value, nan_strategy, metric_class):
"""test correct errors are raised."""
metric = metric_class(nan_strategy=nan_strategy)
if nan_strategy == "error":
with pytest.raises(RuntimeError, match="Encounted `nan` values in tensor"):
metric(value.clone())
elif nan_strategy == "warn":
with pytest.warns(UserWarning, match="Encounted `nan` values in tensor"):
metric(value.clone())


@pytest.mark.parametrize(
"metric_class, nan_strategy, value, expected",
[
(MinMetric, "ignore", _case1, torch.tensor(float("inf"))),
(MinMetric, 2.0, _case1, 2.0),
(MinMetric, "ignore", _case2, 1.0),
(MinMetric, 2.0, _case2, 1.0),
(MaxMetric, "ignore", _case1, -torch.tensor(float("inf"))),
(MaxMetric, 2.0, _case1, 2.0),
(MaxMetric, "ignore", _case2, 5.0),
(MaxMetric, 2.0, _case2, 5.0),
(SumMetric, "ignore", _case1, 0.0),
(SumMetric, 2.0, _case1, 10.0),
(SumMetric, "ignore", _case2, 12.0),
(SumMetric, 2.0, _case2, 14.0),
(MeanMetric, "ignore", _case1, torch.tensor([float("nan")])),
(MeanMetric, 2.0, _case1, 2.0),
(MeanMetric, "ignore", _case2, 3.0),
(MeanMetric, 2.0, _case2, 2.8),
(CatMetric, "ignore", _case1, []),
(CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])),
(CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])),
(CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])),
],
)
def test_nan_expected(metric_class, nan_strategy, value, expected):
"""test that nan values are handled correctly."""
metric = metric_class(nan_strategy=nan_strategy)
metric.update(value.clone())
out = metric.compute()
assert np.allclose(out, expected, equal_nan=True)


@pytest.mark.parametrize("metric_class", [MinMetric, MaxMetric, SumMetric, MeanMetric, CatMetric])
def test_error_on_wrong_nan_strategy(metric_class):
"""test error raised on wrong nan_strategy argument."""
with pytest.raises(ValueError, match="Arg `nan_strategy` should either .*"):
metric_class(nan_strategy=[])


@pytest.mark.skipif(not hasattr(torch, "broadcast_to"), reason="PyTorch <1.8 does not have broadcast_to")
@pytest.mark.parametrize(
"weights, expected", [(1, 11.5), (torch.ones(2, 1, 1), 11.5), (torch.tensor([1, 2]).reshape(2, 1, 1), 13.5)]
)
def test_mean_metric_broadcasting(weights, expected):
"""check that weight broadcasting works for mean metric."""
values = torch.arange(24).reshape(2, 3, 4)
avg = MeanMetric()

assert avg(values, weights) == expected
88 changes: 0 additions & 88 deletions tests/bases/test_average.py

This file was deleted.

8 changes: 6 additions & 2 deletions torchmetrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)

from torchmetrics import functional # noqa: E402
from torchmetrics.aggregation import CatMetric, MaxMetric, MeanMetric, MinMetric, SumMetric # noqa: E402
from torchmetrics.audio import PESQ, PIT, SI_SDR, SI_SNR, SNR # noqa: E402
from torchmetrics.average import AverageMeter # noqa: E402
from torchmetrics.classification import ( # noqa: E402
AUC,
AUROC,
Expand Down Expand Up @@ -71,7 +71,6 @@
"Accuracy",
"AUC",
"AUROC",
"AverageMeter",
"AveragePrecision",
"BinnedAveragePrecision",
"BinnedPrecisionRecallCurve",
Expand All @@ -80,6 +79,7 @@
"BLEUScore",
"BootStrapper",
"CalibrationError",
"CatMetric",
"CohenKappa",
"ConfusionMatrix",
"CosineSimilarity",
Expand All @@ -96,13 +96,16 @@
"KLDivergence",
"LPIPS",
"MatthewsCorrcoef",
"MaxMetric",
"MeanAbsoluteError",
"MeanAbsolutePercentageError",
"MeanMetric",
"MeanSquaredError",
"MeanSquaredLogError",
"Metric",
"MetricCollection",
"MetricTracker",
"MinMetric",
"MultioutputWrapper",
"PearsonCorrcoef",
"PESQ",
Expand All @@ -128,6 +131,7 @@
"Specificity",
"SSIM",
"StatScores",
"SumMetric",
"SymmetricMeanAbsolutePercentageError",
"WER",
]
Loading