Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,19 @@ def __init__(
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
if self.speculative_config is not None and (
self.speculative_config.uses_draft_model()
or self.speculative_config.use_eagle()
):
# When speculative decoding is enabled, additional slots are needed
# on top of the scheduler's allocation
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_drafting
else 1
)
self.max_num_tokens += multiplier * scheduler_config.max_num_seqs

self.max_num_reqs = scheduler_config.max_num_seqs

# Broadcast PP output for external_launcher (torchrun)
Expand Down Expand Up @@ -1662,10 +1675,7 @@ def _prepare_inputs(

# Hot-Swap lora model
if self.lora_config:
assert (
np.sum(num_sampled_tokens)
<= self.vllm_config.scheduler_config.max_num_batched_tokens
)
assert np.sum(num_sampled_tokens) <= self.max_num_tokens
self.set_active_loras(
self.input_batch, num_scheduled_tokens, num_sampled_tokens
)
Expand Down Expand Up @@ -4679,7 +4689,7 @@ def _dummy_run(
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
# for dummy run with LoRA so that the num_reqs collectively
# has num_tokens in total.
assert num_tokens <= self.scheduler_config.max_num_batched_tokens
assert num_tokens <= self.max_num_tokens
max_num_reqs = self.scheduler_config.max_num_seqs
if create_mixed_batch:
assert not uniform_decode
Expand Down Expand Up @@ -6107,7 +6117,7 @@ def init_routed_experts_capturer(self):
+ 1
) * block_size
routed_experts_capturer.init_buffer(
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
max_num_batched_tokens=self.max_num_tokens,
max_num_kv_tokens=self.max_num_kv_tokens,
vllm_config=self.vllm_config,
)
Expand Down