-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Adding Mixup and Cutmix #4379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Mixup and Cutmix #4379
Changes from 4 commits
2de7cfc
55aedd3
33d9575
c15dce0
6f2ebea
c1bc525
67acd89
544967e
0e128f3
3f19902
78e8605
de9fa07
3f212fe
33c2973
9abb18b
b5bf8fc
eb932b9
b548b7a
e3be92b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ | |
| "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", | ||
| "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", | ||
| "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize", | ||
| "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"] | ||
| "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize", 'RandomMixupCutmix'] | ||
|
|
||
|
|
||
| class Compose: | ||
|
|
@@ -1953,3 +1953,126 @@ def forward(self, img): | |
|
|
||
| def __repr__(self): | ||
| return self.__class__.__name__ + '(p={})'.format(self.p) | ||
|
|
||
|
|
||
| class RandomMixupCutmix(torch.nn.Module): | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Randomly apply Mixum or Cutmix to the provided batch and targets. | ||
| The class implements the data augmentations as described in the papers | ||
| `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and | ||
| `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" | ||
| <https://arxiv.org/abs/1905.04899>`_. | ||
|
|
||
| Args: | ||
| num_classes (int): number of classes used for one-hot encoding. | ||
| p (float): probability of the batch being transformed. Default value is 1.0. | ||
| mixup_alpha (float): hyperparameter of the Beta distribution used for mixup. | ||
| Set to 0.0 to turn off. Default value is 1.0. | ||
| cutmix_p (float): probability of using cutmix instead of mixup when both are on. | ||
| Default value is 0.5. | ||
| cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix. | ||
| Set to 0.0 to turn off. Default value is 0.0. | ||
| label_smoothing (float): the amount of smoothing using when one-hot encoding. | ||
| Set to 0.0 to turn off. Default value is 0.0. | ||
| inplace (bool): boolean to make this transform inplace. Default set to False. | ||
| """ | ||
|
|
||
| def __init__(self, num_classes: int, | ||
| p: float = 1.0, mixup_alpha: float = 1.0, | ||
| cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, | ||
| label_smoothing: float = 0.0, inplace: bool = False) -> None: | ||
|
||
| super().__init__() | ||
| assert num_classes > 0, "Please provide a valid positive value for the num_classes." | ||
| assert mixup_alpha > 0 or cutmix_alpha > 0, "Both alpha params can't be zero." | ||
|
|
||
| self.num_classes = num_classes | ||
| self.p = p | ||
| self.mixup_alpha = mixup_alpha | ||
| self.cutmix_p = cutmix_p | ||
| self.cutmix_alpha = cutmix_alpha | ||
| self.label_smoothing = label_smoothing | ||
| self.inplace = inplace | ||
|
|
||
| def _smooth_one_hot(self, target: Tensor) -> Tensor: | ||
| N = target.shape[0] | ||
| device = target.device | ||
| v = torch.full(size=(N, 1), fill_value=1 - self.label_smoothing, device=device) | ||
| return torch.full(size=(N, self.num_classes), fill_value=self.label_smoothing / self.num_classes, | ||
| device=device).scatter_add_(1, target.unsqueeze(1), v) | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
| """ | ||
| Args: | ||
| batch (Tensor): Float tensor of size (B, C, H, W) | ||
| target (Tensor): Integer tensor of size (B, ) | ||
|
|
||
| Returns: | ||
| Tensor: Randomly transformed batch. | ||
| """ | ||
| if batch.ndim != 4: | ||
| raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) | ||
| elif target.ndim != 1: | ||
| raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) | ||
| elif target.dtype != torch.int64: | ||
| raise ValueError("Target dtype should be torch.int64. Got {}".format(target.dtype)) | ||
| elif batch.size(0) % 2 != 0: | ||
| # speed optimization, see below | ||
| raise ValueError("The batch size should be even.") | ||
|
|
||
| if not self.inplace: | ||
| batch = batch.clone() | ||
| # target = target.clone() | ||
|
|
||
| target = self._smooth_one_hot(target) | ||
| if torch.rand(1).item() >= self.p: | ||
| return batch, target | ||
|
|
||
| # It's faster to flip the batch instead of shuffling it to create image pairs | ||
| batch_flipped = batch.flip(0) | ||
| target_flipped = target.flip(0) | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| if self.mixup_alpha <= 0.0: | ||
| use_mixup = False | ||
| else: | ||
| use_mixup = self.cutmix_alpha <= 0.0 or torch.rand(1).item() >= self.cutmix_p | ||
|
|
||
| if use_mixup: | ||
| # Implemented as on mixup paper, page 3. | ||
| lambda_param = float(torch._sample_dirichlet(torch.tensor([self.mixup_alpha, self.mixup_alpha]))[0]) | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| batch_flipped.mul_(1.0 - lambda_param) | ||
| batch.mul_(lambda_param).add_(batch_flipped) | ||
| else: | ||
| # Implemented as on cutmix paper, page 12 (with minor corrections on typos). | ||
| lambda_param = float(torch._sample_dirichlet(torch.tensor([self.cutmix_alpha, self.cutmix_alpha]))[0]) | ||
| W, H = F.get_image_size(batch) | ||
|
|
||
| r_x = torch.randint(W, (1,)) | ||
| r_y = torch.randint(H, (1,)) | ||
|
|
||
| r = 0.5 * math.sqrt(1.0 - lambda_param) | ||
| r_w_half = int(r * W) | ||
| r_h_half = int(r * H) | ||
|
|
||
| x1 = int(torch.clamp(r_x - r_w_half, min=0)) | ||
| y1 = int(torch.clamp(r_y - r_h_half, min=0)) | ||
| x2 = int(torch.clamp(r_x + r_w_half, max=W)) | ||
| y2 = int(torch.clamp(r_y + r_h_half, max=H)) | ||
datumbox marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| batch[:, :, y1:y2, x1:x2] = batch_flipped[:, :, y1:y2, x1:x2] | ||
| lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | ||
|
|
||
| target_flipped.mul_(1.0 - lambda_param) | ||
| target.mul_(lambda_param).add_(target_flipped) | ||
|
|
||
| return batch, target | ||
|
|
||
| def __repr__(self) -> str: | ||
| s = self.__class__.__name__ + '(' | ||
| s += 'num_classes={num_classes}' | ||
| s += ', p={p}' | ||
| s += ', mixup_alpha={mixup_alpha}' | ||
| s += ', cutmix_p={cutmix_p}' | ||
| s += ', cutmix_alpha={cutmix_alpha}' | ||
| s += ', label_smoothing={label_smoothing}' | ||
| s += ', inplace={inplace}' | ||
| s += ')' | ||
| return s.format(**self.__dict__) | ||
Uh oh!
There was an error while loading. Please reload this page.