Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 20 additions & 21 deletions lm_eval/models/vllm_vlms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
import logging
from typing import Dict, List, Optional

import transformers
from more_itertools import distribute
Expand Down Expand Up @@ -40,14 +39,14 @@ class VLLM_VLM(VLLM):
def __init__(
self,
pretrained: str,
trust_remote_code: Optional[bool] = False,
revision: Optional[str] = None,
trust_remote_code: bool | None = False,
revision: str | None = None,
interleave: bool = True,
# TODO<baber>: handle max_images and limit_mm_per_prompt better
max_images: int = 999,
image_width: Optional[int] = None,
image_height: Optional[int] = None,
image_max_side: Optional[int] = None,
image_width: int | None = None,
image_height: int | None = None,
image_max_side: int | None = None,
**kwargs,
):
self.image_width = image_width
Expand Down Expand Up @@ -79,9 +78,9 @@ def __init__(

def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
strings: list[str], # note that input signature of this fn is different
images, # TODO: typehint on this
left_truncate_len: int = None,
left_truncate_len: int | None = None,
truncation: bool = False,
):
images = [img[: self.max_images] for img in images]
Expand All @@ -98,7 +97,7 @@ def tok_batch_multimodal_encode(
]

outputs = []
for x, i in zip(strings, images):
for x, i in zip(strings, images, strict=True):
inputs = {
"prompt": x,
"multi_modal_data": {"image": i},
Expand All @@ -108,14 +107,14 @@ def tok_batch_multimodal_encode(

def _multimodal_model_generate(
self,
requests: List[List[dict]] = None,
requests: list[list[dict]] = None,
generate: bool = False,
max_tokens: int = None,
stop: Optional[List[str]] = None,
stop: list[str] | None = None,
**kwargs,
):
if generate:
kwargs = self.modify_gen_kwargs(kwargs)
kwargs, _, _ = self.modify_gen_kwargs(kwargs)
sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs)
else:
sampling_params = SamplingParams(
Expand All @@ -127,7 +126,7 @@ def _multimodal_model_generate(
# see https://github.com/vllm-project/vllm/issues/973
@ray.remote
def run_inference_one_model(
model_args: dict, sampling_params, requests: List[List[dict]]
model_args: dict, sampling_params, requests: list[list[dict]]
):
llm = LLM(**model_args)
return llm.generate(requests, sampling_params=sampling_params)
Expand All @@ -147,19 +146,19 @@ def run_inference_one_model(
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=self.batch_size == "auto",
lora_request=self.lora_request,
)
else:
outputs = self.model.generate(
requests,
sampling_params=sampling_params,
use_tqdm=True if self.batch_size == "auto" else False,
use_tqdm=self.batch_size == "auto",
)
return outputs

def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
self, chat_history: list[dict[str, str]], add_generation_prompt=True
) -> str:
self.chat_applied = True
if not self.interleave:
Expand Down Expand Up @@ -216,8 +215,8 @@ def apply_chat_template(
)

def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().generate_until(requests=requests, disable_tqdm=disable_tqdm)
Expand Down Expand Up @@ -253,7 +252,7 @@ def _collate(x):
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
eos = self.tokenizer.decode(self.eot_token_id)
for chunk in chunks:
contexts, all_gen_kwargs, aux_arguments = zip(*chunk)
contexts, all_gen_kwargs, aux_arguments = zip(*chunk, strict=True)

visuals = [
[
Expand Down Expand Up @@ -300,7 +299,7 @@ def _collate(x):
inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs
)

for output, context in zip(cont, contexts):
for output, context in zip(cont, contexts, strict=True):
generated_text = output.outputs[0].text
res.append(generated_text)
self.cache_hook.add_partial(
Expand All @@ -313,7 +312,7 @@ def _collate(x):
pbar.close()
return res

def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]:
if requests and len(requests[0].args) < 3:
# Fall back to non-multimodal generation.
return super().loglikelihood_rolling(requests=requests)
Expand Down