Skip to content

Commit bbbdd22

Browse files
authored
Propagate vLLM batch size controls (#588)
* expose vLLM batch size control config * comments * type casting * bump * fix defaults
1 parent bd9019d commit bbbdd22

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/lighteval/main_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def vllm(
123123
job_id=job_id,
124124
dataset_loading_processes=dataset_loading_processes,
125125
custom_tasks_directory=custom_tasks,
126-
override_batch_size=-1, # Cannot override batch size when using VLLM
126+
override_batch_size=-1, # Cannot override batch size when using vLLM; Configure `max_num_seqs` and `max_num_batched_tokens` in `VLLMModelConfig` instead.
127127
num_fewshot_seeds=num_fewshot_seeds,
128128
max_samples=max_samples,
129129
use_chat_template=use_chat_template,

src/lighteval/models/vllm/vllm_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ class VLLMModelConfig:
9393
)
9494
pairwise_tokenization: bool = False # whether to tokenize the context and continuation separately or together.
9595
generation_parameters: GenerationParameters = None # sampling parameters to use for generation
96+
max_num_seqs: int = 128 # maximum number of sequences per iteration; This variable and `max_num_batched_tokens` effectively control the batch size at prefill stage. See https://github.com/vllm-project/vllm/issues/2492 for detailed explaination.
97+
max_num_batched_tokens: int = 2048 # maximum number of tokens per batch
9698

9799
subfolder: Optional[str] = None
98100

@@ -181,6 +183,8 @@ def _create_auto_model(self, config: VLLMModelConfig, env_config: EnvConfig) ->
181183
"max_model_len": self._max_length,
182184
"swap_space": 4,
183185
"seed": config.seed,
186+
"max_num_seqs": int(config.max_num_seqs),
187+
"max_num_batched_tokens": int(config.max_num_batched_tokens),
184188
}
185189
if int(config.data_parallel_size) > 1:
186190
self.model_args["distributed_executor_backend"] = "ray"

0 commit comments

Comments
 (0)