@@ -173,8 +173,15 @@ def main(args):
173173
174174 criterion = nn .CrossEntropyLoss ()
175175
176- optimizer = torch .optim .SGD (
177- model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
176+ opt_name = args .opt .lower ()
177+ if opt_name == 'sgd' :
178+ optimizer = torch .optim .SGD (
179+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
180+ elif opt_name == 'rmsprop' :
181+ optimizer = torch .optim .RMSprop (
182+ model .parameters (), lr = args .lr , momentum = args .momentum , weight_decay = args .weight_decay )
183+ else :
184+ raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." .format (args .opt ))
178185
179186 if args .apex :
180187 model , optimizer = amp .initialize (model , optimizer ,
@@ -191,9 +198,11 @@ def main(args):
191198 if args .resume :
192199 checkpoint = torch .load (args .resume , map_location = 'cpu' )
193200 model_without_ddp .load_state_dict (checkpoint ['model' ])
194- optimizer .load_state_dict (checkpoint ['optimizer' ])
195- lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
196- args .start_epoch = checkpoint ['epoch' ] + 1
201+ if not args .no_resume_opt :
202+ optimizer .load_state_dict (checkpoint ['optimizer' ])
203+ if not args .no_resume_sched :
204+ lr_scheduler .load_state_dict (checkpoint ['lr_scheduler' ])
205+ args .start_epoch = checkpoint ['epoch' ] + 1
197206
198207 if args .test_only :
199208 evaluate (model , criterion , data_loader_test , device = device )
@@ -238,6 +247,7 @@ def parse_args():
238247 help = 'number of total epochs to run' )
239248 parser .add_argument ('-j' , '--workers' , default = 16 , type = int , metavar = 'N' ,
240249 help = 'number of data loading workers (default: 16)' )
250+ parser .add_argument ('--opt' , default = 'sgd' , type = str , help = 'optimizer' )
241251 parser .add_argument ('--lr' , default = 0.1 , type = float , help = 'initial learning rate' )
242252 parser .add_argument ('--momentum' , default = 0.9 , type = float , metavar = 'M' ,
243253 help = 'momentum' )
@@ -275,6 +285,18 @@ def parse_args():
275285 help = "Use pre-trained models from the modelzoo" ,
276286 action = "store_true" ,
277287 )
288+ parser .add_argument (
289+ "--no-resume-opt" ,
290+ dest = "no_resume_opt" ,
291+ help = "When resuming from checkpoint it ignores the optimizer state" ,
292+ action = "store_true" ,
293+ )
294+ parser .add_argument (
295+ "--no-resume-sched" ,
296+ dest = "no_resume_sched" ,
297+ help = "When resuming from checkpoint it ignores the scheduler state" ,
298+ action = "store_true" ,
299+ )
278300
279301 # Mixed precision training parameters
280302 parser .add_argument ('--apex' , action = 'store_true' ,
0 commit comments