@@ -60,16 +60,21 @@ def get_train_dataset(stage, dataset_root):
6060
6161
6262@torch .no_grad ()
63- def _validate (model , args , val_dataset , * , padder_mode , num_flow_updates = None , batch_size = None , header = None ):
63+ def _evaluate (model , args , val_dataset , * , padder_mode , num_flow_updates = None , batch_size = None , header = None ):
6464 """Helper function to compute various metrics (epe, etc.) for a model on a given dataset.
6565
6666 We process as many samples as possible with ddp, and process the rest on a single worker.
6767 """
6868 batch_size = batch_size or args .batch_size
69+ device = torch .device (args .device )
6970
7071 model .eval ()
7172
72- sampler = torch .utils .data .distributed .DistributedSampler (val_dataset , shuffle = False , drop_last = True )
73+ if args .distributed :
74+ sampler = torch .utils .data .distributed .DistributedSampler (val_dataset , shuffle = False , drop_last = True )
75+ else :
76+ sampler = torch .utils .data .SequentialSampler (val_dataset )
77+
7378 val_loader = torch .utils .data .DataLoader (
7479 val_dataset ,
7580 sampler = sampler ,
@@ -88,7 +93,7 @@ def inner_loop(blob):
8893 image1 , image2 , flow_gt = blob [:3 ]
8994 valid_flow_mask = None if len (blob ) == 3 else blob [- 1 ]
9095
91- image1 , image2 = image1 .cuda ( ), image2 .cuda ( )
96+ image1 , image2 = image1 .to ( device ), image2 .to ( device )
9297
9398 padder = utils .InputPadder (image1 .shape , mode = padder_mode )
9499 image1 , image2 = padder .pad (image1 , image2 )
@@ -115,21 +120,22 @@ def inner_loop(blob):
115120 inner_loop (blob )
116121 num_processed_samples += blob [0 ].shape [0 ] # batch size
117122
118- num_processed_samples = utils .reduce_across_processes (num_processed_samples )
119- print (
120- f"Batch-processed { num_processed_samples } / { len (val_dataset )} samples. "
121- "Going to process the remaining samples individually, if any."
122- )
123+ if args .distributed :
124+ num_processed_samples = utils .reduce_across_processes (num_processed_samples )
125+ print (
126+ f"Batch-processed { num_processed_samples } / { len (val_dataset )} samples. "
127+ "Going to process the remaining samples individually, if any."
128+ )
129+ if args .rank == 0 : # we only need to process the rest on a single worker
130+ for i in range (num_processed_samples , len (val_dataset )):
131+ inner_loop (val_dataset [i ])
123132
124- if args .rank == 0 : # we only need to process the rest on a single worker
125- for i in range (num_processed_samples , len (val_dataset )):
126- inner_loop (val_dataset [i ])
133+ logger .synchronize_between_processes ()
127134
128- logger .synchronize_between_processes ()
129135 print (header , logger )
130136
131137
132- def validate (model , args ):
138+ def evaluate (model , args ):
133139 val_datasets = args .val_dataset or []
134140
135141 if args .prototype :
@@ -145,21 +151,21 @@ def validate(model, args):
145151 if name == "kitti" :
146152 # Kitti has different image sizes so we need to individually pad them, we can't batch.
147153 # see comment in InputPadder
148- if args .batch_size != 1 and args .rank == 0 :
154+ if args .batch_size != 1 and ( not args .distributed or args . rank == 0 ) :
149155 warnings .warn (
150156 f"Batch-size={ args .batch_size } was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
151157 )
152158
153159 val_dataset = KittiFlow (root = args .dataset_root , split = "train" , transforms = preprocessing )
154- _validate (
160+ _evaluate (
155161 model , args , val_dataset , num_flow_updates = 24 , padder_mode = "kitti" , header = "Kitti val" , batch_size = 1
156162 )
157163 elif name == "sintel" :
158164 for pass_name in ("clean" , "final" ):
159165 val_dataset = Sintel (
160166 root = args .dataset_root , split = "train" , pass_name = pass_name , transforms = preprocessing
161167 )
162- _validate (
168+ _evaluate (
163169 model ,
164170 args ,
165171 val_dataset ,
@@ -172,11 +178,12 @@ def validate(model, args):
172178
173179
174180def train_one_epoch (model , optimizer , scheduler , train_loader , logger , args ):
181+ device = torch .device (args .device )
175182 for data_blob in logger .log_every (train_loader ):
176183
177184 optimizer .zero_grad ()
178185
179- image1 , image2 , flow_gt , valid_flow_mask = (x .cuda ( ) for x in data_blob )
186+ image1 , image2 , flow_gt , valid_flow_mask = (x .to ( device ) for x in data_blob )
180187 flow_predictions = model (image1 , image2 , num_flow_updates = args .num_flow_updates )
181188
182189 loss = utils .sequence_loss (flow_predictions , flow_gt , valid_flow_mask , args .gamma )
@@ -200,36 +207,68 @@ def main(args):
200207 raise ValueError ("The weights parameter works only in prototype mode. Please pass the --prototype argument." )
201208 utils .setup_ddp (args )
202209
210+ if args .distributed and args .device == "cpu" :
211+ raise ValueError ("The device must be cuda if we want to run in distributed mode using torchrun" )
212+ device = torch .device (args .device )
213+
203214 if args .prototype :
204215 model = prototype .models .optical_flow .__dict__ [args .model ](weights = args .weights )
205216 else :
206217 model = torchvision .models .optical_flow .__dict__ [args .model ](pretrained = args .pretrained )
207218
208- model = model .to (args .local_rank )
209- model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
219+ if args .distributed :
220+ model = model .to (args .local_rank )
221+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .local_rank ])
222+ model_without_ddp = model .module
223+ else :
224+ model .to (device )
225+ model_without_ddp = model
210226
211227 if args .resume is not None :
212- d = torch .load (args .resume , map_location = "cpu" )
213- model .load_state_dict (d , strict = True )
228+ checkpoint = torch .load (args .resume , map_location = "cpu" )
229+ model_without_ddp .load_state_dict (checkpoint [ "model" ] )
214230
215231 if args .train_dataset is None :
216232 # Set deterministic CUDNN algorithms, since they can affect epe a fair bit.
217233 torch .backends .cudnn .benchmark = False
218234 torch .backends .cudnn .deterministic = True
219- validate (model , args )
235+ evaluate (model , args )
220236 return
221237
222238 print (f"Parameter Count: { sum (p .numel () for p in model .parameters () if p .requires_grad )} " )
223239
240+ train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
241+
242+ optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
243+
244+ scheduler = torch .optim .lr_scheduler .OneCycleLR (
245+ optimizer = optimizer ,
246+ max_lr = args .lr ,
247+ epochs = args .epochs ,
248+ steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
249+ pct_start = 0.05 ,
250+ cycle_momentum = False ,
251+ anneal_strategy = "linear" ,
252+ )
253+
254+ if args .resume is not None :
255+ optimizer .load_state_dict (checkpoint ["optimizer" ])
256+ scheduler .load_state_dict (checkpoint ["scheduler" ])
257+ args .start_epoch = checkpoint ["epoch" ] + 1
258+ else :
259+ args .start_epoch = 0
260+
224261 torch .backends .cudnn .benchmark = True
225262
226263 model .train ()
227264 if args .freeze_batch_norm :
228265 utils .freeze_batch_norm (model .module )
229266
230- train_dataset = get_train_dataset (args .train_dataset , args .dataset_root )
267+ if args .distributed :
268+ sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
269+ else :
270+ sampler = torch .utils .data .RandomSampler (train_dataset )
231271
232- sampler = torch .utils .data .distributed .DistributedSampler (train_dataset , shuffle = True , drop_last = True )
233272 train_loader = torch .utils .data .DataLoader (
234273 train_dataset ,
235274 sampler = sampler ,
@@ -238,25 +277,15 @@ def main(args):
238277 num_workers = args .num_workers ,
239278 )
240279
241- optimizer = torch .optim .AdamW (model .parameters (), lr = args .lr , weight_decay = args .weight_decay , eps = args .adamw_eps )
242-
243- scheduler = torch .optim .lr_scheduler .OneCycleLR (
244- optimizer = optimizer ,
245- max_lr = args .lr ,
246- epochs = args .epochs ,
247- steps_per_epoch = ceil (len (train_dataset ) / (args .world_size * args .batch_size )),
248- pct_start = 0.05 ,
249- cycle_momentum = False ,
250- anneal_strategy = "linear" ,
251- )
252-
253280 logger = utils .MetricLogger ()
254281
255282 done = False
256- for current_epoch in range (args .epochs ):
283+ for current_epoch in range (args .start_epoch , args . epochs ):
257284 print (f"EPOCH { current_epoch } " )
285+ if args .distributed :
286+ # needed on distributed mode, otherwise the data loading order would be the same for all epochs
287+ sampler .set_epoch (current_epoch )
258288
259- sampler .set_epoch (current_epoch ) # needed, otherwise the data loading order would be the same for all epochs
260289 train_one_epoch (
261290 model = model ,
262291 optimizer = optimizer ,
@@ -269,13 +298,19 @@ def main(args):
269298 # Note: we don't sync the SmoothedValues across processes, so the printed metrics are just those of rank 0
270299 print (f"Epoch { current_epoch } done. " , logger )
271300
272- if args .rank == 0 :
273- # TODO: Also save the optimizer and scheduler
274- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
275- torch .save (model .state_dict (), Path (args .output_dir ) / f"{ args .name } .pth" )
301+ if not args .distributed or args .rank == 0 :
302+ checkpoint = {
303+ "model" : model_without_ddp .state_dict (),
304+ "optimizer" : optimizer .state_dict (),
305+ "scheduler" : scheduler .state_dict (),
306+ "epoch" : current_epoch ,
307+ "args" : args ,
308+ }
309+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } _{ current_epoch } .pth" )
310+ torch .save (checkpoint , Path (args .output_dir ) / f"{ args .name } .pth" )
276311
277312 if current_epoch % args .val_freq == 0 or done :
278- validate (model , args )
313+ evaluate (model , args )
279314 model .train ()
280315 if args .freeze_batch_norm :
281316 utils .freeze_batch_norm (model .module )
@@ -349,6 +384,7 @@ def get_args_parser(add_help=True):
349384 action = "store_true" ,
350385 )
351386 parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load." )
387+ parser .add_argument ("--device" , default = "cuda" , type = str , help = "device (Use cuda or cpu, Default: cuda)" )
352388
353389 return parser
354390
0 commit comments