1
1
# General imports and variables, as well as config
2
-
2
+ import ast
3
3
import math
4
4
import sys
5
5
import time
28
28
29
29
cache = Memory ('cache/' if config .use_cache else None , verbose = 0 )
30
30
31
-
32
31
mp .set_start_method ('spawn' , force = True )
33
32
from vision_processes import forward , finish_all_consumers # This import loads all the models. May take a while
34
33
from image_patch import *
39
38
40
39
time_wait_between_lines = 0.5
41
40
41
+
42
42
def inject_saver (code , show_intermediate_steps , syntax = None , time_wait_between_lines = None , console = None ):
43
43
injected_function_name = 'show_all'
44
44
if injected_function_name in code :
@@ -62,9 +62,10 @@ def inject_saver(code, show_intermediate_steps, syntax=None, time_wait_between_l
62
62
if show_intermediate_steps :
63
63
escape_thing = lambda x : x .replace ("'" , "\\ '" )
64
64
injection_string_format = \
65
- lambda thing : f"{ indent } { injected_function_name } (lineno={ n } ,value=({ thing } ),valuename='{ escape_thing (thing )} '," \
66
- f"fig=my_fig,console_in=console,time_wait_between_lines=time_wait_between_lines); " \
67
- f"CodexAtLine({ n } ,syntax=syntax,time_wait_between_lines=time_wait_between_lines)"
65
+ lambda \
66
+ thing : f"{ indent } { injected_function_name } (lineno={ n } ,value=({ thing } ),valuename='{ escape_thing (thing )} '," \
67
+ f"fig=my_fig,console_in=console,time_wait_between_lines=time_wait_between_lines); " \
68
+ f"CodexAtLine({ n } ,syntax=syntax,time_wait_between_lines=time_wait_between_lines)"
68
69
else :
69
70
injection_string_format = lambda thing : f"{ indent } CodexAtLine({ n } ,syntax=syntax," \
70
71
f"time_wait_between_lines=time_wait_between_lines)"
@@ -156,7 +157,7 @@ def get_thing_to_show_codetype(codeline):
156
157
157
158
if isinstance (thing_to_show , list ):
158
159
thing_to_show = [thing if not (thing .strip ().startswith ("'" ) and thing .strip ().endswith ("'" ))
159
- else thing .replace ("'" , '"' ) for thing in thing_to_show if thing is not None ]
160
+ else thing .replace ("'" , '"' ) for thing in thing_to_show if thing is not None ]
160
161
elif isinstance (thing_to_show , str ):
161
162
thing_to_show = thing_to_show if not (thing_to_show .strip ().startswith ("'" ) and
162
163
thing_to_show .strip ().endswith ("'" )) else thing_to_show .replace ("'" , '"' )
@@ -186,8 +187,9 @@ def CodexAtLine(lineno, syntax, time_wait_between_lines=1.):
186
187
time .sleep (time_wait_between_lines )
187
188
188
189
189
- def show_all (lineno , value , valuename , fig = None , usefig = True , disp = True , console_in = None , time_wait_between_lines = None , lastlineno = [- 1 ]):
190
- time .sleep (0.1 ) # to avoid race condition!
190
+ def show_all (lineno , value , valuename , fig = None , usefig = True , disp = True , console_in = None , time_wait_between_lines = None ,
191
+ lastlineno = [- 1 ]):
192
+ time .sleep (0.1 ) # to avoid race condition!
191
193
192
194
if console_in is None :
193
195
console_in = console
@@ -196,7 +198,7 @@ def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console
196
198
197
199
if lineno is not None and lineno != lastlineno [0 ]:
198
200
console_in .rule (f"[bold]Line { lineno } [/bold]" , style = "chartreuse2" )
199
- lastlineno [0 ] = lineno # ugly hack
201
+ lastlineno [0 ] = lineno # ugly hack
200
202
201
203
if usefig :
202
204
plt .clf ()
@@ -211,18 +213,26 @@ def show_all(lineno, value, valuename, fig=None, usefig=True, disp=True, console
211
213
if valuename :
212
214
console_in .print (f'{ rich_escape (valuename )} = ' )
213
215
show_one_image (thing_to_show .cropped_image , ax )
214
- elif isinstance (thing_to_show , list ):
216
+ elif isinstance (thing_to_show , list ) or isinstance ( thing_to_show , tuple ) :
215
217
if len (thing_to_show ) > 0 :
216
218
for i , thing in enumerate (thing_to_show ):
217
- disp_ = disp or i < len (thing_to_show )- 1
219
+ disp_ = disp or i < len (thing_to_show ) - 1
218
220
show_all (None , thing , f"{ rich_escape (valuename )} [{ i } ]" , fig = fig , disp = disp_ )
219
221
return
220
222
else :
221
223
console_in .print (f"{ rich_escape (valuename )} is empty" )
224
+ elif isinstance (thing_to_show , dict ):
225
+ if len (thing_to_show ) > 0 :
226
+ for i , (thing_k , thing_v ) in enumerate (thing_to_show .items ()):
227
+ disp_ = disp or i < len (thing_to_show ) - 1
228
+ show_all (None , thing_v , f"{ rich_escape (valuename )} ['{ thing_k } ']" , fig = fig , disp = disp_ )
229
+ return
230
+ else :
231
+ console_in .print (f"{ rich_escape (valuename )} is empty" )
222
232
else :
223
233
console_in .print (f"{ rich_escape (valuename )} = { thing_to_show } " )
224
234
if time_wait_between_lines is not None :
225
- time .sleep (time_wait_between_lines / 2 )
235
+ time .sleep (time_wait_between_lines / 2 )
226
236
return
227
237
228
238
# display small
@@ -243,14 +253,15 @@ def load_image(path):
243
253
244
254
245
255
def get_code (query ):
246
- show_intermediate_steps = True
247
256
code = forward ('codex' , prompt = query , input_type = "image" )
248
257
code = f'def execute_command(image, my_fig, time_wait_between_lines, syntax):' + code
249
258
code_for_syntax = code .replace ("(image, my_fig, time_wait_between_lines, syntax)" , "(image)" )
250
-
251
- syntax = Syntax (code_for_syntax , "python" , theme = "monokai" , line_numbers = True , start_line = 0 )
252
- console .print (syntax )
253
- return code , syntax
259
+ syntax_1 = Syntax (code_for_syntax , "python" , theme = "monokai" , line_numbers = True , start_line = 0 )
260
+ console .print (syntax_1 )
261
+ code = ast .unparse (ast .parse (code ))
262
+ code_for_syntax_2 = code .replace ("(image, my_fig, time_wait_between_lines, syntax)" , "(image)" )
263
+ syntax_2 = Syntax (code_for_syntax_2 , "python" , theme = "monokai" , line_numbers = True , start_line = 0 )
264
+ return code , syntax_2
254
265
255
266
256
267
def execute_code (code , im , show_intermediate_steps = True ):
@@ -261,9 +272,14 @@ def execute_code(code, im, show_intermediate_steps=True):
261
272
262
273
with Live (Padding (syntax , 1 ), refresh_per_second = 10 , console = console , auto_refresh = True ) as live :
263
274
my_fig = plt .figure (figsize = (4 , 4 ))
275
+ try :
276
+ exec (compile (code_line , 'Codex' , 'exec' ), globals ())
277
+ result = execute_command (im , my_fig , time_wait_between_lines , syntax ) # The code is created in the exec()
278
+ except Exception as e :
279
+ print (f"Encountered error { e } when trying to run with visualizations. Trying from scratch." )
280
+ exec (compile (code , 'Codex' , 'exec' ), globals ())
281
+ result = execute_command (im , my_fig , time_wait_between_lines , syntax ) # The code is created in the exec()
264
282
265
- exec (compile (code_line , 'Codex' , 'exec' ), globals ())
266
- result = execute_command (im , my_fig , time_wait_between_lines , syntax ) # The code is created in the exec()
267
283
plt .close (my_fig )
268
284
269
285
f = None
@@ -277,7 +293,7 @@ def execute_code(code, im, show_intermediate_steps=True):
277
293
278
294
279
295
def show_single_image (im ):
280
- im = Image .fromarray ((im .detach ().cpu ().numpy ().transpose (1 ,2 , 0 ) * 255 ).astype ("uint8" ))
296
+ im = Image .fromarray ((im .detach ().cpu ().numpy ().transpose (1 , 2 , 0 ) * 255 ).astype ("uint8" ))
281
297
im .copy ()
282
298
im .thumbnail ((400 , 400 ))
283
299
display (im )
0 commit comments