1
+ import argparse
2
+ import datetime
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ import ruamel .yaml as yaml
12
+ import torch
13
+ import torch .backends .cudnn as cudnn
14
+ import torch .distributed as dist
15
+
16
+ import utils
17
+ from utils .checkpointer import Checkpointer
18
+
19
+ from dataset import create_dataset , create_sampler , create_loader , build_tokenizer
20
+ from dataset .utils import collect_tensor_result , grounding_eval_bbox , grounding_eval_bbox_vlue
21
+
22
+ from models .model_grounding import XVLMForGrounding
23
+
24
+ from optim import create_optimizer
25
+ from refTools .refer_python3 import REFER
26
+ from scheduler import create_scheduler
27
+ from utils .hdfs_io import hmkdir , hcopy , hexists
28
+
29
+
30
+ def train (model , data_loader , optimizer , tokenizer , epoch , device , scheduler , config ):
31
+ model .train ()
32
+
33
+ metric_logger = utils .MetricLogger (delimiter = " " )
34
+ metric_logger .add_meter ('lr' , utils .SmoothedValue (window_size = 1 , fmt = '{value:.6f}' ))
35
+ metric_logger .add_meter ('loss_bbox' , utils .SmoothedValue (window_size = 1 , fmt = '{value:.4f}' ))
36
+ metric_logger .add_meter ('loss_giou' , utils .SmoothedValue (window_size = 1 , fmt = '{value:.4f}' ))
37
+ header = 'Train Epoch: [{}]' .format (epoch )
38
+ print_freq = 50
39
+ step_size = 100
40
+
41
+ accumulate_steps = int (config .get ('accumulate_steps' , 1 ))
42
+ for i , (image , text , target_bbox ) in enumerate (metric_logger .log_every (data_loader , print_freq , header )):
43
+ image = image .to (device , non_blocking = True )
44
+ text_input = tokenizer (text , padding = 'longest' , truncation = True , max_length = config ['max_tokens' ], return_tensors = "pt" ).to (device )
45
+ target_bbox = target_bbox .to (device )
46
+
47
+ _ , loss_bbox , loss_giou = model (image , text_input .input_ids , text_input .attention_mask , target_bbox = target_bbox )
48
+ loss = loss_bbox + loss_giou
49
+
50
+ if accumulate_steps > 1 :
51
+ loss = loss / accumulate_steps
52
+
53
+ # backward
54
+ loss .backward ()
55
+
56
+ if (i + 1 ) % accumulate_steps == 0 :
57
+ # update
58
+ optimizer .step ()
59
+ scheduler .step ()
60
+ optimizer .zero_grad ()
61
+
62
+ metric_logger .update (loss_bbox = loss_bbox .item ())
63
+ metric_logger .update (loss_giou = loss_giou .item ())
64
+ metric_logger .update (lr = optimizer .param_groups [0 ]["lr" ])
65
+
66
+ # gather the stats from all processes
67
+ metric_logger .synchronize_between_processes ()
68
+ print ("Averaged stats:" , metric_logger .global_avg ())
69
+ return {k : "{:.5f}" .format (meter .global_avg ) for k , meter in metric_logger .meters .items ()}
70
+
71
+
72
+ def val (model , data_loader , tokenizer , device ):
73
+ model .eval ()
74
+
75
+ metric_logger = utils .MetricLogger (delimiter = " " )
76
+ header = 'Evaluation:'
77
+ print_freq = 50
78
+
79
+ result = []
80
+ for image , text , ref_ids in metric_logger .log_every (data_loader , print_freq , header ):
81
+ image = image .to (device )
82
+ text_input = tokenizer (text , padding = 'longest' , return_tensors = "pt" ).to (device )
83
+
84
+ with torch .no_grad ():
85
+ outputs_coord = model (image , text_input .input_ids , text_input .attention_mask , target_bbox = None )
86
+
87
+ assert len (ref_ids ) == outputs_coord .shape [0 ]
88
+
89
+ for r_id , coord in zip (ref_ids , outputs_coord ):
90
+ result .append ({'ref_id' : r_id .item (), 'pred' : coord })
91
+
92
+ return result
93
+
94
+
95
+ def main (args , config ):
96
+ utils .init_distributed_mode (args )
97
+ device = torch .device (args .device )
98
+
99
+ world_size = utils .get_world_size ()
100
+
101
+ if world_size > 8 :
102
+ assert hexists (args .output_hdfs ) and args .output_hdfs .startswith ('hdfs' ), "for collect_result among nodes"
103
+
104
+ if args .bs > 0 :
105
+ config ['batch_size' ] = args .bs // world_size
106
+
107
+ seed = args .seed + utils .get_rank ()
108
+ torch .manual_seed (seed )
109
+ np .random .seed (seed )
110
+ random .seed (seed )
111
+ cudnn .benchmark = True
112
+
113
+ print ("Creating dataset" )
114
+ grd_train_dataset , grd_test_dataset = create_dataset ('grounding_bbox' , config , args .evaluate )
115
+
116
+
117
+ print ("Creating model" )
118
+ model = XVLMForGrounding (config = config )
119
+ model .load_pretrained (args .checkpoint , config , load_bbox_pretrain = args .load_bbox_pretrain , is_eval = args .evaluate )
120
+ model = model .to (device )
121
+ print ("### Total Params: " , sum (p .numel () for p in model .parameters () if p .requires_grad ))
122
+
123
+ model_without_ddp = model
124
+ if args .distributed :
125
+ model = torch .nn .parallel .DistributedDataParallel (model , device_ids = [args .gpu ])
126
+ model_without_ddp = model .module
127
+
128
+ tokenizer = build_tokenizer (config ['text_encoder' ])
129
+
130
+ print ("### output_dir, " , args .output_dir , flush = True )
131
+ print ("### output_hdfs, " , args .output_hdfs , flush = True )
132
+ start_time = time .time ()
133
+
134
+ if args .evaluate :
135
+ print ("Start evaluating" )
136
+
137
+ if args .distributed :
138
+ num_tasks = utils .get_world_size ()
139
+ global_rank = utils .get_rank ()
140
+ samplers = create_sampler ([grd_test_dataset ], [False ], num_tasks , global_rank )
141
+ else :
142
+ samplers = [None ]
143
+
144
+ test_loader = create_loader ([grd_test_dataset ], samplers ,
145
+ batch_size = [config ['batch_size' ]],
146
+ num_workers = [4 ], is_trains = [False ], collate_fns = [None ])[0 ]
147
+
148
+ result = val (model_without_ddp , test_loader , tokenizer , device )
149
+ results = collect_tensor_result (result , filename = 'grounding_bbox_eval' , local_wdir = args .result_dir ,
150
+ hdfs_wdir = args .output_hdfs ,
151
+ write_to_hdfs = world_size > 8 )
152
+
153
+ if utils .is_main_process ():
154
+ if 'vlue_test' in config .keys () and config ['vlue_test' ]:
155
+ grounding_acc = grounding_eval_bbox_vlue (results , config ['test_file' ][0 ])
156
+ else :
157
+ # refcoco evaluation tools
158
+ refer = REFER (config ['refcoco_data' ], 'refcoco+' , 'unc' )
159
+ grounding_acc = grounding_eval_bbox (results , refer )
160
+
161
+ log_stats = {** {f'{ k } ' : v for k , v in grounding_acc .items ()}}
162
+ print (log_stats )
163
+
164
+ dist .barrier ()
165
+
166
+ else :
167
+ print ("Start training" )
168
+
169
+ datasets = [grd_train_dataset , grd_test_dataset ]
170
+
171
+ train_dataset_size = len (grd_train_dataset )
172
+ train_batch_size = config ['batch_size' ]
173
+
174
+ if utils .is_main_process ():
175
+ print (f"### data { train_dataset_size } , batch size, { train_batch_size } x { world_size } " )
176
+
177
+ if args .distributed :
178
+ num_tasks = utils .get_world_size ()
179
+ global_rank = utils .get_rank ()
180
+ samplers = create_sampler (datasets , [True , False ], num_tasks , global_rank )
181
+ else :
182
+ samplers = [None , None ]
183
+
184
+ train_loader , test_loader = create_loader (datasets , samplers ,
185
+ batch_size = [config ['batch_size' ], config ['batch_size' ]],
186
+ num_workers = [4 , 4 ], is_trains = [True , False ], collate_fns = [None , None ])
187
+
188
+ arg_opt = utils .AttrDict (config ['optimizer' ])
189
+ optimizer = create_optimizer (arg_opt , model )
190
+ arg_sche = utils .AttrDict (config ['schedular' ])
191
+ accumulate_steps = int (config .get ('accumulate_steps' , 1 ))
192
+ arg_sche ['step_per_epoch' ] = math .ceil (train_dataset_size / (train_batch_size * world_size ) / accumulate_steps )
193
+ arg_sche ['min_rate' ] = config ['min_lr' ] / arg_opt ['lr' ] if 'min_lr' in config else 0
194
+ lr_scheduler = create_scheduler (arg_sche , optimizer )
195
+
196
+ checkpointer = Checkpointer (args .output_dir )
197
+
198
+ max_epoch = config ['schedular' ]['epochs' ]
199
+ best = 0
200
+ best_epoch = 0
201
+ for epoch in range (0 , max_epoch ):
202
+ if args .distributed :
203
+ train_loader .sampler .set_epoch (epoch )
204
+ train_stats = train (model , train_loader , optimizer , tokenizer , epoch , device , lr_scheduler , config )
205
+
206
+ result = val (model_without_ddp , test_loader , tokenizer , device )
207
+ results = collect_tensor_result (result , filename = 'epoch%d' % epoch , local_wdir = args .result_dir , hdfs_wdir = args .output_hdfs ,
208
+ write_to_hdfs = world_size > 8 )
209
+
210
+ if utils .is_main_process ():
211
+ # refcoco evaluation tools
212
+ refer = REFER (config ['refcoco_data' ], 'refcoco+' , 'unc' )
213
+ grounding_acc = grounding_eval_bbox (results , refer )
214
+ log_stats = {** {f'train_{ k } ' : v for k , v in train_stats .items ()},
215
+ ** {f'{ k } ' : v for k , v in grounding_acc .items ()},
216
+ 'epoch' : epoch }
217
+
218
+ if grounding_acc ['val_d' ] > best :
219
+ save_obj = {
220
+ 'model' : model_without_ddp .state_dict (),
221
+ # 'optimizer': optimizer.state_dict(),
222
+ # 'lr_scheduler': lr_scheduler.state_dict(),
223
+ 'config' : config ,
224
+ # 'epoch': epoch,
225
+ }
226
+
227
+ # torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
228
+ checkpointer .save_checkpoint (model_state = save_obj ,
229
+ epoch = 'best' , training_states = optimizer .state_dict ())
230
+
231
+ best = grounding_acc ['val_d' ]
232
+ best_epoch = epoch
233
+
234
+ with open (os .path .join (args .output_dir , "log.txt" ), "a" ) as f :
235
+ f .write (json .dumps (log_stats ) + "\n " )
236
+
237
+ dist .barrier ()
238
+
239
+ if utils .is_main_process ():
240
+ with open (os .path .join (args .output_dir , "log.txt" ), "a" ) as f :
241
+ f .write ("best epoch: %d" % best_epoch )
242
+
243
+ os .system (f"cat { args .output_dir } /log.txt" )
244
+
245
+ total_time = time .time () - start_time
246
+ total_time_str = str (datetime .timedelta (seconds = int (total_time )))
247
+ print ('### Time {}' .format (total_time_str ))
248
+
249
+
250
+ if __name__ == '__main__' :
251
+ parser = argparse .ArgumentParser ()
252
+ parser .add_argument ('--checkpoint' , type = str , required = True )
253
+ parser .add_argument ('--config' , type = str , default = 'configs/Grounding_bbox.yaml' )
254
+ parser .add_argument ('--output_dir' , type = str , default = 'output/refcoco_bbox' )
255
+ parser .add_argument ('--output_hdfs' , type = str , default = '' , help = "to collect eval results among nodes" )
256
+
257
+ parser .add_argument ('--device' , default = 'cuda' )
258
+ parser .add_argument ('--seed' , default = 42 , type = int )
259
+ parser .add_argument ('--world_size' , default = 1 , type = int , help = 'number of distributed processes' )
260
+ parser .add_argument ('--distributed' , action = 'store_false' )
261
+ parser .add_argument ('--dist_url' , default = 'env://' , help = 'url used to set up distributed training' )
262
+
263
+ parser .add_argument ('--evaluate' , action = 'store_true' )
264
+ parser .add_argument ('--override_cfg' , default = "" , type = str , help = "Use ; to separate keys" )
265
+ parser .add_argument ('--load_bbox_pretrain' , action = 'store_true' )
266
+ parser .add_argument ('--bs' , default = - 1 , type = int , help = "for each gpu, batch_size = bs // num_gpus" )
267
+
268
+ args = parser .parse_args ()
269
+
270
+ config = yaml .load (open (args .config , 'r' ), Loader = yaml .Loader )
271
+ utils .update_config (config , args .override_cfg )
272
+ if utils .is_main_process ():
273
+ print ('config:' , json .dumps (config ))
274
+ args .result_dir = os .path .join (args .output_dir , 'result' )
275
+ Path (args .output_dir ).mkdir (parents = True , exist_ok = True )
276
+ Path (args .result_dir ).mkdir (parents = True , exist_ok = True )
277
+
278
+ yaml .dump (config , open (os .path .join (args .output_dir , 'config.yaml' ), 'w' ))
279
+
280
+ if len (args .output_hdfs ):
281
+ hmkdir (args .output_hdfs )
282
+
283
+ main (args , config )
0 commit comments