Skip to content

Commit

Permalink
test: Add test for MulticlassRecall with ignore_index+macro (fixes Li…
Browse files Browse the repository at this point in the history
  • Loading branch information
rittik9 committed Sep 2, 2024
1 parent b9716b2 commit 3646d68
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,37 @@ def test_corner_case():
assert res == 1.0


def test_multiclass_recall_ignore_index():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
y_true = torch.tensor([0, 0, 1, 1])
y_pred = torch.tensor([
[0.9, 0.1],
[0.9, 0.1],
[0.9, 0.1],
[0.1, 0.9],
])

# Test with ignore_index=0 and average="macro"
metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro")
res_ignore_0 = metric_ignore_0(y_pred, y_true)
assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}"

# Test with ignore_index=1 and average="macro"
metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro")
res_ignore_1 = metric_ignore_1(y_pred, y_true)
assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}"

# Test with no ignore_index and average="macro"
metric_no_ignore = MulticlassRecall(num_classes=2, average="macro")
res_no_ignore = metric_no_ignore(y_pred, y_true)
assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}"

# Test with ignore_index=0 and average="none"
metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none")
res_none = metric_none(y_pred, y_true)
assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}"


@pytest.mark.parametrize(
("metric", "kwargs", "base_metric"),
[
Expand Down

0 comments on commit 3646d68

Please sign in to comment.