Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/avg prec auroc compute groups #1086

Merged
merged 3 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed mAP calculation for areas with 0 predictions ([#1080](https://github.com/PyTorchLightning/metrics/pull/1080))


-
- Fixed bug where avg precision state and auroc state was not merge when using MetricCollections ([#1086](https://github.com/PyTorchLightning/metrics/pull/1086))


## [0.9.1] - 2022-06-08
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,8 @@ def _average_precision_update(
average: reduction method for multi-class or multi-label problems
"""
preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label)
if average == "micro":
if preds.ndim == target.ndim:
# Considering each element of the label indicator matrix as a label
preds = preds.flatten()
target = target.flatten()
num_classes = 1
else:
raise ValueError("Cannot use `micro` average with multi-class input")
if average == "micro" and preds.ndim != target.ndim:
raise ValueError("Cannot use `micro` average with multi-class input")

return preds, target, num_classes, pos_label

Expand Down Expand Up @@ -97,6 +91,11 @@ def _average_precision_compute(
"""

# todo: `sample_weights` is unused
if average == "micro" and preds.ndim == target.ndim:
preds = preds.flatten()
target = target.flatten()
num_classes = 1

precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label)
if average == "weighted":
if preds.ndim == target.ndim and target.ndim > 1:
Expand Down
67 changes: 53 additions & 14 deletions test/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import torch

from torchmetrics import (
AUROC,
Accuracy,
AveragePrecision,
CohenKappa,
ConfusionMatrix,
F1Score,
Expand All @@ -29,6 +31,7 @@
Precision,
Recall,
)
from torchmetrics.utilities.checks import _allclose_recursive
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyMetricDiff, DummyMetricSum

Expand Down Expand Up @@ -267,6 +270,8 @@ def test_collection_filtering():
"""Test that collections works with the kwargs argument."""

class DummyMetric(Metric):
full_state_update = True

def __init__(self):
super().__init__()

Expand All @@ -277,6 +282,8 @@ def compute(self):
return

class MyAccuracy(Metric):
full_state_update = True

def __init__(self):
super().__init__()

Expand All @@ -292,21 +299,30 @@ def compute(self):
mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2")


# function for generating
_mc_preds = torch.randn(10, 3).softmax(dim=-1)
_mc_target = torch.randint(3, (10,))
_ml_preds = torch.rand(10, 3)
_ml_target = torch.randint(2, (10, 3))


@pytest.mark.parametrize(
"metrics, expected",
"metrics, expected, preds, target",
[
# single metric forms its own compute group
(Accuracy(3), {0: ["Accuracy"]}),
(Accuracy(3), {0: ["Accuracy"]}, _mc_preds, _mc_target),
# two metrics of same class forms a compute group
({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}),
({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}, _mc_preds, _mc_target),
# two metrics from registry froms a compute group
([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}),
([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}, _mc_preds, _mc_target),
# two metrics from different classes gives two compute groups
([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}),
([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}, _mc_preds, _mc_target),
# multi group multi metric
(
[ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)],
{0: ["ConfusionMatrix", "CohenKappa"], 1: ["Recall", "Precision"]},
_mc_preds,
_mc_target,
),
# Complex example
(
Expand All @@ -319,6 +335,33 @@ def compute(self):
"confmat": ConfusionMatrix(3),
},
{0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]},
_mc_preds,
_mc_target,
),
# With list states
(
[AUROC(average="macro", num_classes=3), AveragePrecision(average="macro", num_classes=3)],
{0: ["AUROC", "AveragePrecision"]},
_mc_preds,
_mc_target,
),
# Nested collections
(
[
MetricCollection(
AUROC(average="micro", num_classes=3),
AveragePrecision(average="micro", num_classes=3),
postfix="_micro",
),
MetricCollection(
AUROC(average="macro", num_classes=3),
AveragePrecision(average="macro", num_classes=3),
postfix="_macro",
),
],
{0: ["AUROC_micro", "AveragePrecision_micro", "AUROC_macro", "AveragePrecision_macro"]},
_ml_preds,
_ml_target,
),
],
)
Expand All @@ -332,8 +375,10 @@ class TestComputeGroups:
["prefix_", "_postfix"],
],
)
def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix):
def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
if isinstance(metrics, MetricCollection):
prefix, postfix = None, None # disable for nested collections
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)
Expand All @@ -342,8 +387,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
assert m2.compute_groups == {}

for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

Expand All @@ -353,8 +396,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)
Expand All @@ -372,22 +413,20 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, method):
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check that whenever user call a methods that give access to the indivitual metric that state are copied
instead of just passed by reference."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

for _ in range(2): # repeat to emulate effect of multiple epochs
for _ in range(2): # repeat to emulate effect of multiple batches
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

def _compare(m1, m2):
for state in m1._defaults:
assert torch.allclose(getattr(m1, state), getattr(m2, state))
assert _allclose_recursive(getattr(m1, state), getattr(m2, state))
# if states are still by reference the reset will make following metrics fail
m1.reset()
m2.reset()
Expand Down
2 changes: 2 additions & 0 deletions test/unittests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@


class DummyMetric(Metric):
full_state_update = True

def __init__(self, val_to_return):
super().__init__()
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
Expand Down