@@ -246,13 +246,11 @@ class MeanAveragePrecision(Metric):
246
246
If ``class_metrics`` is not a boolean
247
247
"""
248
248
249
- detection_boxes : List [Tensor ]
249
+ detections : List [Tensor ]
250
250
detection_scores : List [Tensor ]
251
251
detection_labels : List [Tensor ]
252
- groundtruth_boxes : List [Tensor ]
252
+ groundtruths : List [Tensor ]
253
253
groundtruth_labels : List [Tensor ]
254
- groundtruth_masks : List [Tensor ]
255
- detection_masks : List [Tensor ]
256
254
257
255
def __init__ (
258
256
self ,
@@ -303,13 +301,11 @@ def __init__(
303
301
raise ValueError ("Expected argument `class_metrics` to be a boolean" )
304
302
305
303
self .class_metrics = class_metrics
306
- self .add_state ("detection_boxes " , default = [], dist_reduce_fx = None )
304
+ self .add_state ("detections " , default = [], dist_reduce_fx = None )
307
305
self .add_state ("detection_scores" , default = [], dist_reduce_fx = None )
308
306
self .add_state ("detection_labels" , default = [], dist_reduce_fx = None )
309
- self .add_state ("groundtruth_boxes " , default = [], dist_reduce_fx = None )
307
+ self .add_state ("groundtruths " , default = [], dist_reduce_fx = None )
310
308
self .add_state ("groundtruth_labels" , default = [], dist_reduce_fx = None )
311
- self .add_state ("detection_masks" , default = [], dist_reduce_fx = None )
312
- self .add_state ("groundtruth_masks" , default = [], dist_reduce_fx = None )
313
309
314
310
def update (self , preds : List [Dict [str , Tensor ]], target : List [Dict [str , Tensor ]]) -> None : # type: ignore
315
311
"""Add detections and ground truth to the metric.
@@ -354,32 +350,29 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
354
350
ValueError:
355
351
If any score is not type float and of length 1
356
352
"""
357
- _input_validator (preds , target )
353
+ _input_validator (preds , target , iou_type = self . iou_type )
358
354
359
355
for item in preds :
360
- boxes , masks = self ._get_safe_item_values (item )
361
- self .detection_boxes .append (boxes )
356
+ detections = self ._get_safe_item_values (item )
357
+ self .detections .append (detections )
362
358
self .detection_labels .append (item ["labels" ])
363
359
self .detection_scores .append (item ["scores" ])
364
- self .detection_masks .append (masks )
365
360
366
361
for item in target :
367
- boxes , masks = self ._get_safe_item_values (item )
368
- self .groundtruth_boxes .append (boxes )
362
+ groundtruths = self ._get_safe_item_values (item )
363
+ self .groundtruths .append (groundtruths )
369
364
self .groundtruth_labels .append (item ["labels" ])
370
- self .groundtruth_masks .append (masks )
371
365
372
366
def _get_safe_item_values (self , item ):
373
367
if self .iou_type == "bbox" :
374
368
boxes = _fix_empty_tensors (item ["boxes" ])
375
369
boxes = box_convert (boxes , in_fmt = self .box_format , out_fmt = "xyxy" )
376
- masks = _fix_empty_tensors ( torch . Tensor ())
377
- elif self .iou_type == "masks " :
370
+ return boxes
371
+ elif self .iou_type == "segm " :
378
372
masks = _fix_empty_tensors (item ["masks" ])
379
- boxes = _fix_empty_tensors ( torch . Tensor ())
373
+ return masks
380
374
else :
381
375
raise Exception (f"IOU type { self .iou_type } is not supported" )
382
- return boxes , masks
383
376
384
377
def _get_classes (self ) -> List :
385
378
"""Returns a list of unique classes found in ground truth and detection data."""
@@ -388,18 +381,11 @@ def _get_classes(self) -> List:
388
381
return []
389
382
390
383
def _compute_iou (self , id : int , class_id : int , max_det : int ) -> Tensor :
391
- return self . _compute_iou_impl ( id , self . groundtruth_boxes , self .detection_boxes , class_id , max_det , box_iou )
384
+ iou_func = box_iou if self .iou_type == "bbox" else segm_iou
392
385
393
- # if self.iou_type == "segm":
394
- # return self._compute_iou_impl(id, self.groundtruth_masks, self.detection_masks, class_id, max_det, segm_iou)
395
- # elif self.iou_type == "bbox":
386
+ return self ._compute_iou_impl (id , class_id , max_det , iou_func )
396
387
397
- # else:
398
- # raise Exception(f"IOU type {self.iou_type} is not supported")
399
-
400
- def _compute_iou_impl (
401
- self , id : int , ground_truths , detections , class_id : int , max_det : int , compute_iou : Callable
402
- ) -> Tensor :
388
+ def _compute_iou_impl (self , id : int , class_id : int , max_det : int , compute_iou : Callable ) -> Tensor :
403
389
"""Computes the Intersection over Union (IoU) for ground truth and detection bounding boxes for the given
404
390
image and class.
405
391
@@ -412,8 +398,8 @@ def _compute_iou_impl(
412
398
Maximum number of evaluated detection bounding boxes
413
399
"""
414
400
# if self.iou_type == "bbox":
415
- gt = self .groundtruth_boxes [id ]
416
- det = self .detection_boxes [id ]
401
+ gt = self .groundtruths [id ]
402
+ det = self .detections [id ]
417
403
418
404
gt_label_mask = self .groundtruth_labels [id ] == class_id
419
405
det_label_mask = self .detection_labels [id ] == class_id
@@ -452,8 +438,8 @@ def _evaluate_image(
452
438
ious:
453
439
IoU results for image and class.
454
440
"""
455
- gt = self .groundtruth_boxes [id ]
456
- det = self .detection_boxes [id ]
441
+ gt = self .groundtruths [id ]
442
+ det = self .detections [id ]
457
443
gt_label_mask = self .groundtruth_labels [id ] == class_id
458
444
det_label_mask = self .detection_labels [id ] == class_id
459
445
if len (gt_label_mask ) == 0 or len (det_label_mask ) == 0 :
@@ -595,7 +581,7 @@ def _calculate(self, class_ids: List) -> Tuple[MAPMetricResults, MARMetricResult
595
581
class_ids:
596
582
List of label class Ids.
597
583
"""
598
- img_ids = range (len (self .groundtruth_boxes ))
584
+ img_ids = range (len (self .groundtruths ))
599
585
max_detections = self .max_detection_thresholds [- 1 ]
600
586
area_ranges = self .bbox_area_ranges .values ()
601
587
@@ -766,13 +752,11 @@ def compute(self) -> dict:
766
752
"""
767
753
768
754
# move everything to CPU, as we are faster here
769
- self .detection_boxes = [box .cpu () for box in self .detection_boxes ]
755
+ self .detections = [box .cpu () for box in self .detections ]
770
756
self .detection_labels = [label .cpu () for label in self .detection_labels ]
771
757
self .detection_scores = [score .cpu () for score in self .detection_scores ]
772
- self .groundtruth_boxes = [box .cpu () for box in self .groundtruth_boxes ]
758
+ self .groundtruths = [box .cpu () for box in self .groundtruths ]
773
759
self .groundtruth_labels = [label .cpu () for label in self .groundtruth_labels ]
774
- self .groundtruth_masks = [box .cpu () for box in self .groundtruth_masks ]
775
- self .detection_masks = [label .cpu () for label in self .detection_masks ]
776
760
777
761
classes = self ._get_classes ()
778
762
precisions , recalls = self ._calculate (classes )
0 commit comments