Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,16 +1218,21 @@ def _set_compile_ranges(self):
computed_compile_ranges_split_points = []

# The upper bound of the compile ranges is the max_num_batched_tokens.
# For speculative decoding with draft model, the compile range must be extended
# by 1 for each sequence.
# For speculative decoding, the compile range must be extended
# - Sequential: + 1 * max_num_seqs (one draft token per iteration)
# - Parallel draft: + num_speculative_tokens * max_num_seqs
compile_range_end = self.scheduler_config.max_num_batched_tokens
if compile_range_end is not None:
do_extend: bool = (
self.speculative_config is not None
and self.speculative_config.uses_draft_model()
)
if do_extend:
compile_range_end += self.scheduler_config.max_num_seqs
if self.speculative_config is not None and (
self.speculative_config.uses_draft_model()
or self.speculative_config.use_eagle()
):
multiplier = (
self.speculative_config.num_speculative_tokens
if self.speculative_config.parallel_draft
else 1
)
compile_range_end += multiplier * self.scheduler_config.max_num_seqs

computed_compile_ranges_split_points.append(compile_range_end)

Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ def __init__(
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
multiplier = (
self.num_speculative_tokens if self.speculative_config.parallel_draft else 1
)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
vllm_config.scheduler_config.max_num_batched_tokens
+ max_batch_size * multiplier
)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
Expand Down