Skip to content

Commit

Permalink
[CHERRY-PICK] prevent unwrapping in SanitizeBoundingBoxes (#7446) (#7450
Browse files Browse the repository at this point in the history
)
  • Loading branch information
pmeier authored Mar 23, 2023
1 parent 2bda93b commit 83ef3a6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
3 changes: 3 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,9 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_image is input_img
assert out_whatever is whatever

assert isinstance(out_boxes, datapoints.BoundingBox)
assert isinstance(out_masks, datapoints.Mask)

if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out_labels is labels
else:
Expand Down
15 changes: 10 additions & 5 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,10 +397,15 @@ def forward(self, *inputs: Any) -> Any:
return tree_unflatten(flat_outputs, spec)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask))

if (inpt is not None and inpt is params["labels"]) or isinstance(
inpt, (datapoints.BoundingBox, datapoints.Mask)
):
inpt = inpt[params["valid"]]
if not (is_label or is_bounding_box_or_mask):
return inpt

return inpt
output = inpt[params["valid"]]

if is_label:
return output

return type(inpt).wrap_like(inpt, output)

0 comments on commit 83ef3a6

Please sign in to comment.