Skip to content

Commit

Permalink
Merge branch 'master' into plot/detection
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Mar 6, 2023
2 parents 0b89372 + 0d28f26 commit a9740a0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed evaluation of `R2Score` with near constant target ([#1576](https://github.com/Lightning-AI/metrics/pull/1576))


- Fixed bug related to `top_k>1` and `ignore_index!=None` in `StatScores` based metrics ([#1589](https://github.com/Lightning-AI/metrics/pull/1589))


- Fixed corner case for `PearsonCorrCoef` when running in ddp mode but only on single device ([#1587](https://github.com/Lightning-AI/metrics/pull/1587))


Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/functional/classification/stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,9 @@ def _multiclass_stat_scores_update(
preds = preds.clone()
target = target.clone()
idx = target == ignore_index
preds[idx] = num_classes
target[idx] = num_classes
idx = idx.unsqueeze(1).repeat(1, num_classes, 1) if preds.ndim > target.ndim else idx
preds[idx] = num_classes

if top_k > 1:
preds_oh = torch.movedim(select_topk(preds, topk=top_k, dim=1), 1, -1)
Expand All @@ -374,7 +375,7 @@ def _multiclass_stat_scores_update(
if 0 <= ignore_index <= num_classes - 1:
target_oh[target == ignore_index, :] = -1
else:
preds_oh = preds_oh[..., :-1]
preds_oh = preds_oh[..., :-1] if top_k == 1 else preds_oh
target_oh = target_oh[..., :-1]
target_oh[target == num_classes, :] = -1
sum_dim = [0, 1] if multidim_average == "global" else [1]
Expand Down
15 changes: 15 additions & 0 deletions tests/unittests/classification/test_stat_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,21 @@ def test_top_k_multiclass(k, preds, target, average, expected):
)


def test_top_k_ignore_index_multiclass():
"""Test that top_k argument works together with ignore_index."""
preds_without = torch.randn(10, 3).softmax(dim=-1)
target_without = torch.randint(3, (10,))
preds_with = torch.cat([preds_without, torch.randn(10, 3).softmax(dim=-1)], 0)
target_with = torch.cat([target_without, -100 * torch.ones(10)], 0).long()

res_without = multiclass_stat_scores(preds_without, target_without, num_classes=3, average="micro", top_k=2)
res_with = multiclass_stat_scores(
preds_with, target_with, num_classes=3, average="micro", top_k=2, ignore_index=-100
)

assert torch.allclose(res_without, res_with)


def test_multiclass_overflow():
"""Test that multiclass computations does not overflow even on byte input."""
preds = torch.randint(20, (100,)).byte()
Expand Down

0 comments on commit a9740a0

Please sign in to comment.