77
88from . import functional as F , InterpolationMode
99
10- __all__ = ["AutoAugmentPolicy" , "AutoAugment" ]
10+ __all__ = ["AutoAugmentPolicy" , "AutoAugment" , "RandAugment" ]
1111
1212
1313def _apply_op (img : Tensor , op_name : str , magnitude : float ,
@@ -58,6 +58,7 @@ class AutoAugmentPolicy(Enum):
5858 SVHN = "svhn"
5959
6060
61+ # FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
6162class AutoAugment (torch .nn .Module ):
6263 r"""AutoAugment data augmentation method based on
6364 `"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
@@ -85,9 +86,9 @@ def __init__(
8586 self .policy = policy
8687 self .interpolation = interpolation
8788 self .fill = fill
88- self .transforms = self ._get_transforms (policy )
89+ self .policies = self ._get_policies (policy )
8990
90- def _get_transforms (
91+ def _get_policies (
9192 self ,
9293 policy : AutoAugmentPolicy
9394 ) -> List [Tuple [Tuple [str , float , Optional [int ]], Tuple [str , float , Optional [int ]]]]:
@@ -178,9 +179,9 @@ def _get_transforms(
178179 else :
179180 raise ValueError ("The provided policy {} is not recognized." .format (policy ))
180181
181- def _get_magnitudes (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
182+ def _augmentation_space (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
182183 return {
183- # name : (magnitudes, signed)
184+ # op_name : (magnitudes, signed)
184185 "ShearX" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
185186 "ShearY" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
186187 "TranslateX" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
@@ -224,11 +225,11 @@ def forward(self, img: Tensor) -> Tensor:
224225 elif fill is not None :
225226 fill = [float (f ) for f in fill ]
226227
227- transform_id , probs , signs = self .get_params (len (self .transforms ))
228+ transform_id , probs , signs = self .get_params (len (self .policies ))
228229
229- for i , (op_name , p , magnitude_id ) in enumerate (self .transforms [transform_id ]):
230+ for i , (op_name , p , magnitude_id ) in enumerate (self .policies [transform_id ]):
230231 if probs [i ] <= p :
231- op_meta = self ._get_magnitudes (10 , F .get_image_size (img ))
232+ op_meta = self ._augmentation_space (10 , F .get_image_size (img ))
232233 magnitudes , signed = op_meta [op_name ]
233234 magnitude = float (magnitudes [magnitude_id ].item ()) if magnitude_id is not None else 0.0
234235 if signed and signs [i ] == 0 :
@@ -239,3 +240,87 @@ def forward(self, img: Tensor) -> Tensor:
239240
240241 def __repr__ (self ) -> str :
241242 return self .__class__ .__name__ + '(policy={}, fill={})' .format (self .policy , self .fill )
243+
244+
245+ class RandAugment (torch .nn .Module ):
246+ r"""RandAugment data augmentation method based on
247+ `"RandAugment: Practical automated data augmentation with a reduced search space"
248+ <https://arxiv.org/abs/1909.13719>`.
249+ If the image is torch Tensor, it should be of type torch.uint8, and it is expected
250+ to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
251+ If img is PIL Image, it is expected to be in mode "L" or "RGB".
252+
253+ Args:
254+ num_ops (int): Number of augmentation transformations to apply sequentially.
255+ magnitude (int): Magnitude for all the transformations.
256+ num_magnitude_bins (int): The number of different magnitude values.
257+ interpolation (InterpolationMode): Desired interpolation enum defined by
258+ :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
259+ If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
260+ fill (sequence or number, optional): Pixel fill value for the area outside the transformed
261+ image. If given a number, the value is used for all bands respectively.
262+ """
263+
264+ def __init__ (self , num_ops : int = 2 , magnitude : int = 9 , num_magnitude_bins : int = 30 ,
265+ interpolation : InterpolationMode = InterpolationMode .NEAREST ,
266+ fill : Optional [List [float ]] = None ) -> None :
267+ super ().__init__ ()
268+ self .num_ops = num_ops
269+ self .magnitude = magnitude
270+ self .num_magnitude_bins = num_magnitude_bins
271+ self .interpolation = interpolation
272+ self .fill = fill
273+
274+ def _augmentation_space (self , num_bins : int , image_size : List [int ]) -> Dict [str , Tuple [Tensor , bool ]]:
275+ return {
276+ # op_name: (magnitudes, signed)
277+ "ShearX" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
278+ "ShearY" : (torch .linspace (0.0 , 0.3 , num_bins ), True ),
279+ "TranslateX" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [0 ], num_bins ), True ),
280+ "TranslateY" : (torch .linspace (0.0 , 150.0 / 331.0 * image_size [1 ], num_bins ), True ),
281+ "Rotate" : (torch .linspace (0.0 , 30.0 , num_bins ), True ),
282+ "Brightness" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
283+ "Color" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
284+ "Contrast" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
285+ "Sharpness" : (torch .linspace (0.0 , 0.9 , num_bins ), True ),
286+ "Posterize" : (8 - (torch .arange (num_bins ) / ((num_bins - 1 ) / 4 )).round ().int (), False ),
287+ "Solarize" : (torch .linspace (256.0 , 0.0 , num_bins ), False ),
288+ "AutoContrast" : (torch .tensor (0.0 ), False ),
289+ "Equalize" : (torch .tensor (0.0 ), False ),
290+ "Invert" : (torch .tensor (0.0 ), False ),
291+ }
292+
293+ def forward (self , img : Tensor ) -> Tensor :
294+ """
295+ img (PIL Image or Tensor): Image to be transformed.
296+ Returns:
297+ PIL Image or Tensor: Transformed image.
298+ """
299+ fill = self .fill
300+ if isinstance (img , Tensor ):
301+ if isinstance (fill , (int , float )):
302+ fill = [float (fill )] * F .get_image_num_channels (img )
303+ elif fill is not None :
304+ fill = [float (f ) for f in fill ]
305+
306+ for _ in range (self .num_ops ):
307+ op_meta = self ._augmentation_space (self .num_magnitude_bins , F .get_image_size (img ))
308+ op_index = int (torch .randint (len (op_meta ), (1 ,)).item ())
309+ op_name = list (op_meta .keys ())[op_index ]
310+ magnitudes , signed = op_meta [op_name ]
311+ magnitude = float (magnitudes [self .magnitude ].item ()) if magnitudes .ndim > 0 else 0.0
312+ if signed and torch .randint (2 , (1 ,)):
313+ magnitude *= - 1.0
314+ img = _apply_op (img , op_name , magnitude , interpolation = self .interpolation , fill = fill )
315+
316+ return img
317+
318+ def __repr__ (self ) -> str :
319+ s = self .__class__ .__name__ + '('
320+ s += 'num_ops={num_ops}'
321+ s += ', magnitude={magnitude}'
322+ s += ', num_magnitude_bins={num_magnitude_bins}'
323+ s += ', interpolation={interpolation}'
324+ s += ', fill={fill}'
325+ s += ')'
326+ return s .format (** self .__dict__ )
0 commit comments