Skip to content

Commit c470675

Browse files
committed
clean API interface; working on tests for iou_type SEGM
1 parent 35fee46 commit c470675

File tree

2 files changed

+42
-38
lines changed

2 files changed

+42
-38
lines changed

tests/detection/test_map.py

+20
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,26 @@ def test_empty_metric():
266266
metric.compute()
267267

268268

269+
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
270+
def test_segm_iou_empty_mask():
271+
"""Test empty ground truths."""
272+
metric = MeanAveragePrecision(iou_type="segm")
273+
274+
metric.update(
275+
[
276+
dict(
277+
masks=torch.randint(0, 1, (1, 10, 10)),
278+
scores=torch.Tensor([0.5]),
279+
labels=torch.IntTensor([4]),
280+
),
281+
],
282+
[
283+
dict(masks=torch.Tensor([]), labels=torch.IntTensor([])),
284+
],
285+
)
286+
metric.compute()
287+
288+
269289
@pytest.mark.skipif(_pytest_condition, reason="test requires that torchvision=>0.8.0 is installed")
270290
def test_error_on_wrong_input():
271291
"""Test class input validation."""

torchmetrics/detection/map.py

+22-38
Original file line numberDiff line numberDiff line change
@@ -246,13 +246,11 @@ class MeanAveragePrecision(Metric):
246246
If ``class_metrics`` is not a boolean
247247
"""
248248

249-
detection_boxes: List[Tensor]
249+
detections: List[Tensor]
250250
detection_scores: List[Tensor]
251251
detection_labels: List[Tensor]
252-
groundtruth_boxes: List[Tensor]
252+
groundtruths: List[Tensor]
253253
groundtruth_labels: List[Tensor]
254-
groundtruth_masks: List[Tensor]
255-
detection_masks: List[Tensor]
256254

257255
def __init__(
258256
self,
@@ -303,13 +301,11 @@ def __init__(
303301
raise ValueError("Expected argument `class_metrics` to be a boolean")
304302

305303
self.class_metrics = class_metrics
306-
self.add_state("detection_boxes", default=[], dist_reduce_fx=None)
304+
self.add_state("detections", default=[], dist_reduce_fx=None)
307305
self.add_state("detection_scores", default=[], dist_reduce_fx=None)
308306
self.add_state("detection_labels", default=[], dist_reduce_fx=None)
309-
self.add_state("groundtruth_boxes", default=[], dist_reduce_fx=None)
307+
self.add_state("groundtruths", default=[], dist_reduce_fx=None)
310308
self.add_state("groundtruth_labels", default=[], dist_reduce_fx=None)
311-
self.add_state("detection_masks", default=[], dist_reduce_fx=None)
312-
self.add_state("groundtruth_masks", default=[], dist_reduce_fx=None)
313309

314310
def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]) -> None: # type: ignore
315311
"""Add detections and ground truth to the metric.
@@ -354,32 +350,29 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
354350
ValueError:
355351
If any score is not type float and of length 1
356352
"""
357-
_input_validator(preds, target)
353+
_input_validator(preds, target, iou_type=self.iou_type)
358354

359355
for item in preds:
360-
boxes, masks = self._get_safe_item_values(item)
361-
self.detection_boxes.append(boxes)
356+
detections = self._get_safe_item_values(item)
357+
self.detections.append(detections)
362358
self.detection_labels.append(item["labels"])
363359
self.detection_scores.append(item["scores"])
364-
self.detection_masks.append(masks)
365360

366361
for item in target:
367-
boxes, masks = self._get_safe_item_values(item)
368-
self.groundtruth_boxes.append(boxes)
362+
groundtruths = self._get_safe_item_values(item)
363+
self.groundtruths.append(groundtruths)
369364
self.groundtruth_labels.append(item["labels"])
370-
self.groundtruth_masks.append(masks)
371365

372366
def _get_safe_item_values(self, item):
373367
if self.iou_type == "bbox":
374368
boxes = _fix_empty_tensors(item["boxes"])
375369
boxes = box_convert(boxes, in_fmt=self.box_format, out_fmt="xyxy")
376-
masks = _fix_empty_tensors(torch.Tensor())
377-
elif self.iou_type == "masks":
370+
return boxes
371+
elif self.iou_type == "segm":
378372
masks = _fix_empty_tensors(item["masks"])
379-
boxes = _fix_empty_tensors(torch.Tensor())
373+
return masks
380374
else:
381375
raise Exception(f"IOU type {self.iou_type} is not supported")
382-
return boxes, masks
383376

384377
def _get_classes(self) -> List:
385378
"""Returns a list of unique classes found in ground truth and detection data."""
@@ -388,18 +381,11 @@ def _get_classes(self) -> List:
388381
return []
389382

390383
def _compute_iou(self, id: int, class_id: int, max_det: int) -> Tensor:
391-
return self._compute_iou_impl(id, self.groundtruth_boxes, self.detection_boxes, class_id, max_det, box_iou)
384+
iou_func = box_iou if self.iou_type == "bbox" else segm_iou
392385

393-
# if self.iou_type == "segm":
394-
# return self._compute_iou_impl(id, self.groundtruth_masks, self.detection_masks, class_id, max_det, segm_iou)
395-
# elif self.iou_type == "bbox":
386+
return self._compute_iou_impl(id, class_id, max_det, iou_func)
396387

397-
# else:
398-
# raise Exception(f"IOU type {self.iou_type} is not supported")
399-
400-
def _compute_iou_impl(
401-
self, id: int, ground_truths, detections, class_id: int, max_det: int, compute_iou: Callable
402-
) -> Tensor:
388+
def _compute_iou_impl(self, id: int, class_id: int, max_det: int, compute_iou: Callable) -> Tensor:
403389
"""Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given
404390
image and class.
405391
@@ -412,8 +398,8 @@ def _compute_iou_impl(
412398
Maximum number of evaluated detection bounding boxes
413399
"""
414400
# if self.iou_type == "bbox":
415-
gt = self.groundtruth_boxes[id]
416-
det = self.detection_boxes[id]
401+
gt = self.groundtruths[id]
402+
det = self.detections[id]
417403

418404
gt_label_mask = self.groundtruth_labels[id] == class_id
419405
det_label_mask = self.detection_labels[id] == class_id
@@ -452,8 +438,8 @@ def _evaluate_image(
452438
ious:
453439
IoU results for image and class.
454440
"""
455-
gt = self.groundtruth_boxes[id]
456-
det = self.detection_boxes[id]
441+
gt = self.groundtruths[id]
442+
det = self.detections[id]
457443
gt_label_mask = self.groundtruth_labels[id] == class_id
458444
det_label_mask = self.detection_labels[id] == class_id
459445
if len(gt_label_mask) == 0 or len(det_label_mask) == 0:
@@ -595,7 +581,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
595581
class_ids:
596582
List of label class Ids.
597583
"""
598-
img_ids = range(len(self.groundtruth_boxes))
584+
img_ids = range(len(self.groundtruths))
599585
max_detections = self.max_detection_thresholds[-1]
600586
area_ranges = self.bbox_area_ranges.values()
601587

@@ -766,13 +752,11 @@ def compute(self) -> dict:
766752
"""
767753

768754
# move everything to CPU, as we are faster here
769-
self.detection_boxes = [box.cpu() for box in self.detection_boxes]
755+
self.detections = [box.cpu() for box in self.detections]
770756
self.detection_labels = [label.cpu() for label in self.detection_labels]
771757
self.detection_scores = [score.cpu() for score in self.detection_scores]
772-
self.groundtruth_boxes = [box.cpu() for box in self.groundtruth_boxes]
758+
self.groundtruths = [box.cpu() for box in self.groundtruths]
773759
self.groundtruth_labels = [label.cpu() for label in self.groundtruth_labels]
774-
self.groundtruth_masks = [box.cpu() for box in self.groundtruth_masks]
775-
self.detection_masks = [label.cpu() for label in self.detection_masks]
776760

777761
classes = self._get_classes()
778762
precisions, recalls = self._calculate(classes)

0 commit comments

Comments
 (0)