Skip to content

Commit 5cbac0e

Browse files
authored
Merge branch 'master' into feature-#672
2 parents 2c38938 + d071eb2 commit 5cbac0e

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

tests/detection/test_map.py

+20
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,26 @@ def test_empty_preds():
215215
metric.compute()
216216

217217

218+
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
219+
def test_empty_ground_truths():
220+
"""Test empty ground truths."""
221+
metric = MAP()
222+
223+
metric.update(
224+
[
225+
dict(
226+
boxes=torch.Tensor([[214.1500, 41.2900, 562.4100, 285.0700]]),
227+
scores=torch.Tensor([0.5]),
228+
labels=torch.IntTensor([4]),
229+
),
230+
],
231+
[
232+
dict(boxes=torch.Tensor([]), labels=torch.IntTensor([])),
233+
],
234+
)
235+
metric.compute()
236+
237+
218238
_gpu_test_condition = not torch.cuda.is_available()
219239

220240

torchmetrics/detection/map.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
356356
det = self.detection_boxes[id]
357357
gt_label_mask = self.groundtruth_labels[id] == class_id
358358
det_label_mask = self.detection_labels[id] == class_id
359-
if len(det_label_mask) == 0 or len(det_label_mask) == 0:
359+
if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
360360
return Tensor([])
361361
gt = gt[gt_label_mask]
362362
det = det[det_label_mask]
@@ -396,7 +396,7 @@ def _evaluate_image(
396396
det = self.detection_boxes[id]
397397
gt_label_mask = self.groundtruth_labels[id] == class_id
398398
det_label_mask = self.detection_labels[id] == class_id
399-
if len(det_label_mask) == 0 or len(det_label_mask) == 0:
399+
if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
400400
return None
401401
gt = gt[gt_label_mask]
402402
det = det[det_label_mask]

0 commit comments

Comments
 (0)