Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import torch
from torch._utils_internal import get_file_path_2
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms as T
from torchvision.io import read_image
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode

Expand Down Expand Up @@ -715,3 +718,78 @@ def test_gaussian_blur(device, meth_kwargs):
T.GaussianBlur, meth_kwargs=meth_kwargs,
test_exact_match=False, device=device, agg_method="max", tol=tol
)


@pytest.mark.parametrize('device', cpu_and_gpu())
@pytest.mark.parametrize('alphas', [
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 1.0},
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'cutmix_p': 0.0},
{"mixup_alpha": 1.0, "cutmix_alpha": 1.0, 'p': 0.0},
{"mixup_alpha": 0.0, "cutmix_alpha": 1.0},
{"mixup_alpha": 1.0, "cutmix_alpha": 0.0},
])
@pytest.mark.parametrize('label_smoothing', [0.0, 0.1])
@pytest.mark.parametrize('inplace', [True, False])
def test_random_mixupcutmix(device, alphas, label_smoothing, inplace):
batch_size = 32
num_classes = 10
batch = torch.rand(batch_size, 3, 44, 56, device=device)
targets = torch.randint(num_classes, (batch_size, ), device=device, dtype=torch.int64)

fn = T.RandomMixupCutmix(num_classes, label_smoothing=label_smoothing, inplace=inplace, **alphas)
scripted_fn = torch.jit.script(fn)

seed = torch.seed()
output = fn(batch.clone(), targets.clone())

torch.manual_seed(seed)
output_scripted = scripted_fn(batch.clone(), targets.clone())
assert_equal(output[0], output_scripted[0])
assert_equal(output[1], output_scripted[1])

fn.__repr__()


def test_random_mixupcutmix_with_invalid_data():
with pytest.raises(AssertionError, match="Please provide a valid positive value for the num_classes."):
T.RandomMixupCutmix(0)
with pytest.raises(AssertionError, match="Both alpha params can't be zero."):
T.RandomMixupCutmix(10, mixup_alpha=0.0, cutmix_alpha=0.0)

t = T.RandomMixupCutmix(10)
with pytest.raises(ValueError, match="Batch ndim should be 4."):
t(torch.rand(3, 60, 60), torch.randint(10, (1, )))
with pytest.raises(ValueError, match="Target ndim should be 1."):
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, 1)))
with pytest.raises(ValueError, match="Target dtype should be torch.int64."):
t(torch.rand(32, 3, 60, 60), torch.randint(10, (32, ), dtype=torch.int32))
with pytest.raises(ValueError, match="The batch size should be even."):
t(torch.rand(31, 3, 60, 60), torch.randint(10, (31, )))


def test_random_mixupcutmix_with_real_data():
torch.manual_seed(12)

# Build dummy dataset
images = []
for test_file in [("encode_jpeg", "grace_hopper_517x606.jpg"), ("fakedata", "logos", "rgb_pytorch.png")]:
fullpath = (os.path.dirname(os.path.abspath(__file__)), 'assets') + test_file
img = read_image(get_file_path_2(*fullpath))
images.append(F.resize(img, [224, 224]))
dataset = TensorDataset(torch.stack(images).to(torch.float32), torch.tensor([0, 1]))

# Use mixup in the collate
mixup = T.RandomMixupCutmix(2, cutmix_alpha=1.0, mixup_alpha=1.0, label_smoothing=0.1)
dataloader = DataLoader(dataset, batch_size=2,
collate_fn=lambda batch: mixup(*(torch.stack(x) for x in zip(*batch))))

# Test against known statistics about the produced images
stats = []
for _ in range(25):
for b, t in dataloader:
stats.append(torch.stack([b.mean(), b.std(), t.std()]))

torch.testing.assert_close(
torch.stack(stats).mean(dim=0),
torch.tensor([46.94434738, 64.79092407, 0.23949696])
)
125 changes: 124 additions & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixupCutmix']


class Compose:
Expand Down Expand Up @@ -1953,3 +1953,126 @@ def forward(self, img):

def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)


class RandomMixupCutmix(torch.nn.Module):
"""Randomly apply Mixum or Cutmix to the provided batch and targets.
The class implements the data augmentations as described in the papers
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
<https://arxiv.org/abs/1905.04899>`_.

Args:
num_classes (int): number of classes used for one-hot encoding.
p (float): probability of the batch being transformed. Default value is 1.0.
mixup_alpha (float): hyperparameter of the Beta distribution used for mixup.
Set to 0.0 to turn off. Default value is 1.0.
cutmix_p (float): probability of using cutmix instead of mixup when both are on.
Default value is 0.5.
cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
Set to 0.0 to turn off. Default value is 0.0.
label_smoothing (float): the amount of smoothing using when one-hot encoding.
Set to 0.0 to turn off. Default value is 0.0.
inplace (bool): boolean to make this transform inplace. Default set to False.
"""

