Skip to content

Commit

Permalink
Fix cornercase of average precision (#2507)
Browse files Browse the repository at this point in the history
* fix + tests
* changelog

---------

Co-authored-by: Daniel Stancl <[email protected]>
  • Loading branch information
SkafteNicki and stancld authored Apr 17, 2024
1 parent 9acf6ba commit be39708
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed warnings being suppressed in `MeanAveragePrecision` when requested ([#2501](https://github.com/Lightning-AI/torchmetrics/pull/2501))


- Fixed cornercase in `binary_average_precision` when only negative samples are provided ([#2507](https://github.com/Lightning-AI/torchmetrics/pull/2507))


## [1.3.2] - 2024-03-18

### Fixed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torchmetrics.utilities.compute import _safe_divide, interp
from torchmetrics.utilities.data import _bincount, _cumsum
from torchmetrics.utilities.enums import ClassificationTask
from torchmetrics.utilities.prints import rank_zero_warn


def _binary_clf_curve(
Expand Down Expand Up @@ -274,6 +275,12 @@ def _binary_precision_recall_curve_compute(
fps, tps, thresholds = _binary_clf_curve(state[0], state[1], pos_label=pos_label)
precision = tps / (tps + fps)
recall = tps / tps[-1]
if (state[1] == 0).all(): # all labels are negative, recall is undefined
rank_zero_warn(
"No positive samples found in target, recall is undefined. Setting recall to one for all thresholds.",
UserWarning,
)
recall = torch.ones_like(recall)

# need to call reversed explicitly, since including that to slice would
# introduce negative strides that are not yet supported in pytorch
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/classification/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ def test_binary_average_precision_threshold_arg(self, inputs, threshold_fn):
assert torch.allclose(ap1, ap2)


def test_warning_on_no_positives():
"""Test that a warning is raised when there are no positive samples in the target."""
preds = torch.rand(100)
target = torch.zeros(100).long()
with pytest.warns(UserWarning, match="No positive samples found in target, recall is undefined. Setting recall.*"):
binary_average_precision(preds, target)


def _reference_sklearn_avg_precision_multiclass(preds, target, average="macro", ignore_index=None):
preds = np.moveaxis(preds.numpy(), 1, -1).reshape((-1, preds.shape[1]))
target = target.numpy().flatten()
Expand Down

0 comments on commit be39708

Please sign in to comment.