Skip to content

Commit c863ee0

Browse files
committed
visualization improvements and chatgpt API
1 parent 9700f4a commit c863ee0

File tree

5 files changed

+380
-25
lines changed

5 files changed

+380
-25
lines changed

configs/base_config.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ codex:
4848
temperature: 0. # Temperature for Codex. (Almost) deterministic if 0
4949
best_of: 1 # Number of tries to choose from. Use when temperature > 0
5050
max_tokens: 512 # Maximum number of tokens to generate for Codex
51-
prompt: ./prompts/api.prompt # Codex prompt file, which defines the API
52-
model: code-davinci-002 # Codex model to use. [code-davinci-002, gpt4]
51+
prompt: ./prompts/api.prompt # Codex prompt file, which defines the API. If you use a Chat-based model (3.5/4) try ./prompts/chatapi.prompt (doesn't support video for now due to token limits)
52+
model: code-davinci-002 # Codex model to use. [code-davinci-002, gpt-3.5-turbo, gpt-4]
5353

5454
# Saving and loading parameters
5555
save: True # Save the results to a file

configs/my_config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ path_pretrained_models: './pretrained_models'
44
dataset:
55
data_path: 'data'
66
blip_v2_model_type: blip2-flan-t5-xxl # Change to blip2-flan-t5-xl for smaller GPUs
7-
blip_half_precision: False
7+
blip_half_precision: True
88
# Add more changes here, following the same format as base_config.yaml

main_simple_lib.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# General imports and variables, as well as config
2-
2+
import ast
33
import math
44
import sys
55
import time
@@ -28,7 +28,6 @@
2828

2929
cache = Memory('cache/' if config.use_cache else None, verbose=0)
3030

31-
3231
mp.set_start_method('spawn', force=True)
3332
from vision_processes import forward, finish_all_consumers # This import loads all the models. May take a while
3433
from image_patch import *
@@ -39,6 +38,7 @@
3938

4039
time_wait_between_lines = 0.5
4140

41+
4242
def inject_saver(code, show_intermediate_steps, syntax=None, time_wait_between_lines=None, console=None):
4343
injected_function_name = 'show_all'
4444
if injected_function_name in code:
@@ -62,9 +62,10 @@ def inject_saver(code, show_intermediate_steps, syntax=None, time_wait_between_l
6262
if show_intermediate_steps:
6363
escape_thing = lambda x: x.replace("'", "\\'")
6464
injection_string_format = \
65-
lambda thing: f"{indent}{injected_function_name}(lineno={n},value=({thing}),valuename='{escape_thing(thing)}'," \
66-
f"fig=my_fig,console_in=console,time_wait_between_lines=time_wait_between_lines); " \
67-
f"CodexAtLine({n},syntax=syntax,time_wait_between_lines=time_wait_between_lines)"
65+
lambda \
66+
thing: f"{indent}{injected_function_name}(lineno={n},value=({thing}),valuename='{escape_thing(thing)}'," \
67+
f"fig=my_fig,console_in=console,time_wait_between_lines=time_wait_between_lines); " \
68+
f"CodexAtLine({n},syntax=syntax,time_wait_between_lines=time_wait_between_lines)"
6869
else:
6970
injection_string_format = lambda thing: f"{indent}CodexAtLine({n},syntax=syntax," \
7071
f"time_wait_between_lines=time_wait_between_lines)"
@@ -156,7 +157,7 @@ def get_thing_to_show_codetype(codeline):
156157

157158
if isinstance(thing_to_show, list):
158159
thing_to_show = [thing if not (thing.strip().startswith("'") and thing.strip().endswith("'"))
159-
else thing.replace("'", '"') for thing in thing_to_show if thing is not None]
160+
else thing.replace("'", '"') for thing in thing_to_show if thing is not None]
160161
elif isinstance(thing_to_show, str):
161162
thing_to_show = thing_to_show if not (thing_to_show.strip().startswith("'") and
162163
thing_to_show.strip().endswith("'")) else thing_to_show.replace("'", '"')
@@ -186,8 +187,9 @@ def CodexAtLine(lineno, syntax, time_wait_between_lines=1.):
186187
time.sleep(time_wait_between_lines)
187188

188189

189-
def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None, lastlineno=[-1]):
190-
time.sleep(0.1) # to avoid race condition!
190+
def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console_in=None, time_wait_between_lines=None,
191+
lastlineno=[-1]):
192+
time.sleep(0.1) # to avoid race condition!
191193

192194
if console_in is None:
193195
console_in = console
@@ -196,7 +198,7 @@ def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console
196198

