Skip to content
2 changes: 1 addition & 1 deletion torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
from torchvision.prototype.transforms.functional import convert_bounding_box_format

if isinstance(format, str):
format = BoundingBoxFormat[format]
format = BoundingBoxFormat.from_str(format.upper())

return BoundingBox.new_like(
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, RandomZoomOut
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
Expand Down
61 changes: 60 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
from typing import Any, Dict, List, Union, Sequence, Tuple, cast, Optional

import PIL.Image
import torch
Expand Down Expand Up @@ -168,3 +168,62 @@ def forward(self, *inputs: Any) -> Any:
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)


class RandomZoomOut(Transform):
def __init__(
self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
) -> None:
super().__init__()

if fill is None:
fill = [0.0, 0.0, 0.0]
self.fill = fill

self.side_range = side_range
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
raise ValueError(f"Invalid canvas side range provided {side_range}.")

self.p = p

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, orig_h, orig_w = get_image_dimensions(image)

r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
canvas_width = int(orig_w * r)
canvas_height = int(orig_h * r)

r = torch.rand(2)
left = int((canvas_width - orig_w) * r[0])
top = int((canvas_height - orig_h) * r[1])
right = canvas_width - (left + orig_w)
bottom = canvas_height - (top + orig_h)

return dict(left=left, top=top, right=right, bottom=bottom)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.zoom_out_image_tensor(input, **params, fill=self.fill)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor) and not isinstance(input, features._Feature):
return F.zoom_out_image_tensor(input, **params, fill=self.fill)
elif isinstance(input, PIL.Image.Image):
return F.zoom_out_image_pil(input, **params, fill=self.fill)
elif isinstance(input, features.BoundingBox):
output = F.zoom_out_bounding_box(input, **params, format=input.format)

height, width = input.image_size
height += params["top"] + params["bottom"]
width += params["left"] + params["right"]

return features.BoundingBox.new_like(input, output, image_size=(height, width))
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if torch.rand(1) >= self.p:
return sample

return super().forward(sample)
4 changes: 4 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,16 @@
rotate_image_pil,
pad_image_tensor,
pad_image_pil,
pad_bounding_box,
crop_image_tensor,
crop_image_pil,
perspective_image_tensor,
perspective_image_pil,
vertical_flip_image_tensor,
vertical_flip_image_pil,
zoom_out_image_pil,
zoom_out_image_tensor,
zoom_out_bounding_box,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
86 changes: 85 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def horizontal_flip_bounding_box(
shape = bounding_box.shape

bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY, copy=True
).view(-1, 4)

bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
Expand Down Expand Up @@ -210,6 +210,47 @@ def rotate_image_pil(
pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad


# TODO: this was copy-pasted from _FT.pad. Use this if _FT.pad is actually defined here
def _parse_pad_padding(padding: List[int]) -> List[int]:
if isinstance(padding, int):
if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1:
pad_left = pad_right = pad_top = pad_bottom = padding[0]
elif len(padding) == 2:
pad_left = pad_right = padding[0]
pad_top = pad_bottom = padding[1]
else:
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]

return [pad_left, pad_right, pad_top, pad_bottom]


def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
) -> torch.Tensor:
left, _, top, _ = _parse_pad_padding(padding)

shape = bounding_box.shape

bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY, copy=True
).view(-1, 4)

bounding_box[:, 0::2] += left
bounding_box[:, 1::2] += top

return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).view(shape)


crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop

Expand Down Expand Up @@ -314,3 +355,46 @@ def resized_crop_image_pil(
) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation)


def zoom_out_image_tensor(
image: torch.Tensor,
left: int,
top: int,
right: int,
bottom: int,
fill: List[float] = (0.0,), # type: ignore[assignment]
) -> torch.Tensor:
num_channels, height, width = get_dimensions_image_tensor(image)

# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
output = pad_image_tensor(image, [left, top, right, bottom], fill=0, padding_mode="constant")

if not isinstance(fill, (list, tuple)):
fill = [fill] * num_channels
fill = torch.tensor(fill).to().view(-1, 1, 1)

output[..., :top, :] = fill
output[..., :, :left] = fill
output[..., (top + height) :, :] = fill
output[..., :, (left + width) :] = fill

return output


def zoom_out_image_pil(
img: PIL.Image.Image,
left: int,
top: int,
right: int,
bottom: int,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> PIL.Image.Image:
fill = tuple(int(v) for v in _FP._parse_fill(fill, img, name="fill")["fill"])
return pad_image_pil(img, [left, top, right, bottom], fill=fill, padding_mode="constant")


def zoom_out_bounding_box(
bounding_box: torch.Tensor, left: int, top: int, right: int, bottom: int, format: features.BoundingBoxFormat
) -> torch.Tensor:
return pad_bounding_box(bounding_box, [left, top, right, bottom], format=format)
21 changes: 15 additions & 6 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:


def convert_bounding_box_format(
bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = False
) -> torch.Tensor:
if new_format == old_format:
return bounding_box.clone()
if copy:
return bounding_box.clone()
else:
return bounding_box

if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box)
Expand All @@ -64,10 +67,13 @@ def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor:


def convert_image_color_space_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = False
) -> torch.Tensor:
if new_color_space == old_color_space:
return image.clone()
if copy:
return image.clone()
else:
return image

if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_tensor(image)
Expand All @@ -83,10 +89,13 @@ def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image:


def convert_image_color_space_pil(
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace
image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = False
) -> PIL.Image.Image:
if new_color_space == old_color_space:
return image.copy()
if copy:
return image.copy()
else:
return image

if old_color_space == ColorSpace.GRAYSCALE:
image = _grayscale_to_rgb_pil(image)
Expand Down