diff --git a/CHANGELOG.md b/CHANGELOG.md index a0bc7c284d3..13f09f6dcf8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,6 +53,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fixed `torch.sort` currently does not support bool dtype on CUDA ([#665](https://github.com/PyTorchLightning/metrics/pull/665)) + + ## [0.6.1] - 2021-12-06 diff --git a/tests/detection/test_map.py b/tests/detection/test_map.py index 43c586cee28..83d850c8d4b 100644 --- a/tests/detection/test_map.py +++ b/tests/detection/test_map.py @@ -215,6 +215,19 @@ def test_empty_preds(): metric.compute() +_gpu_test_condition = not torch.cuda.is_available() + + +@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed") +@pytest.mark.skipif(_gpu_test_condition, reason="test requires CUDA availability") +def test_map_gpu(): + """Test predictions on single gpu.""" + metric = MAP() + metric = metric.to("cuda") + metric.update(_inputs.preds[0], _inputs.target[0]) + metric.compute() + + @pytest.mark.skipif(_pytest_condition, reason="test requires that pycocotools and torchvision=>0.8.0 is installed") def test_empty_metric(): """Test empty metric.""" diff --git a/torchmetrics/detection/map.py b/torchmetrics/detection/map.py index 4d18f5be2cd..d40370857d7 100644 --- a/torchmetrics/detection/map.py +++ b/torchmetrics/detection/map.py @@ -407,7 +407,9 @@ def _evaluate_image( ignore_area = (areas < area_range[0]) | (areas > area_range[1]) # sort dt highest score first, sort gt ignore last - ignore_area_sorted, gtind = torch.sort(ignore_area) + ignore_area_sorted, gtind = torch.sort(ignore_area.to(torch.uint8)) + # Convert to uint8 temporarily and back to bool, because "Sort currently does not support bool dtype on CUDA" + ignore_area_sorted = ignore_area_sorted.to(torch.bool) gt = gt[gtind] scores = self.detection_scores[id] scores_filtered = scores[det_label_mask] @@ -429,7 +431,7 @@ def _evaluate_image( for idx_iou, t in enumerate(self.iou_thresholds): for idx_det in range(nb_det): m = MAP._find_best_gt_match(t, nb_gt, gt_matches, idx_iou, gt_ignore, ious, idx_det) - if m is not -1: + if m != -1: det_ignore[idx_iou, idx_det] = gt_ignore[m] det_matches[idx_iou, idx_det] = True gt_matches[idx_iou, m] = True