24
24
25
25
from torchmetrics .detection .helpers import _fix_empty_tensors , _input_validator
26
26
from torchmetrics .metric import Metric
27
+ from torchmetrics .utilities import rank_zero_warn
27
28
from torchmetrics .utilities .imports import (
28
29
_MATPLOTLIB_AVAILABLE ,
29
30
_PYCOCOTOOLS_AVAILABLE ,
@@ -239,6 +240,8 @@ class MeanAveragePrecision(Metric):
239
240
groundtruth_crowds : List [Tensor ]
240
241
groundtruth_area : List [Tensor ]
241
242
243
+ warn_on_many_detections : bool = True
244
+
242
245
def __init__ (
243
246
self ,
244
247
box_format : Literal ["xyxy" , "xywh" , "cxcywh" ] = "xyxy" ,
@@ -329,7 +332,7 @@ def update(self, preds: List[Dict[str, Tensor]], target: List[Dict[str, Tensor]]
329
332
_input_validator (preds , target , iou_type = self .iou_type )
330
333
331
334
for item in preds :
332
- detections = self ._get_safe_item_values (item )
335
+ detections = self ._get_safe_item_values (item , warn = self . warn_on_many_detections )
333
336
334
337
self .detections .append (detections )
335
338
self .detection_labels .append (item ["labels" ])
@@ -542,11 +545,12 @@ def tm_to_coco(self, name: str = "tm_map_input") -> None:
542
545
with open (f"{ name } _target.json" , "w" ) as f :
543
546
f .write (target_json )
544
547
545
- def _get_safe_item_values (self , item : Dict [str , Any ]) -> Union [Tensor , Tuple ]:
548
+ def _get_safe_item_values (self , item : Dict [str , Any ], warn : bool = False ) -> Union [Tensor , Tuple ]:
546
549
"""Convert and return the boxes or masks from the item depending on the iou_type.
547
550
548
551
Args:
549
552
item: input dictionary containing the boxes or masks
553
+ warn: whether to warn if the number of boxes or masks exceeds the max_detection_thresholds
550
554
551
555
Returns:
552
556
boxes or masks depending on the iou_type
@@ -556,12 +560,16 @@ def _get_safe_item_values(self, item: Dict[str, Any]) -> Union[Tensor, Tuple]:
556
560
boxes = _fix_empty_tensors (item ["boxes" ])
557
561
if boxes .numel () > 0 :
558
562
boxes = box_convert (boxes , in_fmt = self .box_format , out_fmt = "xywh" )
563
+ if warn and len (boxes ) > self .max_detection_thresholds [- 1 ]:
564
+ _warning_on_too_many_detections (self .max_detection_thresholds [- 1 ])
559
565
return boxes
560
566
if self .iou_type == "segm" :
561
567
masks = []
562
568
for i in item ["masks" ].cpu ().numpy ():
563
569
rle = mask_utils .encode (np .asfortranarray (i ))
564
570
masks .append ((tuple (rle ["size" ]), rle ["counts" ]))
571
+ if warn and len (masks ) > self .max_detection_thresholds [- 1 ]:
572
+ _warning_on_too_many_detections (self .max_detection_thresholds [- 1 ])
565
573
return tuple (masks )
566
574
raise Exception (f"IOU type { self .iou_type } is not supported" )
567
575
@@ -747,3 +755,13 @@ def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any]
747
755
dist .all_gather_object (list_gathered , list_to_gather , group = process_group )
748
756
749
757
return [list_gathered [rank ][idx ] for idx in range (len (list_gathered [0 ])) for rank in range (world_size )]
758
+
759
+
760
+ def _warning_on_too_many_detections (limit : int ) -> None :
761
+ rank_zero_warn (
762
+ f"Encountered more than { limit } detections in a single image. This means that certain detections with the"
763
+ " lowest scores will be ignored, that may have an undesirable impact on performance. Please consider adjusting"
764
+ " the `max_detection_threshold` to suit your use case. To disable this warning, set attribute class"
765
+ " `warn_on_many_detections=False`, after initializing the metric." ,
766
+ UserWarning ,
767
+ )
0 commit comments