diff --git a/lm_eval/models/vllm_vlms.py b/lm_eval/models/vllm_vlms.py index 15813b8aa9a..59b25dd2da7 100644 --- a/lm_eval/models/vllm_vlms.py +++ b/lm_eval/models/vllm_vlms.py @@ -1,6 +1,5 @@ import copy import logging -from typing import Dict, List, Optional import transformers from more_itertools import distribute @@ -40,14 +39,14 @@ class VLLM_VLM(VLLM): def __init__( self, pretrained: str, - trust_remote_code: Optional[bool] = False, - revision: Optional[str] = None, + trust_remote_code: bool | None = False, + revision: str | None = None, interleave: bool = True, # TODO: handle max_images and limit_mm_per_prompt better max_images: int = 999, - image_width: Optional[int] = None, - image_height: Optional[int] = None, - image_max_side: Optional[int] = None, + image_width: int | None = None, + image_height: int | None = None, + image_max_side: int | None = None, **kwargs, ): self.image_width = image_width @@ -79,9 +78,9 @@ def __init__( def tok_batch_multimodal_encode( self, - strings: List[str], # note that input signature of this fn is different + strings: list[str], # note that input signature of this fn is different images, # TODO: typehint on this - left_truncate_len: int = None, + left_truncate_len: int | None = None, truncation: bool = False, ): images = [img[: self.max_images] for img in images] @@ -98,7 +97,7 @@ def tok_batch_multimodal_encode( ] outputs = [] - for x, i in zip(strings, images): + for x, i in zip(strings, images, strict=True): inputs = { "prompt": x, "multi_modal_data": {"image": i}, @@ -108,14 +107,14 @@ def tok_batch_multimodal_encode( def _multimodal_model_generate( self, - requests: List[List[dict]] = None, + requests: list[list[dict]] = None, generate: bool = False, max_tokens: int = None, - stop: Optional[List[str]] = None, + stop: list[str] | None = None, **kwargs, ): if generate: - kwargs = self.modify_gen_kwargs(kwargs) + kwargs, _, _ = self.modify_gen_kwargs(kwargs) sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) else: sampling_params = SamplingParams( @@ -127,7 +126,7 @@ def _multimodal_model_generate( # see https://github.com/vllm-project/vllm/issues/973 @ray.remote def run_inference_one_model( - model_args: dict, sampling_params, requests: List[List[dict]] + model_args: dict, sampling_params, requests: list[list[dict]] ): llm = LLM(**model_args) return llm.generate(requests, sampling_params=sampling_params) @@ -147,19 +146,19 @@ def run_inference_one_model( outputs = self.model.generate( requests, sampling_params=sampling_params, - use_tqdm=True if self.batch_size == "auto" else False, + use_tqdm=self.batch_size == "auto", lora_request=self.lora_request, ) else: outputs = self.model.generate( requests, sampling_params=sampling_params, - use_tqdm=True if self.batch_size == "auto" else False, + use_tqdm=self.batch_size == "auto", ) return outputs def apply_chat_template( - self, chat_history: List[Dict[str, str]], add_generation_prompt=True + self, chat_history: list[dict[str, str]], add_generation_prompt=True ) -> str: self.chat_applied = True if not self.interleave: @@ -216,8 +215,8 @@ def apply_chat_template( ) def generate_until( - self, requests: List[Instance], disable_tqdm: bool = False - ) -> List[str]: + self, requests: list[Instance], disable_tqdm: bool = False + ) -> list[str]: if requests and len(requests[0].args) < 3: # Fall back to non-multimodal generation. return super().generate_until(requests=requests, disable_tqdm=disable_tqdm) @@ -253,7 +252,7 @@ def _collate(x): chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) eos = self.tokenizer.decode(self.eot_token_id) for chunk in chunks: - contexts, all_gen_kwargs, aux_arguments = zip(*chunk) + contexts, all_gen_kwargs, aux_arguments = zip(*chunk, strict=True) visuals = [ [ @@ -300,7 +299,7 @@ def _collate(x): inputs, stop=until, generate=True, max_tokens=max_gen_toks, **kwargs ) - for output, context in zip(cont, contexts): + for output, context in zip(cont, contexts, strict=True): generated_text = output.outputs[0].text res.append(generated_text) self.cache_hook.add_partial( @@ -313,7 +312,7 @@ def _collate(x): pbar.close() return res - def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]: + def loglikelihood_rolling(self, requests: list[Instance]) -> list[float]: if requests and len(requests[0].args) < 3: # Fall back to non-multimodal generation. return super().loglikelihood_rolling(requests=requests)