def __init__(self, num_classes: int,
p: float = 1.0, mixup_alpha: float = 1.0,
cutmix_p: float = 0.5, cutmix_alpha: float = 0.0,
label_smoothing: float = 0.0, inplace: bool = False) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that the choice of offering inplace support won't make everyone happy. The reason I decided to support it is:

  • It is more performant in terms of memory and speed.
  • Most the tensor-only transforms on torchvision typically support inplace operations (see Normalize and RandomErasing).
  • The ClassyVision version of this method supported inplace operations and thus to ensure we are equally performant we should support it as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a bit skeptical about adding this in-place operation, as most of the operators in here are actually not performing any in-place operation, so the gains of having an inplace flag might not be that large. But we might need to benchmark to be able to know for sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why we can't support an in-place operation if one can be implemented correctly. For now I plan to move the entire implementation on the references script so that we have time to think this through before we commit to the API, so we don't have to solve this now. But overall I think inplace is a valid optimization and provided it can be implemented properly, there is no reason not do get the extra speed improvements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline with @fmassa we are going to review the use of inplace before moving this from references to transforms.

super().__init__()
assert num_classes > 0, "Please provide a valid positive value for the num_classes."
assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero."

self.num_classes = num_classes
self.p = p
self.mixup_alpha = mixup_alpha
self.cutmix_p = cutmix_p
self.cutmix_alpha = cutmix_alpha
self.label_smoothing = label_smoothing
self.inplace = inplace

def _smooth_one_hot(self, target: Tensor) -> Tensor:
N = target.shape[0]
device = target.device
v = torch.full(size=(N, 1), fill_value=1 - self.label_smoothing, device=device)
return torch.full(size=(N, self.num_classes), fill_value=self.label_smoothing / self.num_classes,
device=device).scatter_add_(1, target.unsqueeze(1), v)

def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
"""
Args:
batch (Tensor): Float tensor of size (B, C, H, W)
target (Tensor): Integer tensor of size (B, )

Returns:
Tensor: Randomly transformed batch.
"""
if batch.ndim != 4:
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim))
elif target.ndim != 1:
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim))
elif target.dtype != torch.int64:
raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype))
elif batch.size(0) % 2 != 0:
# speed optimization, see below
raise ValueError("The batch size should be even.")

if not self.inplace:
batch = batch.clone()
# target = target.clone()

target = self._smooth_one_hot(target)
if torch.rand(1).item() >= self.p:
return batch, target

# It's faster to flip the batch instead of shuffling it to create image pairs
batch_flipped = batch.flip(0)
target_flipped = target.flip(0)

if self.mixup_alpha <= 0.0:
use_mixup = False
else:
use_mixup = self.cutmix_alpha <= 0.0 or torch.rand(1).item() >= self.cutmix_p

if use_mixup:
# Implemented as on mixup paper, page 3.
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0])
batch_flipped.mul_(1.0 - lambda_param)
batch.mul_(lambda_param).add_(batch_flipped)
else:
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0])
W, H = F.get_image_size(batch)

r_x = torch.randint(W, (1,))
r_y = torch.randint(H, (1,))

r = 0.5 * math.sqrt(1.0 - lambda_param)
r_w_half = int(r * W)
r_h_half = int(r * H)

x1 = int(torch.clamp(r_x - r_w_half, min=0))
y1 = int(torch.clamp(r_y - r_h_half, min=0))
x2 = int(torch.clamp(r_x + r_w_half, max=W))
y2 = int(torch.clamp(r_y + r_h_half, max=H))

batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2]
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))

target_flipped.mul_(1.0 - lambda_param)
target.mul_(lambda_param).add_(target_flipped)

return batch, target

def __repr__(self) -> str:
s = self.__class__.__name__ + '('
s += 'num_classes={num_classes}'
s += ', p={p}'
s += ', mixup_alpha={mixup_alpha}'
s += ', cutmix_p={cutmix_p}'
s += ', cutmix_alpha={cutmix_alpha}'
s += ', label_smoothing={label_smoothing}'
s += ', inplace={inplace}'
s += ')'
return s.format(**self.__dict__)