Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def rotate_bounding_box():
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and "pil" not in name
and name
not in {
"to_image_tensor",
}
],
)
def test_scriptable(kernel):
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,5 @@
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._type_conversion import DecodeImage, LabelToOneHot

from ._deprecated import ToTensor, ToPILImage, PILToTensor # usort: skip
40 changes: 40 additions & 0 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Any, Dict, Optional

import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F

from ._utils import is_simple_tensor


# TODO: add deprecation warning
class ToTensor(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (PIL.Image.Image, np.ndarray)):
return _F.to_tensor(input)
else:
return input


# TODO: add deprecation warning
class PILToTensor(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, PIL.Image.Image):
return _F.pil_to_tensor(input)
else:
return input


# TODO: add deprecation warning
class ToPILImage(Transform):
def __init__(self, mode: Optional[str] = None) -> None:
super().__init__()
self.mode = mode

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if is_simple_tensor(input) or isinstance(input, (features.Image, np.ndarray)):
return _F.to_pil_image(input, mode=self.mode)
else:
return input
29 changes: 29 additions & 0 deletions torchvision/prototype/transforms/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Any, Dict

import numpy as np
import PIL.Image
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import is_simple_tensor


class DecodeImage(Transform):
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -33,3 +37,28 @@ def extra_repr(self) -> str:
return ""

return f"num_categories={self.num_categories}"


class ToImageTensor(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
output = F.to_image_tensor(input, copy=self.copy)
return features.Image(output)
else:
return input


class ToImagePIL(Transform):
def __init__(self, *, copy: bool = False) -> None:
super().__init__()
self.copy = copy

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.Image, PIL.Image.Image, np.ndarray)) or is_simple_tensor(input):
return F.to_image_pil(input, copy=self.copy)
else:
return input
8 changes: 7 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,10 @@
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
from ._type_conversion import (
decode_image_with_pil,
decode_video_with_av,
label_to_one_hot,
to_image_tensor,
to_image_pil,
)
23 changes: 22 additions & 1 deletion torchvision/prototype/transforms/functional/_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import unittest.mock
from typing import Dict, Any, Tuple
from typing import Dict, Any, Tuple, Union

import numpy as np
import PIL.Image
import torch
from torch.nn.functional import one_hot
from torchvision.io.video import read_video
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
from torchvision.transforms import functional as _F


def decode_image_with_pil(encoded_image: torch.Tensor) -> torch.Tensor:
Expand All @@ -23,3 +24,23 @@ def decode_video_with_av(encoded_video: torch.Tensor) -> Tuple[torch.Tensor, tor

def label_to_one_hot(label: torch.Tensor, *, num_categories: int) -> torch.Tensor:
return one_hot(label, num_classes=num_categories) # type: ignore[no-any-return]


def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> torch.Tensor:
if isinstance(image, torch.Tensor):
if copy:
return image.clone()
else:
return image

return _F.to_tensor(image)


def to_image_pil(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray], copy: bool = False) -> PIL.Image.Image:
if isinstance(image, PIL.Image.Image):
if copy:
return image.copy()
else:
return image

return _F.to_pil_image(to_image_tensor(image, copy=False))
2 changes: 1 addition & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _is_numpy_image(img: Any) -> bool:
return img.ndim in {2, 3}


def to_tensor(pic):
def to_tensor(pic) -> Tensor:
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
This function does not support torchscript.

Expand Down