diff --git a/infer/vllm/driver b/infer/vllm/driver index b6182bc..ce6ac1f 100755 --- a/infer/vllm/driver +++ b/infer/vllm/driver @@ -230,7 +230,7 @@ def run(input_size, output_size, batch_size): ) kwargs = { - 'prompt_token_ids' : input_batch, + 'prompts' : [vllm.TokensPrompt(prompt_token_ids=token_id_list) for token_id_list in input_batch], 'sampling_params' : sampling_params, 'use_tqdm' : False, } @@ -239,7 +239,7 @@ def run(input_size, output_size, batch_size): for rep in range(par.reps): input_batch = fmwork.input_generator(par.model_path, input_size, batch_size, return_tensors='np') - kwargs['prompt_token_ids'] = input_batch + kwargs['prompts'] = [vllm.TokensPrompt(prompt_token_ids=token_id_list) for token_id_list in input_batch] fmwork.t0() #with MADProfiler(backend="rpd", nvtx_tracing=True): # outputs = var.llm.generate(**kwargs)