Skip to content

Commit b551c48

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Allow SanitizeBoundingBoxes to sanitize more labels (#8319)
Reviewed By: vmoens Differential Revision: D55062811 fbshipit-source-id: 2216cc6805bff1c69d3dd7e570b80c927b82d19a
1 parent 101731a commit b551c48

File tree

3 files changed

+62
-21
lines changed

3 files changed

+62
-21
lines changed

test/test_transforms_v2.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -5706,7 +5706,17 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
57065706
return boxes, expected_valid_mask
57075707

57085708
@pytest.mark.parametrize("min_size", (1, 10))
5709-
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None))
5709+
@pytest.mark.parametrize(
5710+
"labels_getter",
5711+
(
5712+
"default",
5713+
lambda inputs: inputs["labels"],
5714+
lambda inputs: (inputs["labels"], inputs["other_labels"]),
5715+
lambda inputs: [inputs["labels"], inputs["other_labels"]],
5716+
None,
5717+
lambda inputs: None,
5718+
),
5719+
)
57105720
@pytest.mark.parametrize("sample_type", (tuple, dict))
57115721
def test_transform(self, min_size, labels_getter, sample_type):
57125722

@@ -5721,12 +5731,16 @@ def test_transform(self, min_size, labels_getter, sample_type):
57215731

57225732
labels = torch.arange(boxes.shape[0])
57235733
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
5734+
# other_labels corresponds to properties from COCO like iscrowd, area...
5735+
# We only sanitize it when labels_getter returns a tuple
5736+
other_labels = torch.arange(boxes.shape[0])
57245737
whatever = torch.rand(10)
57255738
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
57265739
sample = {
57275740
"image": input_img,
57285741
"labels": labels,
57295742
"boxes": boxes,
5743+
"other_labels": other_labels,
57305744
"whatever": whatever,
57315745
"None": None,
57325746
"masks": masks,
@@ -5741,12 +5755,14 @@ def test_transform(self, min_size, labels_getter, sample_type):
57415755
if sample_type is tuple:
57425756
out_image = out[0]
57435757
out_labels = out[1]["labels"]
5758+
out_other_labels = out[1]["other_labels"]
57445759
out_boxes = out[1]["boxes"]
57455760
out_masks = out[1]["masks"]
57465761
out_whatever = out[1]["whatever"]
57475762
else:
57485763
out_image = out["image"]
57495764
out_labels = out["labels"]
5765+
out_other_labels = out["other_labels"]
57505766
out_boxes = out["boxes"]
57515767
out_masks = out["masks"]
57525768
out_whatever = out["whatever"]
@@ -5757,14 +5773,20 @@ def test_transform(self, min_size, labels_getter, sample_type):
57575773
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
57585774
assert isinstance(out_masks, tv_tensors.Mask)
57595775

5760-
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
5776+
if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
57615777
assert out_labels is labels
5778+
assert out_other_labels is other_labels
57625779
else:
57635780
assert isinstance(out_labels, torch.Tensor)
57645781
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
57655782
# This works because we conveniently set labels to arange(num_boxes)
57665783
assert out_labels.tolist() == valid_indices
57675784

5785+
if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)):
5786+
assert_equal(out_other_labels, out_labels)
5787+
else:
5788+
assert_equal(out_other_labels, other_labels)
5789+
57685790
@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
57695791
def test_functional(self, input_type):
57705792
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some

torchvision/transforms/v2/_misc.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform):
321321
- have any coordinate outside of their corresponding image. You may want to
322322
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
323323
324+
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
325+
(see ``labels_getter`` parameter).
326+
324327
It is recommended to call it at the end of a pipeline, before passing the
325328
input to the models. It is critical to call this transform if
326329
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
@@ -330,18 +333,26 @@ class SanitizeBoundingBoxes(Transform):
330333
331334
Args:
332335
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
333-
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input.
336+
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
337+
(or anything else that needs to be sanitized along with the bounding boxes).
334338
By default, this will try to find a "labels" key in the input (case-insensitive), if
335339
the input is a dict or it is a tuple whose second element is a dict.
336340
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
337-
It can also be a callable that takes the same input
338-
as the transform, and returns the labels.
341+
342+
It can also be a callable that takes the same input as the transform, and returns either:
343+
344+
- A single tensor (the labels)
345+
- A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes.
346+
This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties
347+
from COCO.
348+
349+
If ``labels_getter`` is None then only bounding boxes are sanitized.
339350
"""
340351

341352
def __init__(
342353
self,
343354
min_size: float = 1.0,
344-
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
355+
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
345356
) -> None:
346357
super().__init__()
347358

@@ -356,18 +367,28 @@ def forward(self, *inputs: Any) -> Any:
356367
inputs = inputs if len(inputs) > 1 else inputs[0]
357368

358369
labels = self._labels_getter(inputs)
359-
if labels is not None and not isinstance(labels, torch.Tensor):
360-
raise ValueError(
361-
f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead."
362-
)
370+
if labels is not None:
371+
msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
372+
if isinstance(labels, torch.Tensor):
373+
labels = (labels,)
374+
elif isinstance(labels, (tuple, list)):
375+
for entry in labels:
376+
if not isinstance(entry, torch.Tensor):
377+
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
378+
raise ValueError(msg.format(type=type(entry)))
379+
else:
380+
raise ValueError(msg.format(type=type(labels)))
363381

364382
flat_inputs, spec = tree_flatten(inputs)
365383
boxes = get_bounding_boxes(flat_inputs)
366384

367-
if labels is not None and boxes.shape[0] != labels.shape[0]:
368-
raise ValueError(
369-
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
370-
)
385+
if labels is not None:
386+
for label in labels:
387+
if boxes.shape[0] != label.shape[0]:
388+
raise ValueError(
389+
f"Number of boxes (shape={boxes.shape}) and must match the number of labels."
390+
f"Found labels with shape={label.shape})."
391+
)
371392

372393
valid = F._misc._get_sanitize_bounding_boxes_mask(
373394
boxes,
@@ -381,7 +402,7 @@ def forward(self, *inputs: Any) -> Any:
381402
return tree_unflatten(flat_outputs, spec)
382403

383404
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
384-
is_label = inpt is not None and inpt is params["labels"]
405+
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
385406
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
386407

387408
if not (is_label or is_bounding_boxes_or_mask):
@@ -391,5 +412,5 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
391412

392413
if is_label:
393414
return output
394-
395-
return tv_tensors.wrap(output, like=inpt)
415+
else:
416+
return tv_tensors.wrap(output, like=inpt)

torchvision/transforms/v2/_utils.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numbers
55
from contextlib import suppress
66

7-
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
7+
from typing import Any, Callable, Dict, List, Literal, Sequence, Tuple, Type, Union
88

99
import PIL.Image
1010
import torch
@@ -139,9 +139,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
139139
return inputs[candidate_key]
140140

141141

142-
def _parse_labels_getter(
143-
labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None]
144-
) -> Callable[[Any], Optional[torch.Tensor]]:
142+
def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) -> Callable[[Any], Any]:
145143
if labels_getter == "default":
146144
return _find_labels_default_heuristic
147145
elif callable(labels_getter):

0 commit comments

Comments
 (0)