@@ -165,10 +165,16 @@ def main(args):
165165 train_dir = os .path .join (args .data_path , 'train' )
166166 val_dir = os .path .join (args .data_path , 'val' )
167167 dataset , dataset_test , train_sampler , test_sampler = load_data (train_dir , val_dir , args )
168+
169+ collate_fn = None
170+ if args .mixup_alpha > 0.0 or args .cutmix_alpha > 0.0 :
171+ mixupcutmix = torchvision .transforms .RandomMixupCutmix (len (dataset .classes ), mixup_alpha = args .mixup_alpha ,
172+ cutmix_alpha = args .cutmix_alpha )
173+ collate_fn = lambda batch : mixupcutmix (* (torch .utils .data ._utils .collate .default_collate (batch ))) # noqa: E731
168174 data_loader = torch .utils .data .DataLoader (
169175 dataset , batch_size = args .batch_size ,
170- sampler = train_sampler , num_workers = args .workers , pin_memory = True )
171-
176+ sampler = train_sampler , num_workers = args .workers , pin_memory = True ,
177+ collate_fn = collate_fn )
172178 data_loader_test = torch .utils .data .DataLoader (
173179 dataset_test , batch_size = args .batch_size ,
174180 sampler = test_sampler , num_workers = args .workers , pin_memory = True )
@@ -254,7 +260,6 @@ def main(args):
254260def get_args_parser (add_help = True ):
255261 import argparse
256262 parser = argparse .ArgumentParser (description = 'PyTorch Classification Training' , add_help = add_help )
257-
258263 parser .add_argument ('--data-path' , default = '/datasets01/imagenet_full_size/061417/' , help = 'dataset' )
259264 parser .add_argument ('--model' , default = 'resnet18' , help = 'model' )
260265 parser .add_argument ('--device' , default = 'cuda' , help = 'device' )
@@ -273,6 +278,8 @@ def get_args_parser(add_help=True):
273278 parser .add_argument ('--label-smoothing' , default = 0.0 , type = float ,
274279 help = 'label smoothing (default: 0.0)' ,
275280 dest = 'label_smoothing' )
281+ parser .add_argument ('--mixup-alpha' , default = 0.0 , type = float , help = 'mixup alpha (default: 0.0)' )
282+ parser .add_argument ('--cutmix-alpha' , default = 0.0 , type = float , help = 'cutmix alpha (default: 0.0)' )
276283 parser .add_argument ('--lr-step-size' , default = 30 , type = int , help = 'decrease lr every step-size epochs' )
277284 parser .add_argument ('--lr-gamma' , default = 0.1 , type = float , help = 'decrease lr by a factor of lr-gamma' )
278285 parser .add_argument ('--print-freq' , default = 10 , type = int , help = 'print frequency' )
@@ -306,7 +313,6 @@ def get_args_parser(add_help=True):
306313 )
307314 parser .add_argument ('--auto-augment' , default = None , help = 'auto augment policy (default: None)' )
308315 parser .add_argument ('--random-erase' , default = 0.0 , type = float , help = 'random erasing probability (default: 0.0)' )
309-
310316 # Mixed precision training parameters
311317 parser .add_argument ('--apex' , action = 'store_true' ,
312318 help = 'Use apex for mixed precision training' )
@@ -315,7 +321,6 @@ def get_args_parser(add_help=True):
315321 'O0 for FP32 training, O1 for mixed precision training.'
316322 'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
317323 )
318-
319324 # distributed training parameters
320325 parser .add_argument ('--world-size' , default = 1 , type = int ,
321326 help = 'number of distributed processes' )
@@ -326,7 +331,6 @@ def get_args_parser(add_help=True):
326331 parser .add_argument (
327332 '--model-ema-decay' , type = float , default = 0.99 ,
328333 help = 'decay factor for Exponential Moving Average of model parameters(default: 0.99)' )
329-
330334 return parser
331335
332336
0 commit comments