diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index bca8925c71d..b5a8c9ead7e 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2021,6 +2021,9 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): assert out_image is input_img assert out_whatever is whatever + assert isinstance(out_boxes, datapoints.BoundingBox) + assert isinstance(out_masks, datapoints.Mask) + if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): assert out_labels is labels else: diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index c9b9025ebd9..8f063fd6049 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -397,10 +397,15 @@ def forward(self, *inputs: Any) -> Any: return tree_unflatten(flat_outputs, spec) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + is_label = inpt is not None and inpt is params["labels"] + is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask)) - if (inpt is not None and inpt is params["labels"]) or isinstance( - inpt, (datapoints.BoundingBox, datapoints.Mask) - ): - inpt = inpt[params["valid"]] + if not (is_label or is_bounding_box_or_mask): + return inpt - return inpt + output = inpt[params["valid"]] + + if is_label: + return output + + return type(inpt).wrap_like(inpt, output)