@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform):
321
321
- have any coordinate outside of their corresponding image. You may want to
322
322
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
323
323
324
+ It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
325
+ (see ``labels_getter`` parameter).
326
+
324
327
It is recommended to call it at the end of a pipeline, before passing the
325
328
input to the models. It is critical to call this transform if
326
329
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
@@ -330,18 +333,26 @@ class SanitizeBoundingBoxes(Transform):
330
333
331
334
Args:
332
335
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
333
- labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
336
+ labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
337
+ (or anything else that needs to be sanitized along with the bounding boxes).
334
338
By default, this will try to find a "labels" key in the input (case-insensitive), if
335
339
the input is a dict or it is a tuple whose second element is a dict.
336
340
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
337
- It can also be a callable that takes the same input
338
- as the transform, and returns the labels.
341
+
342
+ It can also be a callable that takes the same input as the transform, and returns either:
343
+
344
+ - A single tensor (the labels)
345
+ - A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes.
346
+ This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties
347
+ from COCO.
348
+
349
+ If ``labels_getter`` is None then only bounding boxes are sanitized.
339
350
"""
340
351
341
352
def __init__ (
342
353
self ,
343
354
min_size : float = 1.0 ,
344
- labels_getter : Union [Callable [[Any ], Optional [ torch . Tensor ] ], str , None ] = "default" ,
355
+ labels_getter : Union [Callable [[Any ], Any ], str , None ] = "default" ,
345
356
) -> None :
346
357
super ().__init__ ()
347
358
@@ -356,18 +367,28 @@ def forward(self, *inputs: Any) -> Any:
356
367
inputs = inputs if len (inputs ) > 1 else inputs [0 ]
357
368
358
369
labels = self ._labels_getter (inputs )
359
- if labels is not None and not isinstance (labels , torch .Tensor ):
360
- raise ValueError (
361
- f"The labels in the input to forward() must be a tensor or None, got { type (labels )} instead."
362
- )
370
+ if labels is not None :
371
+ msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
372
+ if isinstance (labels , torch .Tensor ):
373
+ labels = (labels ,)
374
+ elif isinstance (labels , (tuple , list )):
375
+ for entry in labels :
376
+ if not isinstance (entry , torch .Tensor ):
377
+ # TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
378
+ raise ValueError (msg .format (type = type (entry )))
379
+ else :
380
+ raise ValueError (msg .format (type = type (labels )))
363
381
364
382
flat_inputs , spec = tree_flatten (inputs )
365
383
boxes = get_bounding_boxes (flat_inputs )
366
384
367
- if labels is not None and boxes .shape [0 ] != labels .shape [0 ]:
368
- raise ValueError (
369
- f"Number of boxes (shape={ boxes .shape } ) and number of labels (shape={ labels .shape } ) do not match."
370
- )
385
+ if labels is not None :
386
+ for label in labels :
387
+ if boxes .shape [0 ] != label .shape [0 ]:
388
+ raise ValueError (
389
+ f"Number of boxes (shape={ boxes .shape } ) and must match the number of labels."
390
+ f"Found labels with shape={ label .shape } )."
391
+ )
371
392
372
393
valid = F ._misc ._get_sanitize_bounding_boxes_mask (
373
394
boxes ,
@@ -381,7 +402,7 @@ def forward(self, *inputs: Any) -> Any:
381
402
return tree_unflatten (flat_outputs , spec )
382
403
383
404
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
384
- is_label = inpt is not None and inpt is params ["labels" ]
405
+ is_label = params [ "labels" ] is not None and any ( inpt is label for label in params ["labels" ])
385
406
is_bounding_boxes_or_mask = isinstance (inpt , (tv_tensors .BoundingBoxes , tv_tensors .Mask ))
386
407
387
408
if not (is_label or is_bounding_boxes_or_mask ):
@@ -391,5 +412,5 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
391
412
392
413
if is_label :
393
414
return output
394
-
395
- return tv_tensors .wrap (output , like = inpt )
415
+ else :
416
+ return tv_tensors .wrap (output , like = inpt )
0 commit comments