Skip to content

Commit 53869eb

Browse files
authored
Add sanitize_bounding_boxes kernel/functional (#8308)
1 parent d1f3a7b commit 53869eb

File tree

6 files changed

+182
-43
lines changed

6 files changed

+182
-43
lines changed

docs/source/transforms.rst

+1
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ Functionals
414414

415415
v2.functional.normalize
416416
v2.functional.erase
417+
v2.functional.sanitize_bounding_boxes
417418
v2.functional.clamp_bounding_boxes
418419
v2.functional.uniform_temporal_subsample
419420

test/test_transforms_v2.py

+81-18
Original file line numberDiff line numberDiff line change
@@ -5675,18 +5675,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
56755675

56765676

56775677
class TestSanitizeBoundingBoxes:
5678-
@pytest.mark.parametrize("min_size", (1, 10))
5679-
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
5680-
@pytest.mark.parametrize("sample_type", (tuple, dict))
5681-
def test_transform(self, min_size, labels_getter, sample_type):
5682-
5683-
if sample_type is tuple and not isinstance(labels_getter, str):
5684-
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
5685-
# doesn't work if the input is a tuple.
5686-
return
5687-
5688-
H, W = 256, 128
5689-
5678+
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
56905679
boxes_and_validity = [
56915680
([0, 1, 10, 1], False), # Y1 == Y2
56925681
([0, 1, 0, 20], False), # X1 == X2
@@ -5706,18 +5695,31 @@ def test_transform(self, min_size, labels_getter, sample_type):
57065695
]
57075696

57085697
random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
5709-
boxes, is_valid_mask = zip(*boxes_and_validity)
5710-
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
5711-
5712-
boxes = torch.tensor(boxes)
5713-
labels = torch.arange(boxes.shape[0])
5698+
boxes, expected_valid_mask = zip(*boxes_and_validity)
57145699

57155700
boxes = tv_tensors.BoundingBoxes(
57165701
boxes,
57175702
format=tv_tensors.BoundingBoxFormat.XYXY,
57185703
canvas_size=(H, W),
57195704
)
57205705

5706+
return boxes, expected_valid_mask
5707+
5708+
@pytest.mark.parametrize("min_size", (1, 10))
5709+
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
5710+
@pytest.mark.parametrize("sample_type", (tuple, dict))
5711+
def test_transform(self, min_size, labels_getter, sample_type):
5712+
5713+
if sample_type is tuple and not isinstance(labels_getter, str):
5714+
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
5715+
# doesn't work if the input is a tuple.
5716+
return
5717+
5718+
H, W = 256, 128
5719+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
5720+
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
5721+
5722+
labels = torch.arange(boxes.shape[0])
57215723
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
57225724
whatever = torch.rand(10)
57235725
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
@@ -5763,6 +5765,44 @@ def test_transform(self, min_size, labels_getter, sample_type):
57635765
# This works because we conveniently set labels to arange(num_boxes)
57645766
assert out_labels.tolist() == valid_indices
57655767

5768+
@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
5769+
def test_functional(self, input_type):
5770+
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
5771+
# redundancy with test_transform() in terms of correctness checks. But that's OK.
5772+
5773+
H, W, min_size = 256, 128, 10
5774+
5775+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
5776+
5777+
if input_type is tv_tensors.BoundingBoxes:
5778+
format = canvas_size = None
5779+
else:
5780+
# just passing "XYXY" explicitly to make sure we support strings
5781+
format, canvas_size = "XYXY", boxes.canvas_size
5782+
boxes = boxes.as_subclass(torch.Tensor)
5783+
5784+
boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size)
5785+
5786+
assert_equal(valid, torch.tensor(expected_valid_mask))
5787+
assert type(valid) == torch.Tensor
5788+
assert boxes.shape[0] == sum(valid)
5789+
assert isinstance(boxes, input_type)
5790+
5791+
def test_kernel(self):
5792+
H, W, min_size = 256, 128, 10
5793+
boxes, _ = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
5794+
5795+
format, canvas_size = boxes.format, boxes.canvas_size
5796+
boxes = boxes.as_subclass(torch.Tensor)
5797+
5798+
check_kernel(
5799+
F.sanitize_bounding_boxes,
5800+
input=boxes,
5801+
format=format,
5802+
canvas_size=canvas_size,
5803+
check_batched_vs_unbatched=False,
5804+
)
5805+
57665806
def test_no_label(self):
57675807
# Non-regression test for https://github.com/pytorch/vision/issues/7878
57685808

