diff --git a/src/lighteval/main_vllm.py b/src/lighteval/main_vllm.py index 4ffcd60e2..6f0ef659e 100644 --- a/src/lighteval/main_vllm.py +++ b/src/lighteval/main_vllm.py @@ -123,7 +123,7 @@ def vllm( job_id=job_id, dataset_loading_processes=dataset_loading_processes, custom_tasks_directory=custom_tasks, - override_batch_size=-1, # Cannot override batch size when using VLLM + override_batch_size=-1, # Cannot override batch size when using vLLM; Configure `max_num_seqs` and `max_num_batched_tokens` in `VLLMModelConfig` instead. num_fewshot_seeds=num_fewshot_seeds, max_samples=max_samples, use_chat_template=use_chat_template, diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 98309f7d6..7174430fc 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -93,6 +93,8 @@ class VLLMModelConfig: ) pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together. generation_parameters: GenerationParameters = None # sampling parameters to use for generation + max_num_seqs: int = 128 # maximum number of sequences per iteration; This variable and `max_num_batched_tokens` effectively control the batch size at prefill stage. See https://github.com/vllm-project/vllm/issues/2492 for detailed explaination. + max_num_batched_tokens: int = 2048 # maximum number of tokens per batch subfolder: Optional[str] = None @@ -183,6 +185,8 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) -> "max_model_len": self._max_length, "swap_space": 4, "seed": config.seed, + "max_num_seqs": int(config.max_num_seqs), + "max_num_batched_tokens": int(config.max_num_batched_tokens), } if int(config.data_parallel_size) > 1: self.model_args["distributed_executor_backend"] = "ray"