diff --git a/lm_eval/models/vllm_causallms.py b/lm_eval/models/vllm_causallms.py index 390a14a7e3b..e35cac2a079 100644 --- a/lm_eval/models/vllm_causallms.py +++ b/lm_eval/models/vllm_causallms.py @@ -1,6 +1,5 @@ import copy import gc -import inspect import logging import os from importlib.metadata import version @@ -33,7 +32,7 @@ try: import ray - from vllm import LLM, SamplingParams + from vllm import LLM, SamplingParams, TokensPrompt from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import get_open_port @@ -79,7 +78,7 @@ def _vllm_mp_worker( try: llm = LLM(**model_args) res = llm.generate( - prompt_token_ids=requests, + [TokensPrompt(prompt_token_ids=request) for request in requests], sampling_params=sampling_params, lora_request=lora_request, ) @@ -239,13 +238,6 @@ def __init__( model_config = engine_args.create_model_config() kwargs_resolve_hf_chat_template["model_config"] = model_config - - # https://github.com/vllm-project/vllm/pull/18259 - if ( - "trsut_remote_code" - in inspect.signature(resolve_hf_chat_template).parameters - ): - kwargs_resolve_hf_chat_template["trsut_remote_code"] = trust_remote_code else: kwargs_resolve_hf_chat_template["trust_remote_code"] = trust_remote_code @@ -395,7 +387,7 @@ def run_inference_one_model( ): llm = LLM(**model_args) return llm.generate( - prompt_token_ids=requests, + [TokensPrompt(prompt_token_ids=request) for request in requests], sampling_params=sampling_params, lora_request=lora_request, ) @@ -484,7 +476,7 @@ def run_inference_one_model( else: outputs = self.model.generate( - prompt_token_ids=requests, + [TokensPrompt(prompt_token_ids=request) for request in requests], sampling_params=sampling_params, use_tqdm=True if self.batch_size == "auto" else False, lora_request=self.lora_request,