diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ba1428c42ee4..f5a9ada21899 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -379,6 +379,19 @@ def __init__( self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.max_num_tokens = scheduler_config.max_num_batched_tokens + if self.speculative_config is not None and ( + self.speculative_config.uses_draft_model() + or self.speculative_config.use_eagle() + ): + # When speculative decoding is enabled, additional slots are needed + # on top of the scheduler's allocation + multiplier = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config.parallel_drafting + else 1 + ) + self.max_num_tokens += multiplier * scheduler_config.max_num_seqs + self.max_num_reqs = scheduler_config.max_num_seqs # Broadcast PP output for external_launcher (torchrun) @@ -1662,10 +1675,7 @@ def _prepare_inputs( # Hot-Swap lora model if self.lora_config: - assert ( - np.sum(num_sampled_tokens) - <= self.vllm_config.scheduler_config.max_num_batched_tokens - ) + assert np.sum(num_sampled_tokens) <= self.max_num_tokens self.set_active_loras( self.input_batch, num_scheduled_tokens, num_sampled_tokens ) @@ -4679,7 +4689,7 @@ def _dummy_run( # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively # has num_tokens in total. - assert num_tokens <= self.scheduler_config.max_num_batched_tokens + assert num_tokens <= self.max_num_tokens max_num_reqs = self.scheduler_config.max_num_seqs if create_mixed_batch: assert not uniform_decode @@ -6107,7 +6117,7 @@ def init_routed_experts_capturer(self): + 1 ) * block_size routed_experts_capturer.init_buffer( - max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, + max_num_batched_tokens=self.max_num_tokens, max_num_kv_tokens=self.max_num_kv_tokens, vllm_config=self.vllm_config, )