diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 513f0afbc1..9b99dee17d 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -1218,16 +1218,21 @@ def _set_compile_ranges(self): computed_compile_ranges_split_points = [] # The upper bound of the compile ranges is the max_num_batched_tokens. - # For speculative decoding with draft model, the compile range must be extended - # by 1 for each sequence. + # For speculative decoding, the compile range must be extended + # - Sequential: + 1 * max_num_seqs (one draft token per iteration) + # - Parallel draft: + num_speculative_tokens * max_num_seqs compile_range_end = self.scheduler_config.max_num_batched_tokens if compile_range_end is not None: - do_extend: bool = ( - self.speculative_config is not None - and self.speculative_config.uses_draft_model() - ) - if do_extend: - compile_range_end += self.scheduler_config.max_num_seqs + if self.speculative_config is not None and ( + self.speculative_config.uses_draft_model() + or self.speculative_config.use_eagle() + ): + multiplier = ( + self.speculative_config.num_speculative_tokens + if self.speculative_config.parallel_draft + else 1 + ) + compile_range_end += multiplier * self.scheduler_config.max_num_seqs computed_compile_ranges_split_points.append(compile_range_end) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1ae058c2ea..e10f401f1f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -76,8 +76,12 @@ def __init__( self.num_speculative_tokens = self.speculative_config.num_speculative_tokens # The drafter can get longer sequences than the target model. max_batch_size = vllm_config.scheduler_config.max_num_seqs + multiplier = ( + self.num_speculative_tokens if self.speculative_config.parallel_draft else 1 + ) self.max_num_tokens = ( - vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size + vllm_config.scheduler_config.max_num_batched_tokens + + max_batch_size * multiplier ) self.token_arange_np = np.arange(self.max_num_tokens) # We need to get the hidden size from the draft model config because