2222 "RandomHorizontalFlip" , "RandomVerticalFlip" , "RandomResizedCrop" , "RandomSizedCrop" , "FiveCrop" , "TenCrop" ,
2323 "LinearTransformation" , "ColorJitter" , "RandomRotation" , "RandomAffine" , "Grayscale" , "RandomGrayscale" ,
2424 "RandomPerspective" , "RandomErasing" , "GaussianBlur" , "InterpolationMode" , "RandomInvert" , "RandomPosterize" ,
25- "RandomSolarize" , "RandomAdjustSharpness" , "RandomAutocontrast" , "RandomEqualize" , 'RandomMixupCutmix' ]
25+ "RandomSolarize" , "RandomAdjustSharpness" , "RandomAutocontrast" , "RandomEqualize" , 'RandomMixup' ,
26+ "RandomCutmix" ]
2627
2728
2829class Compose :
@@ -515,9 +516,20 @@ def __call__(self, img):
515516class RandomChoice (RandomTransforms ):
516517 """Apply single transformation randomly picked from a list. This transform does not support torchscript.
517518 """
518- def __call__ (self , img ):
519- t = random .choice (self .transforms )
520- return t (img )
519+ def __init__ (self , transforms , p = None ):
520+ super ().__init__ (transforms )
521+ if p is not None and not isinstance (p , Sequence ):
522+ raise TypeError ("Argument transforms should be a sequence" )
523+ self .p = p
524+
525+ def __call__ (self , * args ):
526+ t = random .choices (self .transforms , weights = self .p )[0 ]
527+ return t (* args )
528+
529+ def __repr__ (self ):
530+ format_string = super ().__repr__ ()
531+ format_string += '(p={0})' .format (self .p )
532+ return format_string
521533
522534
523535class RandomCrop (torch .nn .Module ):
@@ -1956,38 +1968,103 @@ def __repr__(self):
19561968
19571969
19581970# TODO: move this to references before merging and delete the tests
1959- class RandomMixupCutmix (torch .nn .Module ):
1960- """Randomly apply Mixup or Cutmix to the provided batch and targets.
1961- The class implements the data augmentations as described in the papers
1962- `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_ and
1971+ class RandomMixup (torch .nn .Module ):
1972+ """Randomly apply Mixup to the provided batch and targets.
1973+ The class implements the data augmentations as described in the paper
1974+ `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
1975+
1976+ Args:
1977+ num_classes (int): number of classes used for one-hot encoding.
1978+ p (float): probability of the batch being transformed. Default value is 0.5.
1979+ alpha (float): hyperparameter of the Beta distribution used for mixup.
1980+ Default value is 1.0.
1981+ inplace (bool): boolean to make this transform inplace. Default set to False.
1982+ """
1983+
1984+ def __init__ (self , num_classes : int ,
1985+ p : float = 0.5 , alpha : float = 1.0 ,
1986+ inplace : bool = False ) -> None :
1987+ super ().__init__ ()
1988+ assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
1989+ assert alpha > 0 , "Alpha param can't be zero."
1990+
1991+ self .num_classes = num_classes
1992+ self .p = p
1993+ self .alpha = alpha
1994+ self .inplace = inplace
1995+
1996+ def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
1997+ """
1998+ Args:
1999+ batch (Tensor): Float tensor of size (B, C, H, W)
2000+ target (Tensor): Integer tensor of size (B, )
2001+
2002+ Returns:
2003+ Tensor: Randomly transformed batch.
2004+ """
2005+ if batch .ndim != 4 :
2006+ raise ValueError ("Batch ndim should be 4. Got {}" .format (batch .ndim ))
2007+ elif target .ndim != 1 :
2008+ raise ValueError ("Target ndim should be 1. Got {}" .format (target .ndim ))
2009+ elif target .dtype != torch .int64 :
2010+ raise ValueError ("Target dtype should be torch.int64. Got {}" .format (target .dtype ))
2011+
2012+ if not self .inplace :
2013+ batch = batch .clone ()
2014+ # target = target.clone()
2015+
2016+ target = torch .nn .functional .one_hot (target , num_classes = self .num_classes ).to (dtype = torch .float32 )
2017+ if torch .rand (1 ).item () >= self .p :
2018+ return batch , target
2019+
2020+ # It's faster to roll the batch by one instead of shuffling it to create image pairs
2021+ batch_rolled = batch .roll (1 , 0 )
2022+ target_rolled = target .roll (1 )
2023+
2024+ # Implemented as on mixup paper, page 3.
2025+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
2026+ batch_rolled .mul_ (1.0 - lambda_param )
2027+ batch .mul_ (lambda_param ).add_ (batch_rolled )
2028+
2029+ target_rolled .mul_ (1.0 - lambda_param )
2030+ target .mul_ (lambda_param ).add_ (target_rolled )
2031+
2032+ return batch , target
2033+
2034+ def __repr__ (self ) -> str :
2035+ s = self .__class__ .__name__ + '('
2036+ s += 'num_classes={num_classes}'
2037+ s += ', p={p}'
2038+ s += ', alpha={alpha}'
2039+ s += ', inplace={inplace}'
2040+ s += ')'
2041+ return s .format (** self .__dict__ )
2042+
2043+
2044+ class RandomCutmix (torch .nn .Module ):
2045+ """Randomly apply Cutmix to the provided batch and targets.
2046+ The class implements the data augmentations as described in the paper
19632047 `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
19642048 <https://arxiv.org/abs/1905.04899>`_.
19652049
19662050 Args:
19672051 num_classes (int): number of classes used for one-hot encoding.
1968- p (float): probability of the batch being transformed. Default value is 1.0.
1969- mixup_alpha (float): hyperparameter of the Beta distribution used for mixup.
1970- Set to 0.0 to turn off. Default value is 1.0.
1971- cutmix_p (float): probability of using cutmix instead of mixup when both are on.
1972- Default value is 0.5.
1973- cutmix_alpha (float): hyperparameter of the Beta distribution used for cutmix.
1974- Set to 0.0 to turn off. Default value is 0.0.
2052+ p (float): probability of the batch being transformed. Default value is 0.5.
2053+ alpha (float): hyperparameter of the Beta distribution used for cutmix.
2054+ Default value is 1.0.
19752055 inplace (bool): boolean to make this transform inplace. Default set to False.
19762056 """
19772057
19782058 def __init__ (self , num_classes : int ,
1979- p : float = 1.0 , mixup_alpha : float = 1.0 ,
1980- cutmix_p : float = 0.5 , cutmix_alpha : float = 0.0 ,
2059+ p : float = 0.5 , alpha : float = 1.0 ,
19812060 inplace : bool = False ) -> None :
19822061 super ().__init__ ()
19832062 assert num_classes > 0 , "Please provide a valid positive value for the num_classes."
1984- assert mixup_alpha > 0 or cutmix_alpha > 0 , "Both alpha params can't be zero."
2063+ assert alpha > 0 , "Alpha param can't be zero."
19852064
19862065 self .num_classes = num_classes
19872066 self .p = p
1988- self .mixup_alpha = mixup_alpha
1989- self .cutmix_p = cutmix_p
1990- self .cutmix_alpha = cutmix_alpha
2067+ self .alpha = alpha
19912068 self .inplace = inplace
19922069
19932070 def forward (self , batch : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
@@ -2018,35 +2095,24 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
20182095 batch_rolled = batch .roll (1 , 0 )
20192096 target_rolled = target .roll (1 )
20202097
2021- if self .mixup_alpha <= 0.0 :
2022- use_mixup = False
2023- else :
2024- use_mixup = self .cutmix_alpha <= 0.0 or torch .rand (1 ).item () >= self .cutmix_p
2025-
2026- if use_mixup :
2027- # Implemented as on mixup paper, page 3.
2028- lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .mixup_alpha , self .mixup_alpha ]))[0 ])
2029- batch_rolled .mul_ (1.0 - lambda_param )
2030- batch .mul_ (lambda_param ).add_ (batch_rolled )
2031- else :
2032- # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2033- lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .cutmix_alpha , self .cutmix_alpha ]))[0 ])
2034- W , H = F .get_image_size (batch )
2098+ # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
2099+ lambda_param = float (torch ._sample_dirichlet (torch .tensor ([self .alpha , self .alpha ]))[0 ])
2100+ W , H = F .get_image_size (batch )
20352101
2036- r_x = torch .randint (W , (1 ,))
2037- r_y = torch .randint (H , (1 ,))
2102+ r_x = torch .randint (W , (1 ,))
2103+ r_y = torch .randint (H , (1 ,))
20382104
2039- r = 0.5 * math .sqrt (1.0 - lambda_param )
2040- r_w_half = int (r * W )
2041- r_h_half = int (r * H )
2105+ r = 0.5 * math .sqrt (1.0 - lambda_param )
2106+ r_w_half = int (r * W )
2107+ r_h_half = int (r * H )
20422108
2043- x1 = int (torch .clamp (r_x - r_w_half , min = 0 ))
2044- y1 = int (torch .clamp (r_y - r_h_half , min = 0 ))
2045- x2 = int (torch .clamp (r_x + r_w_half , max = W ))
2046- y2 = int (torch .clamp (r_y + r_h_half , max = H ))
2109+ x1 = int (torch .clamp (r_x - r_w_half , min = 0 ))
2110+ y1 = int (torch .clamp (r_y - r_h_half , min = 0 ))
2111+ x2 = int (torch .clamp (r_x + r_w_half , max = W ))
2112+ y2 = int (torch .clamp (r_y + r_h_half , max = H ))
20472113
2048- batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
2049- lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
2114+ batch [:, :, y1 :y2 , x1 :x2 ] = batch_rolled [:, :, y1 :y2 , x1 :x2 ]
2115+ lambda_param = float (1.0 - (x2 - x1 ) * (y2 - y1 ) / (W * H ))
20502116
20512117 target_rolled .mul_ (1.0 - lambda_param )
20522118 target .mul_ (lambda_param ).add_ (target_rolled )
@@ -2057,9 +2123,7 @@ def __repr__(self) -> str:
20572123 s = self .__class__ .__name__ + '('
20582124 s += 'num_classes={num_classes}'
20592125 s += ', p={p}'
2060- s += ', mixup_alpha={mixup_alpha}'
2061- s += ', cutmix_p={cutmix_p}'
2062- s += ', cutmix_alpha={cutmix_alpha}'
2126+ s += ', alpha={alpha}'
20632127 s += ', inplace={inplace}'
20642128 s += ')'
20652129 return s .format (** self .__dict__ )
0 commit comments