Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c56148c
Initial Proposal
SamuelGabriel Jul 29, 2021
f4552ed
Tensor Save Test + Test Name Fix
SamuelGabriel Jul 29, 2021
25968e6
Formatting + removing unused argument
SamuelGabriel Jul 29, 2021
2feff4f
fix old argument
SamuelGabriel Jul 29, 2021
58c7ba8
fix isnan check error + indexing error with jit
SamuelGabriel Jul 29, 2021
33d5d59
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Jul 29, 2021
a015611
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 2, 2021
7a7a739
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 17, 2021
406848a
Fix Flake8 error.
SamuelGabriel Aug 17, 2021
f743481
Fix MyPy error.
SamuelGabriel Aug 17, 2021
19d8696
Fix Flake8 error.
SamuelGabriel Aug 17, 2021
1ed1568
Fix PyTorch JIT error in UnitTests due to type annotation.
SamuelGabriel Aug 17, 2021
536446e
Merge branch 'master' into trivialaugment_implementation
SamuelGabriel Aug 17, 2021
942fb66
Merge branch 'master' into trivialaugment_implementation
datumbox Aug 17, 2021
16784a1
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 26, 2021
c8fb6c7
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 26, 2021
2fc8633
Fixing tests.
datumbox Aug 26, 2021
729c0db
Removing type ignore.
datumbox Aug 26, 2021
d02100a
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
datumbox Aug 26, 2021
83552c6
Adding support of ta_wide in references.
datumbox Aug 26, 2021
cd6a75e
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 27, 2021
1fe25fb
Move methods in classes.
datumbox Aug 27, 2021
226998c
Moving new classes to the bottom.
datumbox Aug 27, 2021
425c52d
Specialize to TA to TAwide
datumbox Aug 27, 2021
fa8a6d5
Merge branch 'main' into trivialaugment_implementation
datumbox Aug 31, 2021
7483dbc
Merge branch 'main' into SamuelGabriel_trivialaugment_implementation
datumbox Sep 2, 2021
bd2dc17
Add missing type
datumbox Sep 2, 2021
0087be0
Merge branch 'main' into trivialaugment_implementation
datumbox Sep 2, 2021
5770a03
Fixing lint
datumbox Sep 2, 2021
46f886c
Fix doc
datumbox Sep 2, 2021
2933667
Merge branch 'main' into trivialaugment_implementation
SamuelGabriel Sep 6, 2021
30bbae9
Fix search space of TrivialAugment.
SamuelGabriel Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,6 +1490,19 @@ def test_autoaugment(policy, fill):
transform.__repr__()