@@ -5776,7 +5816,7 @@ def test_no_label(self):
57765816
assert isinstance(out_img, tv_tensors.Image)
57775817
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
57785818

5779-
def test_errors(self):
5819+
def test_errors_transform(self):
57805820
good_bbox = tv_tensors.BoundingBoxes(
57815821
[[0, 0, 10, 10]],
57825822
format=tv_tensors.BoundingBoxFormat.XYXY,
@@ -5799,3 +5839,26 @@ def test_errors(self):
57995839
with pytest.raises(ValueError, match="Number of boxes"):
58005840
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
58015841
transforms.SanitizeBoundingBoxes()(different_sizes)
5842+
5843+
def test_errors_functional(self):
5844+
5845+
good_bbox = tv_tensors.BoundingBoxes(
5846+
[[0, 0, 10, 10]],
5847+
format=tv_tensors.BoundingBoxFormat.XYXY,
5848+
canvas_size=(20, 20),
5849+
)
5850+
5851+
with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
5852+
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format="XYXY", canvas_size=None)
5853+
5854+
with pytest.raises(ValueError, match="canvas_size cannot be None if bounding_boxes is a pure tensor"):
5855+
F.sanitize_bounding_boxes(good_bbox.as_subclass(torch.Tensor), format=None, canvas_size=(10, 10))
5856+
5857+
with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
5858+
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)
5859+
5860+
with pytest.raises(ValueError, match="canvas_size must be None when bounding_boxes is a tv_tensors"):
5861+
F.sanitize_bounding_boxes(good_bbox, format="XYXY", canvas_size=None)
5862+
5863+
with pytest.raises(ValueError, match="bouding_boxes must be a tv_tensors.BoundingBoxes instance or a"):
5864+
F.sanitize_bounding_boxes(good_bbox.tolist())

torchvision/prototype/transforms/_augment.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _extract_image_targets(
123123
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
124124
raise TypeError(
125125
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
126-
"BoundingBoxeses, Masks and Labels or OneHotLabels."
126+
"BoundingBoxes, Masks and Labels or OneHotLabels."
127127
)
128128

129129
targets = []

torchvision/transforms/v2/_misc.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
2+
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
33

44
import PIL.Image
55

@@ -369,28 +369,14 @@ def forward(self, *inputs: Any) -> Any:
369369
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
370370
)
371371

372-
boxes = cast(
373-
tv_tensors.BoundingBoxes,
374-
F.convert_bounding_box_format(
375-
boxes,
376-
new_format=tv_tensors.BoundingBoxFormat.XYXY,
377-
),
372+
valid = F._misc._get_sanitize_bounding_boxes_mask(
373+
boxes,
374+
format=boxes.format,
375+
canvas_size=boxes.canvas_size,
376+
min_size=self.min_size,
378377
)
379-
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
380-
valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
381-
# TODO: Do we really need to check for out of bounds here? All
382-
# transforms should be clamping anyway, so this should never happen?
383-
image_h, image_w = boxes.canvas_size
384-
valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
385-
valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
386-
387-
params = dict(valid=valid.as_subclass(torch.Tensor), labels=labels)
388-
flat_outputs = [
389-
# Even-though it may look like we're transforming all inputs, we don't:
390-
# _transform() will only care about BoundingBoxeses and the labels
391-
self._transform(inpt, params)
392-
for inpt in flat_inputs
393-
]
378+
params = dict(valid=valid, labels=labels)
379+
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]
394380

395381
return tree_unflatten(flat_outputs, spec)
396382

torchvision/transforms/v2/functional/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@
167167
normalize,
168168
normalize_image,
169169
normalize_video,
170+
sanitize_bounding_boxes,
170171
to_dtype,
171172
to_dtype_image,
172173
to_dtype_video,

torchvision/transforms/v2/functional/_misc.py

+90-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import math
2-
from typing import List, Optional
2+
from typing import List, Optional, Tuple
33

44
import PIL.Image
55
import torch
@@ -11,7 +11,9 @@
1111

1212
from torchvision.utils import _log_api_usage_once
1313

14-
from ._utils import _get_kernel, _register_kernel_internal
14+
from ._meta import _convert_bounding_box_format
15+
16+
from ._utils import _get_kernel, _register_kernel_internal, is_pure_tensor
1517

