Skip to content

Commit 135d21a

Browse files
committed
Bug fixes for base config
1 parent d1d7f6f commit 135d21a

File tree

8 files changed

+56
-17
lines changed

8 files changed

+56
-17
lines changed

GLIP

Submodule GLIP updated 1 file

configs/base_config.yaml

+4-3
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@ gpt3: # GPT-3 configuration
4242
n_votes: 1 # Number of tries to use for GPT-3. Use with temperature > 0
4343
qa_prompt: ./prompts/gpt3/gpt3_qa.txt
4444
temperature: 0. # Temperature for GPT-3. Almost deterministic if 0
45-
model: text-davinci-003 # Can replace with code-davinci-002 (which is free for now) but will have worse performance as it's meant for code
45+
model: text-davinci-003 # See openai.Model.list() for available models
4646

4747
codex:
4848
temperature: 0. # Temperature for Codex. (Almost) deterministic if 0
4949
best_of: 1 # Number of tries to choose from. Use when temperature > 0
5050
max_tokens: 512 # Maximum number of tokens to generate for Codex
51-
prompt: ./prompts/api.prompt # Codex prompt file, which defines the API. If you use a Chat-based model (3.5/4) try ./prompts/chatapi.prompt (doesn't support video for now due to token limits)
52-
model: code-davinci-002 # Codex model to use. [code-davinci-002, gpt-3.5-turbo, gpt-4]
51+
prompt: ./prompts/chatapi.prompt # Codex prompt file, which defines the API. (doesn't support video for now due to token limits)
52+
model: gpt-3.5-turbo # Codex model to use. [gpt-3.5-turbo, gpt-4]. See openai.Model.list()
5353

5454
# Saving and loading parameters
5555
save: True # Save the results to a file
@@ -60,6 +60,7 @@ clear_cache: False # Clear stored cache
6060
use_cached_codex: False # Use previously-computed Codex results
6161
cached_codex_path: '' # Path to the csv results file from which to load Codex results
6262
log_every: 20 # Log accuracy every n batches
63+
wandb: False # Use Weights and Biases
6364

6465
blip_half_precision: True # Use 8bit (Faster but slightly less accurate) for BLIP if True
6566
blip_v2_model_type: blip2-flan-t5-xxl # Which model to use for BLIP-2

configs/my_config.yaml

+24-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,27 @@ dataset:
55
data_path: 'data'
66
blip_v2_model_type: blip2-flan-t5-xxl # Change to blip2-flan-t5-xl for smaller GPUs
77
blip_half_precision: True
8-
# Add more changes here, following the same format as base_config.yaml
8+
# Add more changes here, following the same format as base_config.yaml
9+
10+
load_models: # Which pretrained models to load
11+
maskrcnn: False
12+
clip: False
13+
glip: False
14+
owlvit: False
15+
tcl: False
16+
gpt3_qa: True
17+
gpt3_general: True
18+
depth: False
19+
blip: False
20+
saliency: False
21+
xvlm: False
22+
codex: True
23+
object_detector: False
24+
25+
# wandb: False
26+
#
27+
28+
# codex:
29+
# model: gpt-3.5-turbo
30+
31+
# execute_code: True

data/queries.csv

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
1-
query,answer,image_name
2-
What color do you get if you combine the colors of the viper and the flower?,purple,viper_flower.png
3-
Tell me about the competition between the two skyscrapers in the image.,,skyscrapers.png
1+
index,sample_id,possible_answers,query_type,info_to_prompt,query,answer,image_name,img,
2+
0,0,How many cookies are there?,seven,,How many cookies are there?,seven,cookies.png,cookies.png

datasets/dataset.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,17 @@ def get_video(self, video_path):
6969
return video
7070

7171
def __getitem__(self, index):
72-
query = self.df.iloc[index]["query"]
73-
answer = self.df.iloc[index]["answer"]
72+
73+
out_dict = self.df.iloc[index].to_dict()
74+
7475
sample_path = self.get_sample_path(index)
7576

7677
# Load and transform image
7778
image = self.get_image(sample_path) if self.input_type == "image" else self.get_video(sample_path)
7879

79-
out_dict = {"query": query, "answer": answer, "image": image, 'index': index}
80+
out_dict["image"] = image
81+
out_dict["index"] = index
82+
8083
return out_dict
8184

8285
def __len__(self):

image_patch.py

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def find(self, object_name: str) -> list[ImagePatch]:
141141
if object_name == 'person':
142142
object_name = 'people' # GLIP does better at people than person
143143

144+
# all_object_coordinates = self.forward('object_detector', self.cropped_image)
144145
all_object_coordinates = self.forward('glip', self.cropped_image, object_name)
145146
if len(all_object_coordinates) == 0:
146147
return []
@@ -155,6 +156,7 @@ def find(self, object_name: str) -> list[ImagePatch]:
155156
# mask = all_areas == all_areas.max() # At least return one element
156157
all_object_coordinates = all_object_coordinates[mask]
157158