@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide'])
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
def test_trivialaugment(augmentation_space, fill, num_magnitude_bins):
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.TrivialAugment(augmentation_space=augmentation_space,
fill=fill, num_magnitude_bins=num_magnitude_bins)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
22 changes: 22 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,13 +541,35 @@ def test_autoaugment(device, policy, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('augmentation_space', ['aa', 'ta_wide'])
@pytest.mark.parametrize('fill', [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1])
def test_trivialaugment(device, augmentation_space, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

s_transform = None
transform = T.TrivialAugment(augmentation_space=augmentation_space, fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


def test_autoaugment_save():
transform = T.AutoAugment()
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))


def test_trivialaugment_save():
transform = T.TrivialAugment()
s_transform = torch.jit.script(transform)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize(
'config', [
Expand Down
190 changes: 130 additions & 60 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "TrivialAugment"]


class AutoAugmentPolicy(Enum):
Expand Down Expand Up @@ -108,25 +108,132 @@ def _get_transforms( # type: ignore[return]
]


def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]:
_BINS = 10
return {
def _get_magnitudes(
augmentation_space: str, image_size: List[int], num_bins: int = 10
) -> Dict[str, Tuple[Tensor, bool]]:
if augmentation_space == 'aa':
shear_max = 0.3
translate_max_x = 150.0 / 331.0 * image_size[0]
translate_max_y = 150.0 / 331.0 * image_size[1]
rotate_max = 30.0
enhancer_max = 0.9
posterize_min_bits = 4

elif augmentation_space == 'ta_wide':
shear_max = 0.99
translate_max_x = 32.0 # this is an absolute
translate_max_y = 32.0 # this is an absolute
rotate_max = 135.0
enhancer_max = 0.99
posterize_min_bits = 2
else:
raise ValueError(f"Provided augmentation_space arguments {augmentation_space} not available.")

magnitudes = {
# name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
"ShearX": (torch.linspace(0.0, shear_max, num_bins), True),
"ShearY": (torch.linspace(0.0, shear_max, num_bins), True),
"TranslateX": (torch.linspace(0.0, translate_max_x, num_bins), True),
"TranslateY": (torch.linspace(0.0, translate_max_y, num_bins), True),
"Rotate": (torch.linspace(0.0, rotate_max, num_bins), True),
"Brightness": (torch.linspace(0.0, enhancer_max, num_bins), True),
"Color": (torch.linspace(0.0, enhancer_max, num_bins), True),
"Contrast": (torch.linspace(0.0, enhancer_max, num_bins), True),
"Sharpness": (torch.linspace(0.0, enhancer_max, num_bins), True),
"Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / (8 - posterize_min_bits))).round().int(), False),
"Solarize": (torch.linspace(256.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(float('nan')), False),
"Equalize": (torch.tensor(float('nan')), False),
"Invert": (torch.tensor(float('nan')), False),
}
return magnitudes


def apply_aug(img: Tensor, op_name: str, magnitude: float,
interpolation: InterpolationMode, fill: Optional[List[float]]):
if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=interpolation, fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0,
interpolation=interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img


class TrivialAugment(torch.nn.Module):
r"""Dataset-independent data-augmentation with TrivialAugment, as described in
`"TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation" <https://arxiv.org/abs/2103.10158>`.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Args:
augmentation_space (str): A string defining which augmentation space to use.
The augmentation space can either set to be the one used for AutoAugment (`aa`)
or to the strongest augmentation space from the TrivialAugment paper (`ta_wide`).
num_magnitude_bins (int): The number of different magnitude values.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(self, augmentation_space: str = 'ta_wide', num_magnitude_bins: int = 30,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[List[float]] = None):
super().__init__()
self.augmentation_space = augmentation_space
self.num_magnitude_bins = num_magnitude_bins
self.interpolation = interpolation
self.fill = fill

def forward(self, img: Tensor):
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]

op_meta = _get_magnitudes(self.augmentation_space, F._get_image_size(img), num_bins=self.num_magnitude_bins)
op_index = torch.randint(len(op_meta), (1,))
op_name = list(op_meta.keys())[op_index.item()] # type: ignore[index]
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[torch.randint(len(magnitudes), (1,), dtype=torch.long)].item()) \
if not magnitudes.isnan().all() else 0.0
if signed and torch.randint(2, (1,)):
magnitude *= -1.0

return apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)


class AutoAugment(torch.nn.Module):
Expand Down Expand Up @@ -160,7 +267,6 @@ def __init__(
self.transforms = _get_transforms(policy)
if self.transforms is None:
raise ValueError("The provided policy {} is not recognized.".format(policy))
self._op_meta = _get_magnitudes()

@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
Expand All @@ -175,9 +281,6 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:

return policy_id, probs, signs

def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]

def forward(self, img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.
Expand All @@ -196,46 +299,13 @@ def forward(self, img: Tensor) -> Tensor:

for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
op_meta = _get_magnitudes('aa', F._get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
if not magnitudes.isnan().all() and magnitude_id is not None else 0.0
if signed and signs[i] == 0:
magnitude *= -1.0

if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=self.interpolation, fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
img = apply_aug(img, op_name, magnitude, interpolation=self.interpolation, fill=fill)

return img

Expand Down