Skip to content

Commit

Permalink
Fix bool sort on CUDA (#665)
Browse files Browse the repository at this point in the history
* Fix bool sort on CUDA
* Add dedicated gpu test
* Apply suggestions from code review

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
twsl and Borda authored Dec 8, 2021
1 parent 5a200ac commit e98fbaf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed


- Fixed `torch.sort` currently does not support bool dtype on CUDA ([#665](https://github.com/PyTorchLightning/metrics/pull/665))



## [0.6.1] - 2021-12-06

Expand Down
13 changes: 13 additions & 0 deletions tests/detection/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,19 @@ def test_empty_preds():
metric.compute()


_gpu_test_condition = not torch.cuda.is_available()


@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability")
def test_map_gpu():
"""Test predictions on single gpu."""
metric = MAP()
metric = metric.to("cuda")
metric.update(_inputs.preds[0], _inputs.target[0])
metric.compute()


@pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed")
def test_empty_metric():
"""Test empty metric."""
Expand Down
6 changes: 4 additions & 2 deletions torchmetrics/detection/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ def _evaluate_image(
ignore_area = (areas < area_range[0]) | (areas > area_range[1])

# sort dt highest score first, sort gt ignore last
ignore_area_sorted, gtind = torch.sort(ignore_area)
ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8))
# Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA"
ignore_area_sorted = ignore_area_sorted.to(torch.bool)
gt = gt[gtind]
scores = self.detection_scores[id]
scores_filtered = scores[det_label_mask]
Expand All @@ -429,7 +431,7 @@ def _evaluate_image(
for idx_iou, t in enumerate(self.iou_thresholds):
for idx_det in range(nb_det):
m = MAP._find_best_gt_match(t, nb_gt, gt_matches, idx_iou, gt_ignore, ious, idx_det)
if m is not -1:
if m != -1:
det_ignore[idx_iou, idx_det] = gt_ignore[m]
det_matches[idx_iou, idx_det] = True
gt_matches[idx_iou, m] = True
Expand Down

0 comments on commit e98fbaf

Please sign in to comment.