@@ -228,21 +228,25 @@ def _collate(x):
228
228
until = [until ]
229
229
elif not isinstance (until , list ):
230
230
raise ValueError (f"Expected `gen_kwargs['until']` to be of type Union[str,list] but got { type (until )} " )
231
+
232
+ if isinstance (contexts , tuple ):
233
+ contexts = list (contexts )
234
+
231
235
for i in range (len (contexts )):
232
236
if "<image>" in contexts [i ]:
233
- context [i ] = contexts [i ].replace ("<image>" , "" )
234
- questions = [self .prompt .format (visual_path , context ) for visual_path , context in zip (visual_paths , contexts )]
237
+ contexts [i ] = contexts [i ].replace ("<image>" , "" )
235
238
236
239
# Similar to llava, is visual paths has len 0
237
240
# Then nothing will be executed
238
241
query = []
239
- for visual_path , context in zip (visual_paths , contexts ):
240
- query .append ({"image" : visual_path })
241
- query .append ({"text" : context })
242
-
243
242
if len (visual_paths ) == 0 :
244
243
for context in contexts :
245
244
query .append ({"text" : context })
245
+ else :
246
+ for visual_path , context in zip (visual_paths , contexts ):
247
+ query .append ({"image" : visual_path })
248
+ query .append ({"text" : context })
249
+
246
250
247
251
questions = self .tokenizer .from_list_format (query )
248
252
input_ids = self .tokenizer (questions , return_tensors = "pt" , padding = "longest" )
0 commit comments