197199
if lineno is not None and lineno != lastlineno[0]:
198200
console_in.rule(f"[bold]Line {lineno}[/bold]", style="chartreuse2")
199-
lastlineno[0] = lineno # ugly hack
201+
lastlineno[0] = lineno # ugly hack
200202

201203
if usefig:
202204
plt.clf()
@@ -211,18 +213,26 @@ def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console
211213
if valuename:
212214
console_in.print(f'{rich_escape(valuename)} = ')
213215
show_one_image(thing_to_show.cropped_image, ax)
214-
elif isinstance(thing_to_show, list):
216+
elif isinstance(thing_to_show, list) or isinstance(thing_to_show, tuple):
215217
if len(thing_to_show) > 0:
216218
for i, thing in enumerate(thing_to_show):
217-
disp_ = disp or i < len(thing_to_show)-1
219+
disp_ = disp or i < len(thing_to_show) - 1
218220
show_all(None, thing, f"{rich_escape(valuename)}[{i}]", fig=fig, disp=disp_)
219221
return
220222
else:
221223
console_in.print(f"{rich_escape(valuename)} is empty")
224+
elif isinstance(thing_to_show, dict):
225+
if len(thing_to_show) > 0:
226+
for i, (thing_k, thing_v) in enumerate(thing_to_show.items()):
227+
disp_ = disp or i < len(thing_to_show) - 1
228+
show_all(None, thing_v, f"{rich_escape(valuename)}['{thing_k}']", fig=fig, disp=disp_)
229+
return
230+
else:
231+
console_in.print(f"{rich_escape(valuename)} is empty")
222232
else:
223233
console_in.print(f"{rich_escape(valuename)} = {thing_to_show}")
224234
if time_wait_between_lines is not None:
225-
time.sleep(time_wait_between_lines/2)
235+
time.sleep(time_wait_between_lines / 2)
226236
return
227237

228238
# display small
@@ -243,14 +253,15 @@ def load_image(path):
243253

244254

245255
def get_code(query):
246-
show_intermediate_steps = True
247256
code = forward('codex', prompt=query, input_type="image")
248257
code = f'def execute_command(image, my_fig, time_wait_between_lines, syntax):' + code
249258
code_for_syntax = code.replace("(image, my_fig, time_wait_between_lines, syntax)", "(image)")
250-
251-
syntax = Syntax(code_for_syntax, "python", theme="monokai", line_numbers=True, start_line=0)
252-
console.print(syntax)
253-
return code, syntax
259+
syntax_1 = Syntax(code_for_syntax, "python", theme="monokai", line_numbers=True, start_line=0)
260+
console.print(syntax_1)
261+
code = ast.unparse(ast.parse(code))
262+
code_for_syntax_2 = code.replace("(image, my_fig, time_wait_between_lines, syntax)", "(image)")
263+
syntax_2 = Syntax(code_for_syntax_2, "python", theme="monokai", line_numbers=True, start_line=0)
264+
return code, syntax_2
254265

255266

256267
def execute_code(code, im, show_intermediate_steps=True):
@@ -261,9 +272,14 @@ def execute_code(code, im, show_intermediate_steps=True):
261272

262273
with Live(Padding(syntax, 1), refresh_per_second=10, console=console, auto_refresh=True) as live:
263274
my_fig = plt.figure(figsize=(4, 4))
275+
try:
276+
exec(compile(code_line, 'Codex', 'exec'), globals())
277+
result = execute_command(im, my_fig, time_wait_between_lines, syntax) # The code is created in the exec()
278+
except Exception as e:
279+
print(f"Encountered error {e} when trying to run with visualizations. Trying from scratch.")
280+
exec(compile(code, 'Codex', 'exec'), globals())
281+
result = execute_command(im, my_fig, time_wait_between_lines, syntax) # The code is created in the exec()
264282

265-
exec(compile(code_line, 'Codex', 'exec'), globals())
266-
result = execute_command(im, my_fig, time_wait_between_lines, syntax) # The code is created in the exec()
267283
plt.close(my_fig)
268284

269285
f = None
@@ -277,7 +293,7 @@ def execute_code(code, im, show_intermediate_steps=True):
277293

278294

279295
def show_single_image(im):
280-
im = Image.fromarray((im.detach().cpu().numpy().transpose(1,2,0)*255).astype("uint8"))
296+
im = Image.fromarray((im.detach().cpu().numpy().transpose(1, 2, 0) * 255).astype("uint8"))
281297
im.copy()
282298
im.thumbnail((400, 400))
283299
display(im)

0 commit comments

Comments
 (0)