Skip to content

Commit

Permalink
[metrics] Accuracy num_classes error fix (#3764)
Browse files Browse the repository at this point in the history
* change accuracy error to warning

* changelog
  • Loading branch information
SkafteNicki authored Oct 1, 2020
1 parent 8be002c commit 9a7d1a1
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `ModelCheckpoint` with `save_top_k=-1` option not tracking the best models when a monitor metric is available ([3735](https://github.com/PyTorchLightning/pytorch-lightning/pull/3735))

- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))

## [0.9.0] - YYYY-MM-DD

### Added
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/metrics/functional/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def get_num_classes(
if num_classes is None:
num_classes = num_all_classes
elif num_classes != num_all_classes:
rank_zero_warn(f'You have set {num_classes} number of classes if different from'
f' predicted ({num_pred_classes}) and target ({num_target_classes}) number of classes')
rank_zero_warn(f'You have set {num_classes} number of classes which is'
f' different from predicted ({num_pred_classes}) and'
f' target ({num_target_classes}) number of classes',
RuntimeWarning)
return num_classes


Expand Down Expand Up @@ -266,9 +268,6 @@ def accuracy(
tensor(0.7500)
"""
if not (target > 0).any() and num_classes is None:
raise RuntimeError("cannot infer num_classes when target is all zero")

tps, fps, tns, fns, sups = stat_scores_multiple_classes(
pred=pred, target=target, num_classes=num_classes)

Expand Down
8 changes: 6 additions & 2 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ def test_multilabel_accuracy():
assert torch.allclose(accuracy(y2, torch.logical_not(y2), class_reduction='none'), torch.tensor([0., 0.]))
assert torch.allclose(accuracy(y1, torch.logical_not(y1), class_reduction='none'), torch.tensor([0., 0.]))

with pytest.raises(RuntimeError):
accuracy(y2, torch.zeros_like(y2), class_reduction='none')
# num_classes does not match extracted number from input we expect a warning
with pytest.warns(RuntimeWarning,
match=r'You have set .* number of classes which is'
r' different from predicted (.*) and'
r' target (.*) number of classes'):
_ = accuracy(y2, torch.zeros_like(y2), num_classes=3)


def test_accuracy():
Expand Down

0 comments on commit 9a7d1a1

Please sign in to comment.