Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,12 @@ def _assert_equal_tensor_to_pil(tensor, pil_image, msg=None):


def _assert_approx_equal_tensor_to_pil(
tensor, pil_image, tol=1e-5, msg=None, agg_method="mean", allowed_percentage_diff=None
tensor,
pil_image,
tol=1e-5,
msg=None,
agg_method="mean",
allowed_percentage_diff=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above: this looks like a cosmetic change? Try to avoid that and revert the previous state, it's distracting when reviewing and it also affects git blame unnecessarily :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this cosmetic change happens after I use pre-commit hooks which run both ufmt and flake8 as described here

):
# FIXME: this is handled automatically by `assert_close` below. Let's remove this in favor of it
# TODO: we could just merge this into _assert_equal_tensor_to_pil
Expand Down Expand Up @@ -284,8 +289,29 @@ def __init__(
mae=False,
**other_parameters,
):
if all(isinstance(input, PIL.Image.Image) for input in [actual, expected]):
actual, expected = (to_image(input) for input in [actual, expected])
# Convert PIL images to tv_tensors.Image (regardless of what the other is)
if isinstance(actual, PIL.Image.Image):
actual = to_image(actual)
if isinstance(expected, PIL.Image.Image):
expected = to_image(expected)
Comment on lines +292 to +296
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that the above is a change of behavior: we used to convert both inputs, or none of them, which means we'd error when comparing a tensor to a PIL image. We now accept it. I think that's fine.

@justincdavis thanks for sharing your implementation in https://github.com/pytorch/vision/pull/9284/files#diff-f833e6bb21df531837a51e84306266c2f1f6d1565340a498095868e24b3f27de. I think you were being slightly more conservative in your added logic. Was there any particular edge-case you had in mind that you wanted to guard against?


# Convert CV-CUDA tensors to torch.Tensor (regardless of what the other is)
try:
import cvcuda
from torchvision.transforms.v2.functional import cvcuda_to_tensor
Comment on lines +299 to +301
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the _is_cvcuda_available() instead of a try/except. You can then call cvcuda = _import_cvcuda() which is safer and less surprising than a raw import cvcuda.


if isinstance(actual, cvcuda.Tensor):
actual = cvcuda_to_tensor(actual)
# Remove batch dimension if it's 1 for easier comparison
if actual.shape[0] == 1:
actual = actual[0]
Comment on lines +305 to +307
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems unnecessary, we should be able to compare tensors where the batch dim is 1. Try to remove it, if it doesn't work for any reason let me know.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: ah, OK, it's for when we compare a 3D PIL image to a 4D cvcuda tensor. That's... fine. Let's explain why then (addition in bold):

Remove batch dimension if it's 1 for easier comparison against 3D PIL images

if isinstance(expected, cvcuda.Tensor):
expected = cvcuda_to_tensor(expected)
# Remove batch dimension if it's 1 for easier comparison
if expected.shape[0] == 1:
expected = expected[0]
except ImportError:
pass

super().__init__(actual, expected, **other_parameters)
self.mae = mae
Expand Down Expand Up @@ -400,8 +426,8 @@ def make_image_pil(*args, **kwargs):
return to_pil_image(make_image(*args, **kwargs))


def make_image_cvcuda(*args, **kwargs):
return to_cvcuda_tensor(make_image(*args, **kwargs))
def make_image_cvcuda(*args, batch_dims=(1,), **kwargs):
return to_cvcuda_tensor(make_image(*args, batch_dims=batch_dims, **kwargs))


def make_keypoints(canvas_size=DEFAULT_SIZE, *, num_points=4, dtype=None, device="cpu"):
Expand Down Expand Up @@ -541,5 +567,9 @@ def ignore_jit_no_profile_information_warning():
# with varying `INT1` and `INT2`. Since these are uninteresting for us and only clutter the test summary, we ignore
# them.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=re.escape("operator() profile_node %"), category=UserWarning)
warnings.filterwarnings(
"ignore",
message=re.escape("operator() profile_node %"),
category=UserWarning,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this looks like a cosmetic change?

yield
100 changes: 78 additions & 22 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1255,6 +1259,11 @@ def test_functional(self, make_input):
(F.horizontal_flip_image, torch.Tensor),
(F._geometry._horizontal_flip_image_pil, PIL.Image.Image),
(F.horizontal_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._horizontal_flip_image_cvcuda,
cvcuda.Tensor,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.horizontal_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.horizontal_flip_mask, tv_tensors.Mask),
(F.horizontal_flip_video, tv_tensors.Video),
Expand All @@ -1270,6 +1279,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1283,13 +1296,24 @@ def test_transform(self, make_input, device):
@pytest.mark.parametrize(
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
)
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.horizontal_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W]
expected = F.horizontal_flip(F.to_pil_image(image))
# CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking
assert_equal(actual, expected, check_device=False)
Comment on lines +1315 to +1316
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not do that, it weakens the check. It's safer and more explicit to just call .cpu() on image above:

