-
Notifications
You must be signed in to change notification settings - Fork 2k
/
Copy pathhf_vlms.py
724 lines (613 loc) · 29.7 KB
/
hf_vlms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
import copy
from typing import Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from tqdm import tqdm
from transformers import BatchEncoding
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import (
Collator,
flatten_image_list,
pad_and_concat,
replace_placeholders,
stop_sequences_criteria,
)
DEFAULT_IMAGE_PLACEHOLDER = "<image>"
eval_logger = utils.eval_logger
@register_model("hf-multimodal")
class HFMultimodalLM(HFLM):
"""
An abstracted Hugging Face model class for multimodal LMs like Llava and Idefics.
"""
AUTO_MODEL_CLASS = transformers.AutoModelForVision2Seq
MULTIMODAL = True # flag to indicate, for now, that this model type can run multimodal requests
def __init__(
self,
pretrained: Union[str, transformers.PreTrainedModel],
image_token_id: Optional[int] = None,
image_string: Optional[str] = None,
interleave: bool = True,
# TODO: handle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
**kwargs,
):
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
# modify init behavior.
super().__init__(pretrained, **kwargs)
assert (
self.batch_size != "auto"
), "Batch size 'auto' is not yet supported for hf-multimodal models."
self.chat_applied: bool = False
# TODO: phi-3.5 "image placeholders" are <image_1>, <image_2>, ... in order. how to handle this case
# HF AutoModelForVision2Seq models have an `image_token_id` value in their configs
# denoting the token which indicates a location where an image will be substituted in.
# This can take different string values across models, e.g. <image> for Idefics2 and <|image_pad|> for Qwen2-VL
self.interleave = interleave
self.max_images = max_images
self.rgb = convert_img_format
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
if not image_string:
self.image_token_id = (
int(image_token_id)
if image_token_id
else (
getattr(self.config, "image_token_id", None)
or getattr(self.config, "image_token_index", None)
)
)
assert (
self.image_token_id is not None
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
# get the string this token ID corresponds to
self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False
)
if image_token_id is not None:
eval_logger.info(
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
)
else:
eval_logger.info(
f"A non-default image_token string with string value image_string='{image_string}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
)
self.image_token = image_string
def _create_tokenizer(
self,
pretrained: Union[str, transformers.PreTrainedModel],
tokenizer: Optional[
Union[
str,
transformers.ProcessorMixin,
]
],
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> None:
"""
Helper method during initialization.
For the multimodal variant, we initialize not just
`self.tokenizer` but also `self.processor`.
"""
if tokenizer:
if isinstance(tokenizer, str):
return transformers.AutoProcessor.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
# use_fast=use_fast_tokenizer,
)
else:
assert isinstance(
tokenizer, transformers.ProcessorMixin
) # TODO: check this condition
return tokenizer
# Get tokenizer based on 'pretrained'
if isinstance(pretrained, str):
model_name = pretrained
else:
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.processor = transformers.AutoProcessor.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
# use_fast=use_fast_tokenizer,
)
self.tokenizer = self.processor.tokenizer
def tok_multimodal_encode(
self, string, images, left_truncate_len=None, add_special_tokens=None
):
"""Helper function which encodes an image + string combo using AutoProcessor"""
# We inherit special token kwarg setup from HFLM.tok_encode
# special_tokens_kwargs = {}
# by default for CausalLM - false or self.add_bos_token is set
# if add_special_tokens is None:
# special_tokens_kwargs = {"add_special_tokens": False or self.add_bos_token}
# otherwise the method explicitly defines the value
# else:
# special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
# encode text+images
# TODO: why does (Qwen2-VL) processor error when attempting to add special tokens to text?
encoding = self.processor(
text=string, images=images, return_tensors=None
) # , **special_tokens_kwargs)
# remove (and store) our tokenized text
text_encoding = encoding.pop("input_ids")
encoding.pop("attention_mask")
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
text_encoding = text_encoding[-left_truncate_len:]
return text_encoding, encoding # image_encoding is a dict
def _encode_multimodal_pair(self, context, continuation, images):
"""Helper function to perform the role of TemplateLM._encode_pair
Except allowing for image input to also be processed alongside `context`.
This method is a bit messy due to the need to defer conversion of image and text token input
into PyTorch tensors until the main inference loop.
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
context = context[:-n_spaces]
# TODO: replace default <image> placeholder with self.image_token, for contexts
whole_enc, image_enc = self.tok_multimodal_encode(
context + continuation, images
)
context_enc, _ = self.tok_multimodal_encode(context, images)
# tok_multimodal_encode returns List[List[int]] for tokenized text. Get rid of the batch dim
# since we only are encoding a single string.
# TODO: this is a bit hacky, it'd be nice to make this generally cleaner
whole_enc, context_enc = whole_enc[0], context_enc[0]
context_enc_len = len(context_enc)
continuation_enc = whole_enc[context_enc_len:]
return context_enc, continuation_enc, image_enc
def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
self.chat_applied = True
if not self.interleave:
for content in chat_history:
c = []
text = content["content"]
# Count and remove image placeholders
image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})
# Add single text entry at the end
c.append({"type": "text", "text": text})
content["content"] = c
else:
for content in chat_history:
c = []
text = content["content"]
expected_image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
actual_image_count = 0
text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)
for i, part in enumerate(text_parts):
# TODO: concatenate text parts (esp. if skipping images)?
if part: # Add non-empty text parts
c.append({"type": "text", "text": part})
if (
(i < len(text_parts) - 1) and i < self.max_images
): # Add image placeholder after each split except the last
c.append({"type": "image"})
actual_image_count += 1
content["content"] = c
if actual_image_count != expected_image_count:
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)
return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
if hasattr(self.processor, "apply_chat_template"):
_tokenizer = self.tokenizer
self.tokenizer = self.processor
selected_template = super().chat_template(chat_template)
self.tokenizer = _tokenizer
return selected_template
else:
return super().chat_template(chat_template)
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images: List[List], # TODO: images are pil.Image at the moment, update typehint
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Union[
BatchEncoding, Dict[str, torch.Tensor]
]: # note that this return signature differs from HFLM tok_batch_encode.
# NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
if not self.chat_applied:
# TODO<baber>: This still keeps the whitespace in the image placeholder, which is not ideal.
strings = [
replace_placeholders(
string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
for string in strings
]
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
self.tokenizer.padding_side = padding_side
# add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
images = [img[: self.max_images] for img in images]
if self.rgb:
images = [[img.convert("RGB") for img in sublist] for sublist in images]
# certain models like llava expect a single-level image list even for bs>1, multi-image. TODO: port this over to loglikelihoods
if getattr(self.config, "model_type", "") == "llava":
images = flatten_image_list(images)
encoding = self.processor(
images=images,
text=strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
# **add_special_tokens, # TODO: at least some Processors error out when passing this. How do we control whether text gets BOS added?
)
encoding.to( # TODO: our other tokenization methods in HFLM don't typically move to device. this breaks convention
self.device, self.model.dtype
) # TODO: This only casts the pixel values. Should they always be float16?
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
:, -left_truncate_len:
]
self.tokenizer.padding_side = old_padding_side
return encoding
def _model_multimodal_call(self, inps, imgs, attn_mask=None, labels=None):
"""
TODO: update docstring
"""
# note: imgs is a dict.
with torch.no_grad():
return self.model(inps, **imgs).logits
def _model_multimodal_generate(self, inputs, max_length, stop, **generation_kwargs):
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)
# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False
if do_sample is False and generation_kwargs.get("temperature") == 0.0:
generation_kwargs.pop("temperature")
stopping_criteria = stop_sequences_criteria(
self.tokenizer,
stop,
inputs["input_ids"].shape[1],
inputs["input_ids"].shape[0],
)
return self.model.generate(
**inputs,
max_length=max_length,
stopping_criteria=stopping_criteria,
pad_token_id=self.tokenizer.pad_token_id,
use_cache=True,
**generation_kwargs,
)
def _batch_images(self, image_encs):
"""
Helper function: batch together image encodings across examples in a batch.
# TODO: for variable-sized images, this may break down.
"""
batched_imgs = {}
for key in image_encs[0].keys():
batched_imgs[key] = torch.cat(
[
torch.tensor(
image_enc[key], device=self.device, dtype=self.model.dtype
)
for image_enc in image_encs
],
dim=0,
)
return batched_imgs
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
raise NotImplementedError(
"model type `hf-multimodal` does not support loglikelihood_rolling. Use 'hf' model type for text-only loglikelihood_rolling tasks ",
"this is because we do not support measuring the loglikelihood a model assigns to an image.",
)
def loglikelihood(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
raise NotImplementedError(
"'loglikelihood' requests for model type `hf-multimodal` are not yet tested. This feature will be enabled when a loglikelihood-based multiple-choice VQA dataset is added!"
)
new_reqs = []
for context, continuation, aux_arguments in [req.args for req in requests]:
if context == "":
raise ValueError(
"Must get non-empty context for multimodal requests! You might be trying to run 'loglikelihood_rolling', which is not supported in the multimodal case."
)
else:
visuals = aux_arguments["visual"]
context_enc, continuation_enc, image_enc = self._encode_multimodal_pair(
context, continuation, visuals
)
# TODO: key to pick for caching images
new_reqs.append(
(
(context, continuation, visuals),
context_enc,
continuation_enc,
image_enc,
)
)
return self._loglikelihood_tokens(new_reqs, disable_tqdm=disable_tqdm)
def _loglikelihood_tokens(
self,
requests: List[
Tuple[Tuple[None, str, str], List[int], List[int], List[int]]
], # TODO: update typehint to be correct
disable_tqdm: bool = False,
override_bs: int = None,
) -> List[Tuple[float, bool]]:
res = []
# TODO: **improve multimodal collation.** We currently ignore image size when ordering docs. ideally we'd take them into account
def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = req[1] + req[2]
return -len(toks), tuple(toks)
def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return req[-1] + req[-3] + req[-2][:-1]
re_ord = Collator(
requests,
sort_fn=_collate,
group_by="contexts" # TODO: can't group-by just "contexts" any more, need to incorporate imgs
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
and self.logits_cache
else None,
group_fn=_lookup_one_token_cont,
)
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
n_reordered_requests = len(re_ord)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else override_bs
if override_bs is not None
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running loglikelihood requests with text+image input",
)
for chunk in chunks:
imgs = []
inps = []
cont_toks_list = []
inplens = []
padding_len_inp = None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for _, context_enc, continuation_enc, image_enc in chunk:
# sanity check
assert len(image_enc) > 0
assert len(context_enc) > 0
assert len(continuation_enc) > 0
assert len(continuation_enc) <= self.max_length
# how this all works (illustrated on a causal decoder-only setup):
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# model \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left
# TODO: assuming that we won't handle enc-dec Vision2Seq models. Is that a safe assumption?
inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long,
device=self.device,
)
(inplen,) = inp.shape
padding_len_inp = (
max(padding_len_inp, inplen)
if padding_len_inp is not None
else inplen
)
inps.append(inp) # [1, inp_length]
cont_toks_list.append(continuation_enc)
inplens.append(inplen)
imgs.append(image_enc)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs = {}
batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp]
# batch our examples' image inputs together
batched_imgs = self._batch_images(
imgs
) # TODO: fix/test for bs>1 case with differently-sized imgs!
multi_logits = F.log_softmax(
self._model_multimodal_call(batched_inps, batched_imgs, **call_kwargs),
dim=-1,
) # [batch, padding_length (inp or cont), vocab]
for (
request_str,
ctx_tokens,
_,
image_encs,
), logits, inplen, cont_toks in zip(
chunk, multi_logits, inplens, cont_toks_list
):
# Slice to original seq length
contlen = len(cont_toks)
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# also discards + checks for "virtual tokens" in the causal LM's input window
# from prompt/prefix tuning tokens, if applicable
ctx_len = (
inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
else None
)
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
logits = logits.unsqueeze(0) # [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens = logits.argmax(dim=-1)
# check for one-token continuation cache hits.
# noop in case group_by != "contexts" or no cache hit and returns the
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
# logits -> [1, seq, vocab]
for request_str, cont_toks, logits in re_ord.get_cache(
req_str=request_str,
cxt_toks=ctx_tokens,
cont_toks=cont_toks,
logits=logits,
):
cont_toks = torch.tensor(
cont_toks, dtype=torch.long, device=self.device
).unsqueeze(0) # [1, seq]
max_equal = (greedy_tokens == cont_toks).all()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
-1
) # [1, seq]
# Answer: (log prob, is-exact-match)
answer = (float(logits.sum()), bool(max_equal))
res.append(answer)
self.cache_hook.add_partial(
"loglikelihood", request_str, answer
) # TODO: choose convention for adding images into the cache key
pbar.update(1)
pbar.close()
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
# TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs)
res = []
def _collate(x):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks = self.tok_encode(x[0])
return -len(toks), x[0]
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests with text+image input",
)
# TODO: port auto-batch sizing into this.
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
re_ords = Collator(
[reg.args for reg in requests],
_collate,
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
### Up to here: was identical to non-multimodal HFLM generate_until ###
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
visuals = [arg["visual"] for arg in aux_arguments]
if not isinstance(contexts, list):
contexts = list(
contexts
) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list.
# TODO: could we upstream this workaround to HF?
### this part onward: same as HFLM ###
# we assume all gen kwargs in the batch are the same
# this is safe to assume because the `grouper` object ensures it.
gen_kwargs = all_gen_kwargs[0]
# unpack our keyword arguments.
until = None
if isinstance(gen_kwargs, dict):
kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in kwargs.keys():
until = kwargs.pop("until")
if isinstance(until, str):
until = [until]
elif not isinstance(until, list):
raise ValueError(
f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
# add EOS token to stop sequences
eos = self.tok_decode(self.eot_token_id, skip_special_tokens=False)
if not until:
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
else:
max_gen_toks = self.max_gen_toks
### end stuff that's entirely copied verbatim from HFLM ###
max_ctx_len = self.max_length - max_gen_toks
inputs = self.tok_batch_multimodal_encode(
contexts,
visuals,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
)
context_enc = inputs["input_ids"]
if "max_length" not in kwargs:
kwargs["max_length"] = context_enc.shape[1] + max_gen_toks
cont = self._model_multimodal_generate(inputs, stop=until, **kwargs)
del inputs
torch.cuda.empty_cache()
import gc
gc.collect()
### essentially same as HFLM beyond this line!
cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only VLM
cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
res.append(s)
self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), s
) # TODO: cache key for multimodal input should be what?
pbar.update(1)
# reorder this group of results back to original unsorted form
res = re_ords.get_original(res)
pbar.close()
return res