1618

1719
def normalize(
@@ -275,3 +277,89 @@ def to_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float, scale:
275277
def _to_dtype_tensor_dispatch(inpt: torch.Tensor, dtype: torch.dtype, scale: bool = False) -> torch.Tensor:
276278
# We don't need to unwrap and rewrap here, since TVTensor.to() preserves the type
277279
return inpt.to(dtype)
280+
281+
282+
def sanitize_bounding_boxes(
283+
bounding_boxes: torch.Tensor,
284+
format: Optional[tv_tensors.BoundingBoxFormat] = None,
285+
canvas_size: Optional[Tuple[int, int]] = None,
286+
min_size: float = 1.0,
287+
) -> Tuple[torch.Tensor, torch.Tensor]:
288+
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
289+
290+
This removes bounding boxes that:
291+
292+
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
293+
- have any coordinate outside of their corresponding image. You may want to
294+
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
295+
296+
It is recommended to call it at the end of a pipeline, before passing the
297+
input to the models. It is critical to call this transform if
298+
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
299+
If you want to be extra careful, you may call it after all transforms that
300+
may modify bounding boxes but once at the end should be enough in most
301+
cases.
302+
303+
Args:
304+
bounding_boxes (Tensor or :class:`~torchvision.tv_tensors.BoundingBoxes`): The bounding boxes to be sanitized.
305+
format (str or :class:`~torchvision.tv_tensors.BoundingBoxFormat`, optional): The format of the bounding boxes.
306+
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
307+
canvas_size (tuple of int, optional): The canvas_size of the bounding boxes
308+
(size of the corresponding image/video).
309+
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
310+
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
311+
312+
Returns:
313+
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
314+
The mask can then be used to subset other tensors (e.g. labels) that are associated with the bounding boxes.
315+
"""
316+
if torch.jit.is_scripting() or is_pure_tensor(bounding_boxes):
317+
if format is None or canvas_size is None:
318+
raise ValueError(
319+
"format and canvas_size cannot be None if bounding_boxes is a pure tensor. "
320+
f"Got format={format} and canvas_size={canvas_size}."
321+
"Set those to appropriate values or pass bounding_boxes as a tv_tensors.BoundingBoxes object."
322+
)
323+
if isinstance(format, str):
324+
format = tv_tensors.BoundingBoxFormat[format.upper()]
325+
valid = _get_sanitize_bounding_boxes_mask(
326+
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
327+
)
328+
bounding_boxes = bounding_boxes[valid]
329+
else:
330+
if not isinstance(bounding_boxes, tv_tensors.BoundingBoxes):
331+
raise ValueError("bouding_boxes must be a tv_tensors.BoundingBoxes instance or a pure tensor.")
332+
if format is not None or canvas_size is not None:
333+
raise ValueError(
334+
"format and canvas_size must be None when bounding_boxes is a tv_tensors.BoundingBoxes instance. "
335+
f"Got format={format} and canvas_size={canvas_size}. "
336+
"Leave those to None or pass bouding_boxes as a pure tensor."
337+
)
338+
valid = _get_sanitize_bounding_boxes_mask(
339+
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
340+
)
341+
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)
342+
343+
return bounding_boxes, valid
344+
345+
346+
def _get_sanitize_bounding_boxes_mask(
347+
bounding_boxes: torch.Tensor,
348+
format: tv_tensors.BoundingBoxFormat,
349+
canvas_size: Tuple[int, int],
350+
min_size: float = 1.0,
351+
) -> torch.Tensor:
352+
353+
bounding_boxes = _convert_bounding_box_format(
354+
bounding_boxes, new_format=tv_tensors.BoundingBoxFormat.XYXY, old_format=format
355+
)
356+
357+
image_h, image_w = canvas_size
358+
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
359+
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
360+
# TODO: Do we really need to check for out of bounds here? All
361+
# transforms should be clamping anyway, so this should never happen?
362+
image_h, image_w = canvas_size
363+
valid &= (bounding_boxes[:, 0] <= image_w) & (bounding_boxes[:, 2] <= image_w)
364+
valid &= (bounding_boxes[:, 1] <= image_h) & (bounding_boxes[:, 3] <= image_h)
365+
return valid

0 commit comments

Comments
 (0)