image = F.cvcuda_to_tensor(image)[0].cpu()


def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1345,6 +1369,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1354,12 +1382,12 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomHorizontalFlip(p=0)

output = transform(input)

assert_equal(output, input)
if isinstance(input, cvcuda.Tensor):
assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input))
else:
assert_equal(output, input)


class TestAffine:
Expand Down Expand Up @@ -1856,6 +1884,10 @@ def test_kernel_video(self):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1871,6 +1903,11 @@ def test_functional(self, make_input):
(F.vertical_flip_image, torch.Tensor),
(F._geometry._vertical_flip_image_pil, PIL.Image.Image),
(F.vertical_flip_image, tv_tensors.Image),
pytest.param(
F._geometry._vertical_flip_image_cvcuda,
cvcuda.Tensor,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
(F.vertical_flip_bounding_boxes, tv_tensors.BoundingBoxes),
(F.vertical_flip_mask, tv_tensors.Mask),
(F.vertical_flip_video, tv_tensors.Video),
Expand All @@ -1886,6 +1923,10 @@ def test_functional_signature(self, kernel, input_type):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1897,13 +1938,24 @@ def test_transform(self, make_input, device):
check_transform(transforms.RandomVerticalFlip(p=1), make_input(device=device))

@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
def test_image_correctness(self, fn):
image = make_image(dtype=torch.uint8, device="cpu")

@pytest.mark.parametrize(
"make_input",
[
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
],
)
def test_image_correctness(self, fn, make_input):
image = make_input()
actual = fn(image)
expected = F.to_image(F.vertical_flip(F.to_pil_image(image)))

torch.testing.assert_close(actual, expected)
if make_input is make_image_cvcuda:
image = F.cvcuda_to_tensor(image)[0] # Remove batch dimension: [1, C, H, W] -> [C, H, W]
expected = F.vertical_flip(F.to_pil_image(image))
# CV-CUDA tensors are on CUDA, PIL images are on CPU, so disable device checking
assert_equal(actual, expected, check_device=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

above: same comments


def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
affine_matrix = np.array(
Expand Down Expand Up @@ -1955,6 +2007,10 @@ def test_keypoints_correctness(self, fn):
make_image_tensor,
make_image_pil,
make_image,
pytest.param(
make_image_cvcuda,
marks=pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="CVCUDA is not available"),
),
make_bounding_boxes,
make_segmentation_mask,
make_video,
Expand All @@ -1964,12 +2020,12 @@ def test_keypoints_correctness(self, fn):
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_transform_noop(self, make_input, device):
input = make_input(device=device)

transform = transforms.RandomVerticalFlip(p=0)

output = transform(input)

assert_equal(output, input)
if isinstance(input, cvcuda.Tensor):
assert_equal(F.cvcuda_to_tensor(output), F.cvcuda_to_tensor(input))
else:
assert_equal(output, input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be no need to manually convert anymore, assert_equal should be able to handle it. Also don't make raw cvcuda accesses! It would force a hard dependency :)



class TestRotate:
Expand Down Expand Up @@ -7101,7 +7157,7 @@ def test_classification_preset(image_type, label_type, dataset_return_type, to_t

out = t(sample)

assert type(out) == type(sample)
assert type(out) is type(sample)

if dataset_return_type is tuple:
out_image, out_label = out
Expand Down Expand Up @@ -7412,7 +7468,7 @@ def test_functional(self, input_type):
boxes, valid = F.sanitize_bounding_boxes(boxes, format=format, canvas_size=canvas_size, min_size=min_size)

assert_equal(valid, torch.tensor(expected_valid_mask))
assert type(valid) == torch.Tensor
assert type(valid) is torch.Tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert both above, they seem unrelated and not needed.

assert boxes.shape[0] == sum(valid)
assert isinstance(boxes, input_type)

Expand Down
12 changes: 11 additions & 1 deletion torchvision/transforms/v2/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torchvision.ops.boxes import box_iou
from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform
from torchvision.transforms.v2.functional._utils import _FillType
from torchvision.transforms.v2.functional._utils import _FillType, _import_cvcuda, _is_cvcuda_available

from ._transform import _RandomApplyTransform
from ._utils import (
Expand All @@ -30,6 +30,10 @@
query_size,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()


class RandomHorizontalFlip(_RandomApplyTransform):
"""Horizontally flip the input with a given probability.
Expand All @@ -45,6 +49,9 @@ class RandomHorizontalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomHorizontalFlip

if CVCUDA_AVAILABLE:
_transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor)
Comment on lines +52 to +53
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be

Suggested change
if CVCUDA_AVAILABLE:
_transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor)
_transformed_types = _RandomApplyTransform._transformed_types + (is_cvcuda_tensor, )

where is_cvcuda_tensor is @justincdavis 's implementation from:
https://github.com/pytorch/vision/pull/9283/files#diff-6e892912803ab29861746f7118e9e462a5c9eaab50ad286903e87f09f6c44ff3R175

The function is great and it allows us to avoid the big if CVCUDA_AVAILABLE.

To follow-up from the discussion in https://github.com/pytorch/vision/pull/9277/files#r2566354008 about whether this _transformed_types should be in each child transform class or in the base Transform class: I think it should be in each child transform for now (as done here):

  • we may have a to push a release where we won't have implemented cvcuda support for all transforms, so if this was in the base class then we'd be claiming cvcuda support for some transforms that actually do not support it
  • we may publish new transforms in the future which won't immediately come with cvcuda support, causing the same kind of problem.

We could revisit this eventually. It might eventually be simpler to just have that once and for all in the base Transform class. But having it in each child classes isn't much work, and it makes it easier to track which ones support cvcuda, and which ones don't.


def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.horizontal_flip, inpt)

Expand All @@ -63,6 +70,9 @@ class RandomVerticalFlip(_RandomApplyTransform):

_v1_transform_cls = _transforms.RandomVerticalFlip

if CVCUDA_AVAILABLE:
_transformed_types = (torch.Tensor, PIL.Image.Image, cvcuda.Tensor)

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
return self._call_kernel(F.vertical_flip, inpt)

Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/v2/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .functional._utils import _get_kernel

CVCUDA_AVAILABLE = _is_cvcuda_available()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be used, we can remove it.



class Transform(nn.Module):
"""Base class to implement your own v2 transforms.
Expand Down
37 changes: 35 additions & 2 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numbers
import warnings
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import PIL.Image
import torch
Expand All @@ -26,7 +26,20 @@

from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format

from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
from ._utils import (
_FillTypeJIT,
_get_kernel,
_import_cvcuda,
_is_cvcuda_available,
_register_five_ten_crop_kernel_internal,
_register_kernel_internal,
)

CVCUDA_AVAILABLE = _is_cvcuda_available()
if TYPE_CHECKING:
import cvcuda
if CVCUDA_AVAILABLE:
cvcuda = _import_cvcuda()


def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode:
Expand Down Expand Up @@ -62,6 +75,16 @@ def _horizontal_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.hflip(image)


def _horizontal_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
return _import_cvcuda().flip(image, flipCode=1)


if CVCUDA_AVAILABLE:
_register_kernel_internal(horizontal_flip, _import_cvcuda().Tensor)(
_horizontal_flip_image_cvcuda
)


@_register_kernel_internal(horizontal_flip, tv_tensors.Mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image(mask)
Expand Down Expand Up @@ -150,6 +173,16 @@ def _vertical_flip_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
return _FP.vflip(image)


def _vertical_flip_image_cvcuda(image: "cvcuda.Tensor") -> "cvcuda.Tensor":
return _import_cvcuda().flip(image, flipCode=0)


if CVCUDA_AVAILABLE:
_register_kernel_internal(vertical_flip, _import_cvcuda().Tensor)(
_vertical_flip_image_cvcuda
)


@_register_kernel_internal(vertical_flip, tv_tensors.Mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image(mask)
Expand Down