-
Notifications
You must be signed in to change notification settings - Fork 1
/
VLM_stage.py
154 lines (118 loc) · 5.13 KB
/
VLM_stage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import json
import base64
import random
import argparse
import natsort
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from src.run_gpt import run_gpt
random.seed(10)
dict_api = {
"api_key":"ADD",
}
class CustomDatasetGPT(Dataset):
def __init__(self, questions, num_kf):
self.questions = questions
self.num_kf = num_kf
def __getitem__(self, index):
line = self.questions[index]
group = 4
newnum_per_group = self.num_kf // group
oldnum_per_group = len(line["VLM_path"]) // group
assert oldnum_per_group >= newnum_per_group, f"oldnum_per_group:{oldnum_per_group} is smaller than newnum_per_group:{newnum_per_group}"
new_kf_paths = []
new_kf_timelines = []
for i in range(group):
start_index = i * oldnum_per_group
end_index = start_index + oldnum_per_group
sub_kf_paths = line["VLM_path"][start_index:min(end_index, len(line["VLM_path"]))]
sub_kf_timelines = line["VLM_timeline"][start_index:min(end_index, len(line["VLM_timeline"]))]
new_kf_paths.extend(sub_kf_paths[:newnum_per_group])
new_kf_timelines.extend(sub_kf_timelines[:newnum_per_group])
kf_paths = natsort.natsorted(new_kf_paths)
kf_timelines = natsort.natsorted(new_kf_timelines)
images = []
images_base64 = []
for e in kf_paths:
images.append(Image.open(e).convert('RGB'))
images_base64.append(encode_image(e))
return images_base64, kf_paths, kf_timelines
def __len__(self):
return len(self.questions)
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def create_data_loader_gpt(questions, num_kf, batch_size=1, num_workers=4):
assert batch_size == 1, "batch_size must be 1"
dataset = CustomDatasetGPT(questions, num_kf)
data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
return data_loader, dataset
def eval_model(args):
base_dir, question_path, vlm, num_kf, temp = (
args.output_dir,
args.question_path,
args.gptmodel,
args.num_kf,
args.temp,
)
questions = [json.loads(q) for q in open(os.path.expanduser(question_path), "r")]
fname = question_path.split('/')[-1]
answer_path = f"{base_dir}/egoschema/{num_kf}/{fname}"
os.makedirs(os.path.dirname(answer_path), exist_ok=True)
print(f"question_path:{question_path}\nanswer_path:{answer_path}")
ans_file = open(answer_path, "w")
data_loader, dataset = create_data_loader_gpt(questions, num_kf)
for (base64_image, kf_paths, kf_timelines), line in tqdm(zip(data_loader, questions), total=len(questions)):
idx = line["q_uid"]
CA = line["CA"] if "CA" in line else None
option0 = line['option 0']
option1 = line['option 1']
option2 = line['option 2']
option3 = line['option 3']
option4 = line['option 4']
question = line['question']
lenwords = "50"
prompt = f"'C' stands for the cameraman. Describe the activity depicted in this first-person perspective image in less than {lenwords} words. In your answer, don't mention that the image is in first-person perspective, as we already know this."
prompts = [prompt] * num_kf
image_paths = [e[0] for e in kf_paths]
image_timelines = [e[0] for e in kf_timelines]
output_VLM = run_gpt(
images=image_paths,
texts=prompts,
api_keys=list(dict_api.values()),
max_tokens=2000,
model=vlm,
temperature=temp,
num_threads=20, # Tune this
backoff_time=1 * 60,
silent=False,
dataset="egoschema",
verbose=False,
)
output_VLM = list(output_VLM)
for j, e in enumerate(image_timelines):
line_frame = line.copy()
line_frame["answer"] = f"At {str(e)} seconds, {output_VLM[j]}"
line_frame["AR-VLM_model_id"] = vlm
line_frame["AR-VLM_prompt"] = prompts[j]
line_frame["timeline"] = float(e)
line_frame["frame_idx"] = j
line_frame["image_paths"] = image_paths
if "imgidx_kw_dict" in line_frame.keys(): line_frame.pop("imgidx_kw_dict")
if "google_drive_id" in line_frame.keys(): line_frame.pop("google_drive_id")
ans_file.write(json.dumps(line_frame)+"\n")
print(f"question.\nquestion_path:{question_path}\nanswer_path:{answer_path}")
ans_file.close()
return "job is done"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str)
parser.add_argument("--question-path", type=str, default="")
parser.add_argument("--num-kf", type=int)
parser.add_argument("--gptmodel", type=str, default="gpt-4o")
parser.add_argument("--temp", type=float, default=None)
args = parser.parse_args()
eval_model(args)