159+
158160
return [self.crop(*coordinates) for coordinates in all_object_coordinates]
159161

160162
def exists(self, object_name) -> bool:

main_batch.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pathlib
55
from functools import partial
66
import warnings
7+
import traceback
8+
79

810
import pandas as pd
911
import torch.multiprocessing as mp
@@ -79,6 +81,8 @@ def run_program(parameters, queues_in_, input_type_, retrying=False):
7981
# Functions to be used
8082
llm_query_partial, bool_to_yesno, distance, best_image_match)
8183
except Exception as e:
84+
# print full traceback
85+
traceback.print_exc()
8286
if retrying:
8387
return None, code
8488
print(f'Sample {sample_id} failed with error: {e}. Next you will see an "expected an indented block" error. ')
@@ -110,6 +114,7 @@ def main():
110114
batch_size = config.dataset.batch_size
111115
num_processes = min(batch_size, 50)
112116

117+
113118
if config.multiprocessing:
114119
queue_results_main = manager.Queue()
115120
queues_results = [manager.Queue() for _ in range(batch_size)]
@@ -126,11 +131,11 @@ def main():
126131
import wandb
127132
wandb.init(project="viper", config=OmegaConf.to_container(config))
128133
# log the prompt file
129-
wandb.save(config.prompt)
134+
wandb.save(config.codex.prompt)
130135

131136
dataset = MyDataset(**config.dataset)
132137

133-
with open(config.prompt) as f:
138+
with open(config.codex.prompt) as f:
134139
base_prompt = f.read().strip()
135140

136141
codes_all = None
@@ -155,12 +160,15 @@ def main():
155160
if config.multiprocessing else open(os.devnull, "w") as pool:
156161
try:
157162
n_batches = len(dataloader)
163+
158164
for i, batch in tqdm(enumerate(dataloader), total=n_batches):
159165

160166
# Combine all querys and get Codex predictions for them
161167
# TODO compute Codex for next batch as current batch is being processed
168+
162169
if not config.use_cached_codex:
163-
codes = codex(prompt=batch['info_to_prompt'], base_prompt=base_prompt)
170+
# codes = codex(prompt=batch['info_to_qprompt'], base_prompt=base_prompt)
171+
codes = codex(prompt=batch['query'], base_prompt=base_prompt)
164172

165173
else:
166174
codes = codes_all[i * batch_size:(i + 1) * batch_size] # If cache
@@ -171,13 +179,13 @@ def main():
171179
# Otherwise, we would create a new model for every process
172180
results = []
173181
for c, sample_id, img, possible_answers, query in \
174-
zip(codes, batch['sample_id'], batch['img'], batch['possible_answers'], batch['query']):
182+
zip(codes, batch['sample_id'], batch['image'], batch['possible_answers'], batch['query']):
175183
result = run_program([c, sample_id, img, possible_answers, query], queues_in, input_type)
176184
results.append(result)
177185
else:
178186
results = list(pool.imap(partial(
179187
run_program, queues_in_=queues_in, input_type_=input_type),
180-
zip(codes, batch['sample_id'], batch['img'], batch['possible_answers'], batch['query'])))
188+
zip(codes, batch['sample_id'], batch['image'], batch['possible_answers'], batch['query'])))
181189
else:
182190
results = [(None, c) for c in codes]
183191
warnings.warn("Not executing code! This is only generating the code. We set the flag "
@@ -192,7 +200,7 @@ def main():
192200
all_possible_answers += batch['possible_answers']
193201
all_query_types += batch['query_type']
194202
all_querys += batch['query']
195-
all_img_paths += [dataset.get_img_path(idx) for idx in batch['index']]
203+
all_img_paths += [dataset.get_sample_path(idx) for idx in batch['index']]
196204
if i % config.log_every == 0:
197205
try:
198206
accuracy = datasets.accuracy(all_results, all_answers, all_possible_answers, all_query_types)
@@ -201,6 +209,8 @@ def main():
201209
console.print(f'Error computing accuracy: {e}')
202210

203211
except Exception as e:
212+
# print full stack trace
213+
traceback.print_exc()
204214
console.print(f'Exception: {e}')
205215
console.print("Completing logging and exiting...")
206216

vision_models.py

+1
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,7 @@ def codex_helper(extended_prompt):
940940
# if len(resp) == 1:
941941
# resp = resp[0]
942942
else:
943+
warnings.warn('OpenAI Codex is deprecated. Please use GPT-4 or GPT-3.5-turbo.')
943944
response = openai.Completion.create(
944945
model="code-davinci-002",
945946
temperature=config.codex.temperature,

0 commit comments

Comments
 (0)