Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/multiclass recall macro avg ignore index #2710

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed multiclass recall macro avg. ignore index ([#2710](https://github.com/Lightning-AI/torchmetrics/pull/2710))


---
Expand Down Expand Up @@ -117,12 +117,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [1.4.3] - 2024-10-10

### Fixed

- Fixed for Pearson changes inputs ([#2765](https://github.com/Lightning-AI/torchmetrics/pull/2765))
- Fixed bug in `PESQ` metric where `NoUtterancesError` prevented calculating on a batch of data ([#2753](https://github.com/Lightning-AI/torchmetrics/pull/2753))
- Fixed corner case in `MatthewsCorrCoef` ([#2743](https://github.com/Lightning-AI/torchmetrics/pull/2743))



## [1.4.2] - 2022-09-12

### Added
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ def compute(self) -> Tensor:
multidim_average=self.multidim_average,
top_k=self.top_k,
zero_division=self.zero_division,
ignore_index=self.ignore_index,
)

def plot(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _precision_recall_reduce(
multilabel: bool = False,
top_k: int = 1,
zero_division: float = 0,
ignore_index: Optional[int] = None,
) -> Tensor:
different_stat = fp if stat == "precision" else fn # this is what differs between the two scores
if average == "binary":
Expand All @@ -56,7 +57,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat, zero_division)

score = _safe_divide(tp, tp + different_stat, zero_division)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k, ignore_index=ignore_index)


def binary_precision(
Expand Down Expand Up @@ -253,6 +254,7 @@ def multiclass_precision(
multidim_average=multidim_average,
top_k=top_k,
zero_division=zero_division,
ignore_index=ignore_index,
)


Expand Down Expand Up @@ -563,6 +565,7 @@ def multiclass_recall(
multidim_average=multidim_average,
top_k=top_k,
zero_division=zero_division,
ignore_index=ignore_index,
)


Expand Down Expand Up @@ -676,6 +679,7 @@ def multilabel_recall(
multidim_average=multidim_average,
multilabel=True,
zero_division=zero_division,
ignore_index=ignore_index,
)


Expand Down
13 changes: 12 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
score: Tensor,
average: Optional[str],
multilabel: bool,
tp: Tensor,
fp: Tensor,
fn: Tensor,
top_k: int = 1,
ignore_index: Optional[int] = None,
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -78,6 +85,10 @@ def _adjust_weights_safe_divide(
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0

if ignore_index is not None and 0 <= ignore_index < len(score):
weights[..., ignore_index] = 0.0

return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
43 changes: 41 additions & 2 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def _reference_sklearn_precision_recall_multiclass(
if preds.ndim == target.ndim + 1:
preds = torch.argmax(preds, 1)

valid_labels = list(range(NUM_CLASSES))
if ignore_index is not None:
valid_labels = [label for label in valid_labels if label != ignore_index]

if multidim_average == "global":
preds = preds.numpy().flatten()
target = target.numpy().flatten()
Expand All @@ -212,7 +216,9 @@ def _reference_sklearn_precision_recall_multiclass(
target,
preds,
average=average,
labels=list(range(num_classes)) if average is None else None,
labels=valid_labels
if average in ("macro", "weighted")
else (list(range(num_classes)) if average is None else None),
zero_division=zero_division,
)

Expand All @@ -237,7 +243,9 @@ def _reference_sklearn_precision_recall_multiclass(
true,
pred,
average=average,
labels=list(range(num_classes)) if average is None else None,
labels=valid_labels
if average in ("macro", "weighted")
else (list(range(num_classes)) if average is None else None),
zero_division=zero_division,
)
res.append(0.0 if np.isnan(r).any() else r)
Expand Down Expand Up @@ -661,6 +669,37 @@ def test_corner_case():
assert res == 1.0


def test_multiclass_recall_ignore_index():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems we are already testing various ignore_index with reference metric so if we had it wrong this did not pass already... it is possible that we also have a bug in the reference metric?
cc: @SkafteNicki

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking to the code and the ignore index is already applied in _multilabel_stat_scores_format which reduces the preds/target size the same way as the reference metric so calling it with null weights in fact ignores additional index

Copy link
Contributor Author

@rittik9 rittik9 Nov 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is we are using sklearn's recall_score as a reference for our unittests. So even if in _reference_sklearn_precision_recall_multiclass() function we are using remove_ignore_index function for removing those predictions whose real values are ignore_index class before passing it to recall_score function, it does not matter. Because whenever average='macro' sklearn's recall_score will always return mean cosidering the total no. of classes (as we are passing all the classes in recall_score() function's labels argument). That is the reason why unittests failed in the first place. I think we need to fix the unittests to take care of ignore_index using sklearn's recall_score() function's labels argument. I've prepared a notebook for explanation. cc:@Borda.

"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
y_true = torch.tensor([0, 0, 1, 1])
y_pred = torch.tensor([
[0.9, 0.1],
[0.9, 0.1],
[0.9, 0.1],
[0.1, 0.9],
])

# Test with ignore_index=0 and average="macro"
metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro")
res_ignore_0 = metric_ignore_0(y_pred, y_true)
assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}"

# Test with ignore_index=1 and average="macro"
metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro")
res_ignore_1 = metric_ignore_1(y_pred, y_true)
assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}"

# Test with no ignore_index and average="macro"
metric_no_ignore = MulticlassRecall(num_classes=2, average="macro")
res_no_ignore = metric_no_ignore(y_pred, y_true)
assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}"

# Test with ignore_index=0 and average="none"
metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none")
res_none = metric_none(y_pred, y_true)
assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}"


@pytest.mark.parametrize(
("metric", "kwargs", "base_metric"),
[
Expand Down
Loading