Skip to content

Commit ac040c8

Browse files
committed
update: code 240822
Change-Id: I979120eecab1b45527452bec8a31817b8d170c8b
1 parent b0da355 commit ac040c8

File tree

105 files changed

+28923
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+28923
-1
lines changed

Grounding_bbox.py

+283
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
import argparse
2+
import datetime
3+
import json
4+
import math
5+
import os
6+
import random
7+
import time
8+
from pathlib import Path
9+
10+
import numpy as np
11+
import ruamel.yaml as yaml
12+
import torch
13+
import torch.backends.cudnn as cudnn
14+
import torch.distributed as dist
15+
16+
import utils
17+
from utils.checkpointer import Checkpointer
18+
19+
from dataset import create_dataset, create_sampler, create_loader, build_tokenizer
20+
from dataset.utils import collect_tensor_result, grounding_eval_bbox, grounding_eval_bbox_vlue
21+
22+
from models.model_grounding import XVLMForGrounding
23+
24+
from optim import create_optimizer
25+
from refTools.refer_python3 import REFER
26+
from scheduler import create_scheduler
27+
from utils.hdfs_io import hmkdir, hcopy, hexists
28+
29+
30+
def train(model, data_loader, optimizer, tokenizer, epoch, device, scheduler, config):
31+
model.train()
32+
33+
metric_logger = utils.MetricLogger(delimiter=" ")
34+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
35+
metric_logger.add_meter('loss_bbox', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
36+
metric_logger.add_meter('loss_giou', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
37+
header = 'Train Epoch: [{}]'.format(epoch)
38+
print_freq = 50
39+
step_size = 100
40+
41+
accumulate_steps = int(config.get('accumulate_steps', 1))
42+
for i, (image, text, target_bbox) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
43+
image = image.to(device, non_blocking=True)
44+
text_input = tokenizer(text, padding='longest', truncation=True, max_length=config['max_tokens'], return_tensors="pt").to(device)
45+
target_bbox = target_bbox.to(device)
46+
47+
_, loss_bbox, loss_giou = model(image, text_input.input_ids, text_input.attention_mask, target_bbox=target_bbox)
48+
loss = loss_bbox + loss_giou
49+
50+
if accumulate_steps > 1:
51+
loss = loss / accumulate_steps
52+
53+
# backward
54+
loss.backward()
55+
56+
if (i+1) % accumulate_steps == 0:
57+
# update
58+
optimizer.step()
59+
scheduler.step()
60+
optimizer.zero_grad()
61+
62+
metric_logger.update(loss_bbox=loss_bbox.item())
63+
metric_logger.update(loss_giou=loss_giou.item())
64+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
65+
66+
# gather the stats from all processes
67+
metric_logger.synchronize_between_processes()
68+
print("Averaged stats:", metric_logger.global_avg())
69+
return {k: "{:.5f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
70+
71+
72+
def val(model, data_loader, tokenizer, device):
73+
model.eval()
74+
75+
metric_logger = utils.MetricLogger(delimiter=" ")
76+
header = 'Evaluation:'
77+
print_freq = 50
78+
79+
result = []
80+
for image, text, ref_ids in metric_logger.log_every(data_loader, print_freq, header):
81+
image = image.to(device)
82+
text_input = tokenizer(text, padding='longest', return_tensors="pt").to(device)
83+
84+
with torch.no_grad():
85+
outputs_coord = model(image, text_input.input_ids, text_input.attention_mask, target_bbox=None)
86+
87+
assert len(ref_ids) == outputs_coord.shape[0]
88+
89+
for r_id, coord in zip(ref_ids, outputs_coord):
90+
result.append({'ref_id': r_id.item(), 'pred': coord})
91+
92+
return result
93+
94+
95+
def main(args, config):
96+
utils.init_distributed_mode(args)
97+
device = torch.device(args.device)
98+
99+
world_size = utils.get_world_size()
100+
101+
if world_size > 8:
102+
assert hexists(args.output_hdfs) and args.output_hdfs.startswith('hdfs'), "for collect_result among nodes"
103+
104+
if args.bs > 0:
105+
config['batch_size'] = args.bs // world_size
106+
107+
seed = args.seed + utils.get_rank()
108+
torch.manual_seed(seed)
109+
np.random.seed(seed)
110+
random.seed(seed)
111+
cudnn.benchmark = True
112+
113+
print("Creating dataset")
114+
grd_train_dataset, grd_test_dataset = create_dataset('grounding_bbox', config, args.evaluate)
115+
116+
117+
print("Creating model")
118+
model = XVLMForGrounding(config=config)
119+
model.load_pretrained(args.checkpoint, config, load_bbox_pretrain=args.load_bbox_pretrain, is_eval=args.evaluate)
120+
model = model.to(device)
121+
print("### Total Params: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
122+
123+
model_without_ddp = model
124+
if args.distributed:
125+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
126+
model_without_ddp = model.module
127+
128+
tokenizer = build_tokenizer(config['text_encoder'])
129+
130+
print("### output_dir, ", args.output_dir, flush=True)
131+
print("### output_hdfs, ", args.output_hdfs, flush=True)
132+
start_time = time.time()
133+
134+
if args.evaluate:
135+
print("Start evaluating")
136+
137+
if args.distributed:
138+
num_tasks = utils.get_world_size()
139+
global_rank = utils.get_rank()
140+
samplers = create_sampler([grd_test_dataset], [False], num_tasks, global_rank)
141+
else:
142+
samplers = [None]
143+
144+
test_loader = create_loader([grd_test_dataset], samplers,
145+
batch_size=[config['batch_size']],
146+
num_workers=[4], is_trains=[False], collate_fns=[None])[0]
147+
148+
result = val(model_without_ddp, test_loader, tokenizer, device)
149+
results = collect_tensor_result(result, filename='grounding_bbox_eval', local_wdir=args.result_dir,
150+
hdfs_wdir=args.output_hdfs,
151+
write_to_hdfs=world_size > 8)
152+
153+
if utils.is_main_process():
154+
if 'vlue_test' in config.keys() and config['vlue_test']:
155+
grounding_acc = grounding_eval_bbox_vlue(results, config['test_file'][0])
156+
else:
157+
# refcoco evaluation tools
158+
refer = REFER(config['refcoco_data'], 'refcoco+', 'unc')
159+
grounding_acc = grounding_eval_bbox(results, refer)
160+
161+
log_stats = {**{f'{k}': v for k, v in grounding_acc.items()}}
162+
print(log_stats)
163+
164+
dist.barrier()
165+
166+
else:
167+
print("Start training")
168+
169+
datasets = [grd_train_dataset, grd_test_dataset]
170+
171+
train_dataset_size = len(grd_train_dataset)
172+
train_batch_size = config['batch_size']
173+
174+
if utils.is_main_process():
175+
print(f"### data {train_dataset_size}, batch size, {train_batch_size} x {world_size}")
176+
177+
if args.distributed:
178+
num_tasks = utils.get_world_size()
179+
global_rank = utils.get_rank()
180+
samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
181+
else:
182+
samplers = [None, None]
183+
184+
train_loader, test_loader = create_loader(datasets, samplers,
185+
batch_size=[config['batch_size'], config['batch_size']],
186+
num_workers=[4, 4], is_trains=[True, False], collate_fns=[None, None])
187+
188+
arg_opt = utils.AttrDict(config['optimizer'])
189+
optimizer = create_optimizer(arg_opt, model)
190+
arg_sche = utils.AttrDict(config['schedular'])
191+
accumulate_steps = int(config.get('accumulate_steps', 1))
192+
arg_sche['step_per_epoch'] = math.ceil(train_dataset_size/(train_batch_size*world_size) / accumulate_steps)
193+
arg_sche['min_rate'] = config['min_lr'] / arg_opt['lr'] if 'min_lr' in config else 0
194+
lr_scheduler = create_scheduler(arg_sche, optimizer)
195+
196+
checkpointer = Checkpointer(args.output_dir)
197+
198+
max_epoch = config['schedular']['epochs']
199+
best = 0
200+
best_epoch = 0
201+
for epoch in range(0, max_epoch):
202+
if args.distributed:
203+
train_loader.sampler.set_epoch(epoch)
204+
train_stats = train(model, train_loader, optimizer, tokenizer, epoch, device, lr_scheduler, config)
205+
206+
result = val(model_without_ddp, test_loader, tokenizer, device)
207+
results = collect_tensor_result(result, filename='epoch%d' % epoch, local_wdir=args.result_dir, hdfs_wdir=args.output_hdfs,
208+
write_to_hdfs=world_size > 8)
209+
210+
if utils.is_main_process():
211+
# refcoco evaluation tools
212+
refer = REFER(config['refcoco_data'], 'refcoco+', 'unc')
213+
grounding_acc = grounding_eval_bbox(results, refer)
214+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
215+
**{f'{k}': v for k, v in grounding_acc.items()},
216+
'epoch': epoch}
217+
218+
if grounding_acc['val_d'] > best:
219+
save_obj = {
220+
'model': model_without_ddp.state_dict(),
221+
# 'optimizer': optimizer.state_dict(),
222+
# 'lr_scheduler': lr_scheduler.state_dict(),
223+
'config': config,
224+
# 'epoch': epoch,
225+
}
226+
227+
# torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
228+
checkpointer.save_checkpoint(model_state=save_obj,
229+
epoch='best', training_states=optimizer.state_dict())
230+
231+
best = grounding_acc['val_d']
232+
best_epoch = epoch
233+
234+
with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
235+
f.write(json.dumps(log_stats) + "\n")
236+
237+
dist.barrier()
238+
239+
if utils.is_main_process():
240+
with open(os.path.join(args.output_dir, "log.txt"), "a") as f:
241+
f.write("best epoch: %d" % best_epoch)
242+
243+
os.system(f"cat {args.output_dir}/log.txt")
244+
245+
total_time = time.time() - start_time
246+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
247+
print('### Time {}'.format(total_time_str))
248+
249+
250+
if __name__ == '__main__':
251+
parser = argparse.ArgumentParser()
252+
parser.add_argument('--checkpoint', type=str, required=True)
253+
parser.add_argument('--config', type=str, default='configs/Grounding_bbox.yaml')
254+
parser.add_argument('--output_dir', type=str, default='output/refcoco_bbox')
255+
parser.add_argument('--output_hdfs', type=str, default='', help="to collect eval results among nodes")
256+
257+
parser.add_argument('--device', default='cuda')
258+
parser.add_argument('--seed', default=42, type=int)
259+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
260+
parser.add_argument('--distributed', action='store_false')
261+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
262+
263+
parser.add_argument('--evaluate', action='store_true')
264+
parser.add_argument('--override_cfg', default="", type=str, help="Use ; to separate keys")
265+
parser.add_argument('--load_bbox_pretrain', action='store_true')
266+
parser.add_argument('--bs', default=-1, type=int, help="for each gpu, batch_size = bs // num_gpus")
267+
268+
args = parser.parse_args()
269+
270+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
271+
utils.update_config(config, args.override_cfg)
272+
if utils.is_main_process():
273+
print('config:', json.dumps(config))
274+
args.result_dir = os.path.join(args.output_dir, 'result')
275+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
276+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
277+
278+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
279+
280+
if len(args.output_hdfs):
281+
hmkdir(args.output_hdfs)
282+
283+
main(args, config)

LICENSE

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Copyright (c) 2023, ByteDance Inc.
2+
All rights reserved.
3+
4+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5+
6+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7+
8+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9+
10+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11+
12+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

0 commit comments

Comments
 (0)