diff --git a/test/common_utils.py b/test/common_utils.py index 8c3c9dd58a8..6bd585d394d 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -188,7 +188,12 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None): def _assert_approx_equal_tensor_to_pil( - tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None + tensor, + pil_image, + tol=1e-5, + msg=None, + agg_method="mean", + allowed_percentage_diff=None, ): # FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it # TODO: we could just merge this into _assert_equal_tensor_to_pil @@ -284,8 +289,29 @@ def __init__( mae=False, **other_parameters, ): - if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]): - actual, expected = (to_image(input) for input in [actual, expected]) + # Convert PIL images to tv_tensors.Image (regardless of what the other is) + if isinstance(actual, PIL.Image.Image): + actual = to_image(actual) + if isinstance(expected, PIL.Image.Image): + expected = to_image(expected) + + # Convert CV-CUDA tensors to torch.Tensor (regardless of what the other is) + try: + import cvcuda + from torchvision.transforms.v2.functional import cvcuda_to_tensor + + if isinstance(actual, cvcuda.Tensor): + actual = cvcuda_to_tensor(actual) + # Remove batch dimension if it's 1 for easier comparison + if actual.shape[0] == 1: + actual = actual[0] + if isinstance(expected, cvcuda.Tensor): + expected = cvcuda_to_tensor(expected) + # Remove batch dimension if it's 1 for easier comparison + if expected.shape[0] == 1: + expected = expected[0] + except ImportError: + pass super().__init__(actual, expected, **other_parameters) self.mae = mae @@ -400,8 +426,8 @@ def make_image_pil(*args, **kwargs): return to_pil_image(make_image(*args, **kwargs)) -def make_image_cvcuda(*args, **kwargs): - return to_cvcuda_tensor(make_image(*args, **kwargs)) +def make_image_cvcuda(*args, batch_dims=(1,), **kwargs): + return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs)) def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"): @@ -541,5 +567,9 @@ def ignore_jit_no_profile_information_warning(): # with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore # them. with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning) + warnings.filterwarnings( + "ignore", + message=re.escape("operator() profile_node %"), + category=UserWarning, + ) yield diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 670a9d00ffb..1a21c08013a 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -1240,6 +1240,10 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1255,6 +1259,11 @@ def test_functional(self, make_input): (F.horizontal_flip_image, torch.Tensor), (F._geometry._horizontal_flip_image_pil, PIL.Image.Image), (F.horizontal_flip_image, tv_tensors.Image), + pytest.param( + F._geometry._horizontal_flip_image_cvcuda, + cvcuda.Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), (F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.horizontal_flip_mask, tv_tensors.Mask), (F.horizontal_flip_video, tv_tensors.Video), @@ -1270,6 +1279,10 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1283,13 +1296,24 @@ def test_transform(self, make_input, device): @pytest.mark.parametrize( "fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)] ) - def test_image_correctness(self, fn): - image = make_image(dtype=torch.uint8, device="cpu") - + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), + ], + ) + def test_image_correctness(self, fn, make_input): + image = make_input() actual = fn(image) - expected = F.to_image(F.horizontal_flip(F.to_pil_image(image))) - - torch.testing.assert_close(actual, expected) + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + expected = F.horizontal_flip(F.to_pil_image(image)) + # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking + assert_equal(actual, expected, check_device=False) def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1345,6 +1369,10 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1354,12 +1382,12 @@ def test_keypoints_correctness(self, fn): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform_noop(self, make_input, device): input = make_input(device=device) - transform = transforms.RandomHorizontalFlip(p=0) - output = transform(input) - - assert_equal(output, input) + if isinstance(input, cvcuda.Tensor): + assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) + else: + assert_equal(output, input) class TestAffine: @@ -1856,6 +1884,10 @@ def test_kernel_video(self): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1871,6 +1903,11 @@ def test_functional(self, make_input): (F.vertical_flip_image, torch.Tensor), (F._geometry._vertical_flip_image_pil, PIL.Image.Image), (F.vertical_flip_image, tv_tensors.Image), + pytest.param( + F._geometry._vertical_flip_image_cvcuda, + cvcuda.Tensor, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), (F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes), (F.vertical_flip_mask, tv_tensors.Mask), (F.vertical_flip_video, tv_tensors.Video), @@ -1886,6 +1923,10 @@ def test_functional_signature(self, kernel, input_type): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1897,13 +1938,24 @@ def test_transform(self, make_input, device): check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device)) @pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)]) - def test_image_correctness(self, fn): - image = make_image(dtype=torch.uint8, device="cpu") - + @pytest.mark.parametrize( + "make_input", + [ + make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), + ], + ) + def test_image_correctness(self, fn, make_input): + image = make_input() actual = fn(image) - expected = F.to_image(F.vertical_flip(F.to_pil_image(image))) - - torch.testing.assert_close(actual, expected) + if make_input is make_image_cvcuda: + image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W] + expected = F.vertical_flip(F.to_pil_image(image)) + # CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking + assert_equal(actual, expected, check_device=False) def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes): affine_matrix = np.array( @@ -1955,6 +2007,10 @@ def test_keypoints_correctness(self, fn): make_image_tensor, make_image_pil, make_image, + pytest.param( + make_image_cvcuda, + marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"), + ), make_bounding_boxes, make_segmentation_mask, make_video, @@ -1964,12 +2020,12 @@ def test_keypoints_correctness(self, fn): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_transform_noop(self, make_input, device): input = make_input(device=device) - transform = transforms.RandomVerticalFlip(p=0) - output = transform(input) - - assert_equal(output, input) + if isinstance(input, cvcuda.Tensor): + assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input)) + else: + assert_equal(output, input) class TestRotate: @@ -7101,7 +7157,7 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t out = t(sample) - assert type(out) == type(sample) + assert type(out) is type(sample) if dataset_return_type is tuple: out_image, out_label = out @@ -7412,7 +7468,7 @@ def test_functional(self, input_type): boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size) assert_equal(valid, torch.tensor(expected_valid_mask)) - assert type(valid) == torch.Tensor + assert type(valid) is torch.Tensor assert boxes.shape[0] == sum(valid) assert isinstance(boxes, input_type) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 1418a6b4953..7bb17aa7f41 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,7 +11,7 @@ from torchvision.ops.boxes import box_iou from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform -from torchvision.transforms.v2.functional._utils import _FillType +from torchvision.transforms.v2.functional._utils import _FillType, _import_cvcuda, _is_cvcuda_available from ._transform import _RandomApplyTransform from ._utils import ( @@ -30,6 +30,10 @@ query_size, ) +CVCUDA_AVAILABLE = _is_cvcuda_available() +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() + class RandomHorizontalFlip(_RandomApplyTransform): """Horizontally flip the input with a given probability. @@ -45,6 +49,9 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.horizontal_flip, inpt) @@ -63,6 +70,9 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip + if CVCUDA_AVAILABLE: + _transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor) + def transform(self, inpt: Any, params: dict[str, Any]) -> Any: return self._call_kernel(F.vertical_flip, inpt) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index ac84fcb6c82..610e7d7e83b 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -13,6 +13,8 @@ from .functional._utils import _get_kernel +CVCUDA_AVAILABLE = _is_cvcuda_available() + class Transform(nn.Module): """Base class to implement your own v2 transforms. diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4fcb7fabe0d..d0e76cdc358 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -2,7 +2,7 @@ import numbers import warnings from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import PIL.Image import torch @@ -26,7 +26,20 @@ from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format -from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal +from ._utils import ( + _FillTypeJIT, + _get_kernel, + _import_cvcuda, + _is_cvcuda_available, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, +) + +CVCUDA_AVAILABLE = _is_cvcuda_available() +if TYPE_CHECKING: + import cvcuda +if CVCUDA_AVAILABLE: + cvcuda = _import_cvcuda() def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -62,6 +75,16 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.hflip(image) +def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=1) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)( + _horizontal_flip_image_cvcuda + ) + + @_register_kernel_internal(horizontal_flip, tv_tensors.Mask) def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor: return horizontal_flip_image(mask) @@ -150,6 +173,16 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image: return _FP.vflip(image) +def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor": + return _import_cvcuda().flip(image, flipCode=0) + + +if CVCUDA_AVAILABLE: + _register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)( + _vertical_flip_image_cvcuda + ) + + @_register_kernel_internal(vertical_flip, tv_tensors.Mask) def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor: return vertical_flip_image(mask)