Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5e52ed2
add prototype transforms that don't need dispatchers
pmeier Feb 14, 2022
225ed47
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 14, 2022
fe82e94
cleanup
pmeier Feb 14, 2022
36f3e0d
remove legacy_transform decorator
pmeier Feb 15, 2022
757fbed
remove legacy classes
pmeier Feb 15, 2022
0ca3800
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 15, 2022
dc61271
remove explicit param passing
pmeier Feb 15, 2022
c7c4608
streamline extra_repr
pmeier Feb 15, 2022
13d49cb
remove obsolete ._supports() method
pmeier Feb 15, 2022
4771e25
cleanup
pmeier Feb 15, 2022
fb2077f
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 16, 2022
c393a43
remove Query
pmeier Feb 16, 2022
e7502ed
cleanup
pmeier Feb 16, 2022
fd752a6
fix tests
pmeier Feb 16, 2022
ea71c2c
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 17, 2022
283c474
kernels -> functional
pmeier Feb 21, 2022
b3c0452
move image size and num channels extraction to functional
pmeier Feb 21, 2022
c129dea
extend legacy function to extract image size and num channels
pmeier Feb 21, 2022
9b18c28
implement dispatching for auto augment
pmeier Feb 22, 2022
3348f89
fix auto augment dispatch
pmeier Feb 22, 2022
25e2ec0
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 22, 2022
90f9fa7
revert some naming changes
pmeier Feb 22, 2022
ddf28d2
remove ability to pass params to autoaugment
pmeier Feb 22, 2022
68bbb2b
fix legacy image size extraction
pmeier Feb 22, 2022
aa9f912
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 23, 2022
1587588
align prototype.transforms.functional with transforms.functional
pmeier Feb 24, 2022
41be83c
Merge branch 'transforms-without-dispatcher' of https://github.com/pm…
pmeier Feb 24, 2022
9fc2693
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 24, 2022
ab79215
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 24, 2022
7826ab3
cleanup
pmeier Feb 25, 2022
ced8bcf
fix image size and channels extraction
pmeier Feb 25, 2022
0017807
fix affine and rotate
pmeier Feb 25, 2022
23955b6
Merge branch 'transforms-without-dispatcher' of https://github.com/pm…
pmeier Feb 25, 2022
ed32288
Merge branch 'main' into transforms-without-dispatcher
pmeier Feb 25, 2022
71e4c56
revert image size to (width, height)
pmeier Feb 25, 2022
0943de0
Minor corrections
datumbox Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 14 additions & 47 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import itertools

import PIL.Image
import pytest
import torch
from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels
Expand All @@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs):
yield bounding_box.data


INPUT_CREATIONS_FNS = {
features.Image: make_images,
features.BoundingBox: make_bounding_boxes,
features.OneHotLabel: make_one_hot_labels,
torch.Tensor: make_vanilla_tensor_images,
PIL.Image.Image: make_pil_images,
}


def parametrize(transforms_with_inputs):
return pytest.mark.parametrize(
("transform", "input"),
Expand All @@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs):
def parametrize_from_transforms(*transforms):
transforms_with_inputs = []
for transform in transforms:
dispatcher = transform._DISPATCHER
if dispatcher is None:
continue

for type_ in dispatcher._kernels:
for creation_fn in [
make_images,
make_bounding_boxes,
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
]:
inputs = list(creation_fn())
try:
inputs = INPUT_CREATIONS_FNS[type_]()
except KeyError:
output = transform(inputs[0])
except Exception:
continue
else:
if output is inputs[0]:
continue

transforms_with_inputs.append((transform, inputs))

Expand All @@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms):

class TestSmoke:
@parametrize_from_transforms(
transforms.RandomErasing(),
transforms.RandomErasing(p=1.0),
transforms.HorizontalFlip(),
transforms.Resize([16, 16]),
transforms.CenterCrop([16, 16]),
Expand Down Expand Up @@ -141,35 +137,6 @@ def test_auto_augment(self, transform, input):
def test_normalize(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertColorSpace("grayscale"),
itertools.chain(
make_images(),
make_vanilla_tensor_images(color_spaces=["rgb"]),
make_pil_images(color_spaces=["rgb"]),
),
)
]
)
def test_convert_bounding_color_space(self, transform, input):
transform(input)

@parametrize(
[
(
transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"),
itertools.chain(
make_bounding_boxes(),
make_vanilla_tensor_bounding_boxes(formats=["xywh"]),
),
)
]
)
def test_convert_bounding_box_format(self, transform, input):
transform(input)

@parametrize(
[
(
Expand Down
5 changes: 3 additions & 2 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torchvision.transforms import AutoAugmentPolicy, InterpolationMode # usort: skip
from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip

from . import kernels # usort: skip
from . import functional # usort: skip

from ._transform import Transform # usort: skip

from ._augment import RandomErasing, RandomMixup, RandomCutmix
Expand Down
61 changes: 44 additions & 17 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import PIL.Image
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F
from torchvision.prototype.transforms import Transform, kernels as K
from torchvision.transforms import functional as _F

from ._utils import query_image
from ._utils import query_image, get_image_size, get_image_num_channels


class RandomErasing(Transform):
_DISPATCHER = F.erase
_FAIL_TYPES = {PIL.Image.Image, features.BoundingBox, features.SegmentationMask}

def __init__(
self,
p: float = 0.5,
Expand Down Expand Up @@ -45,8 +43,8 @@ def __init__(

def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
img_h, img_w = F.get_image_size(image)
img_c = F.get_image_num_channels(image)
img_c = get_image_num_channels(image)
img_h, img_w = get_image_size(image)

if isinstance(self.value, (int, float)):
value = [self.value]
Expand All @@ -59,7 +57,8 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

if value is not None and not (len(value) in (1, img_c)):
raise ValueError(
f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)"
"If value is a sequence, it should have either a single value or "
f"{image.shape[-3]} (number of input channels)"
)

area = img_h * img_w
Expand Down Expand Up @@ -93,16 +92,23 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(zip("ijhwv", (i, j, h, w, v)))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if torch.rand(1) >= self.p:
if type(input) is torch.Tensor:
return _F.erase(input, **params)
elif type(input) is features.Image:
return features.Image.new_like(input, K.erase_image(input, **params))
elif type(input) in {features.BoundingBox, features.SegmentationMask} or isinstance(input, PIL.Image.Image):
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input

return super()._transform(input, params)
def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]

return super().forward(*inputs)


class RandomMixup(Transform):
_DISPATCHER = F.mixup
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}

def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
Expand All @@ -111,11 +117,20 @@ def __init__(self, *, alpha: float) -> None:
def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Image:
output = K.mixup_image(input, **params)
return features.Image.new_like(input, output)
elif type(input) is features.OneHotLabel:
output = K.mixup_one_hot_label(input, **params)
return features.OneHotLabel.new_like(input, output)
elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input

class RandomCutmix(Transform):
_DISPATCHER = F.cutmix
_FAIL_TYPES = {features.BoundingBox, features.SegmentationMask}

class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None:
super().__init__()
self.alpha = alpha
Expand All @@ -125,7 +140,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam = float(self._dist.sample(()))

image = query_image(sample)
H, W = F.get_image_size(image)
H, W = get_image_size(image)

r_x = torch.randint(W, ())
r_y = torch.randint(H, ())
Expand All @@ -143,3 +158,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if type(input) is features.Image:
output = K.cutmix_image(input, box=params["box"])
return features.Image.new_like(input, output)
elif type(input) is features.OneHotLabel:
output = K.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}:
raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()")
else:
return input
Loading