Skip to content

Commit

Permalink
Bugfix/avg prec auroc compute groups (#1086)
Browse files Browse the repository at this point in the history
* fix bugs
* changelog

(cherry picked from commit c2f55fa)
  • Loading branch information
SkafteNicki authored and Borda committed Jun 14, 2022
1 parent 659216b commit fbb470a
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 24 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed mAP calculation for areas with 0 predictions ([#1080](https://github.com/PyTorchLightning/metrics/pull/1080))


-
- Fixed bug where avg precision state and auroc state was not merge when using MetricCollections ([#1086](https://github.com/PyTorchLightning/metrics/pull/1086))


## [0.9.1] - 2022-06-08
Expand Down
68 changes: 54 additions & 14 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from tests.helpers import seed_all
from tests.helpers.testers import DummyMetricDiff, DummyMetricSum
from torchmetrics import (
AUROC,
Accuracy,
AveragePrecision,
CohenKappa,
ConfusionMatrix,
F1Score,
Expand All @@ -31,6 +33,8 @@
Precision,
Recall,
)
from torchmetrics.utilities.checks import _allclose_recursive


seed_all(42)

Expand Down Expand Up @@ -267,6 +271,8 @@ def test_collection_filtering():
"""Test that collections works with the kwargs argument."""

class DummyMetric(Metric):
full_state_update = True

def __init__(self):
super().__init__()

Expand All @@ -277,6 +283,8 @@ def compute(self):
return

class MyAccuracy(Metric):
full_state_update = True

def __init__(self):
super().__init__()

Expand All @@ -292,21 +300,30 @@ def compute(self):
mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2")


# function for generating
_mc_preds = torch.randn(10, 3).softmax(dim=-1)
_mc_target = torch.randint(3, (10,))
_ml_preds = torch.rand(10, 3)
_ml_target = torch.randint(2, (10, 3))


@pytest.mark.parametrize(
"metrics, expected",
"metrics, expected, preds, target",
[
# single metric forms its own compute group
(Accuracy(3), {0: ["Accuracy"]}),
(Accuracy(3), {0: ["Accuracy"]}, _mc_preds, _mc_target),
# two metrics of same class forms a compute group
({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}),
({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}, _mc_preds, _mc_target),
# two metrics from registry froms a compute group
([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}),
([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}, _mc_preds, _mc_target),
# two metrics from different classes gives two compute groups
([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}),
([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}, _mc_preds, _mc_target),
# multi group multi metric
(
[ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)],
{0: ["ConfusionMatrix", "CohenKappa"], 1: ["Recall", "Precision"]},
_mc_preds,
_mc_target,
),
# Complex example
(
Expand All @@ -319,6 +336,33 @@ def compute(self):
"confmat": ConfusionMatrix(3),
},
{0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]},
_mc_preds,
_mc_target,
),
# With list states
(
[AUROC(average="macro", num_classes=3), AveragePrecision(average="macro", num_classes=3)],
{0: ["AUROC", "AveragePrecision"]},
_mc_preds,
_mc_target,
),
# Nested collections
(
[
MetricCollection(
AUROC(average="micro", num_classes=3),
AveragePrecision(average="micro", num_classes=3),
postfix="_micro",
),
MetricCollection(
AUROC(average="macro", num_classes=3),
AveragePrecision(average="macro", num_classes=3),
postfix="_macro",
),
],
{0: ["AUROC_micro", "AveragePrecision_micro", "AUROC_macro", "AveragePrecision_macro"]},
_ml_preds,
_ml_target,
),
],
)
Expand All @@ -332,8 +376,10 @@ class TestComputeGroups:
["prefix_", "_postfix"],
],
)
def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix):
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
if isinstance(metrics, MetricCollection):
prefix, postfix = None, None # disable for nested collections
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)
Expand All @@ -342,8 +388,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
assert m2.compute_groups == {}

for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

Expand All @@ -353,8 +397,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)
Expand All @@ -372,22 +414,20 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, method):
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check that whenever user call a methods that give access to the indivitual metric that state are copied
instead of just passed by reference."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

for _ in range(2): # repeat to emulate effect of multiple epochs
for _ in range(2): # repeat to emulate effect of multiple batches
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

def _compare(m1, m2):
for state in m1._defaults:
assert torch.allclose(getattr(m1, state), getattr(m2, state))
assert _allclose_recursive(getattr(m1, state), getattr(m2, state))
# if states are still by reference the reset will make following metrics fail
m1.reset()
m2.reset()
Expand Down
2 changes: 2 additions & 0 deletions tests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@


class DummyMetric(Metric):
full_state_update = True

def __init__(self, val_to_return):
super().__init__()
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
Expand Down
15 changes: 7 additions & 8 deletions torchmetrics/functional/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,8 @@ def _average_precision_update(
average: reduction method for multi-class or multi-label problems
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
if average == "micro":
if preds.ndim == target.ndim:
# Considering each element of the label indicator matrix as a label
preds = preds.flatten()
target = target.flatten()
num_classes = 1
else:
raise ValueError("Cannot use `micro` average with multi-class input")
if average == "micro" and preds.ndim != target.ndim:
raise ValueError("Cannot use `micro` average with multi-class input")

return preds, target, num_classes, pos_label

Expand Down Expand Up @@ -97,6 +91,11 @@ def _average_precision_compute(
"""

# todo: `sample_weights` is unused
if average == "micro" and preds.ndim == target.ndim:
preds = preds.flatten()
target = target.flatten()
num_classes = 1

precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
if average == "weighted":
if preds.ndim == target.ndim and target.ndim > 1:
Expand Down

0 comments on commit fbb470a

Please sign in to comment.