Skip to content

Commit

Permalink
Set weights only for the classes axis
Browse files Browse the repository at this point in the history
  • Loading branch information
baskrahmer authored and rittik9 committed Dec 7, 2024
1 parent bf1c29f commit 930fba3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def _adjust_weights_safe_divide(
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0

if ignore_index is not None and 0 <= ignore_index < len(score):
weights[ignore_index] = 0.0
weights[..., ignore_index] = 0.0

return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)

Expand Down

0 comments on commit 930fba3

Please sign in to comment.