diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 69249e610680..887fd52794cb 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -19,6 +19,9 @@ init_attn_backend, ) from vllm.v1.worker.gpu.block_table import BlockTables +from vllm.v1.worker.gpu.cudagraph_utils import ( + BatchExecutionDescriptor, +) from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState @@ -239,6 +242,66 @@ def generate_draft( idx_mapping, query_start_loc, pos, num_tokens_padded ) + def _dispatch_and_sync_dp( + self, + cudagraph_manager: EagleCudaGraphManager, + num_reqs: int, + num_tokens: int, + uniform_token_count: int | None, + ) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]: + batch_desc = cudagraph_manager.dispatch( + num_reqs, num_tokens, uniform_token_count + ) + num_tokens_across_dp = None + if self.dp_size > 1: + batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( + cudagraph_manager, + batch_desc, + num_tokens, + num_reqs, + uniform_token_count, + self.dp_size, + self.dp_rank, + ) + return batch_desc, num_tokens_across_dp + + def _build_draft_attn_metadata( + self, + num_reqs: int, + num_reqs_padded: int, + num_tokens_padded: int, + max_query_len: int, + ) -> dict[str, Any] | None: + if not self.draft_attn_layer_names: + return None + + query_start_loc_cpu = ( + torch.arange(num_reqs_padded + 1, dtype=torch.int32, device="cpu").clamp_( + max=num_reqs + ) + * max_query_len + ) + block_tables = [ + x[:num_reqs_padded] for x in self.block_tables.input_block_tables + ] + slot_mappings = self.block_tables.slot_mappings[:, :num_tokens_padded] + attn_metadata = build_attn_metadata( + attn_groups=self.attn_groups, + num_reqs=num_reqs_padded, + num_tokens=num_tokens_padded, + query_start_loc_gpu=self.input_buffers.query_start_loc[ + : num_reqs_padded + 1 + ], + query_start_loc_cpu=query_start_loc_cpu, + max_query_len=max_query_len, + seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], + max_seq_len=self.max_model_len, + block_tables=block_tables, + slot_mappings=slot_mappings, + kv_cache_config=self.kv_cache_config, + ) + return attn_metadata + def capture_model(self) -> None: if self.num_speculative_steps == 1: return @@ -319,7 +382,6 @@ def propose( logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs - num_reqs_padded = input_batch.num_reqs_after_padding # NOTE(woosuk): For draft sampling, we only consider the temperature # and ignore the other sampling parameters such as top_k and top_p, # for simplicity and performance. @@ -366,69 +428,49 @@ def propose( self.max_num_reqs, ) - # Get batch descriptor and sync across DP ranks. - # Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode - - batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1) - num_tokens_across_dp = None - - if self.dp_size > 1: - batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( - self.cudagraph_manager, - batch_desc, - num_reqs, - num_reqs, - 1, # uniform_token_count - self.dp_size, - self.dp_rank, - ) - - if not (dummy_run and skip_attn_for_dummy_run): - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos, batch_desc.num_tokens - ) - - if batch_desc.cg_mode == CUDAGraphMode.FULL: - return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs] + # Each request produces exactly 1 token per draft decode step, + # enabling FULL cudagraph. + decode_batch_desc, num_tokens_across_dp = self._dispatch_and_sync_dp( + self.cudagraph_manager, + num_reqs, + num_reqs, + uniform_token_count=1, + ) - # Run eager or piecewise CUDA graph. attn_metadata_updated = None slot_mappings_updated = None if not (dummy_run and skip_attn_for_dummy_run): - query_start_loc_cpu = torch.arange( - num_reqs_padded + 1, dtype=torch.int32, device="cpu" - ) - block_tables = [ - x[:num_reqs_padded] for x in self.block_tables.input_block_tables - ] - - # FIXME(woosuk): This is UNSAFE!! - attn_metadata_updated = build_attn_metadata( - attn_groups=self.attn_groups, - num_reqs=num_reqs_padded, - num_tokens=num_reqs_padded, - query_start_loc_gpu=query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - max_query_len=1, - seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], - max_seq_len=self.max_model_len, - block_tables=block_tables, - slot_mappings=slot_mappings, - kv_cache_config=self.kv_cache_config, + # Build attention metadata and slot mappings for the draft + # decode steps. It is necessary to rebuild the attention + # metadata even when replaying the FULL cudagraph so that + # any attention metadata builder state is updated. + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, + self.input_buffers.query_start_loc[: num_reqs + 1], + pos, + decode_batch_desc.num_tokens, ) slot_mappings_updated = build_slot_mappings_by_layer( slot_mappings, self.kv_cache_config ) + attn_metadata_updated = self._build_draft_attn_metadata( + num_reqs=num_reqs, + num_reqs_padded=decode_batch_desc.num_reqs or num_reqs, + num_tokens_padded=decode_batch_desc.num_tokens, + max_query_len=1, + ) - self.generate_draft( - num_reqs, - batch_desc.num_tokens, - attn_metadata_updated, - slot_mappings_updated, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=batch_desc.cg_mode, - ) + if decode_batch_desc.cg_mode == CUDAGraphMode.FULL: + self.cudagraph_manager.run_fullgraph(decode_batch_desc) + else: + self.generate_draft( + num_reqs, + decode_batch_desc.num_tokens, + attn_metadata_updated, + slot_mappings_updated, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=decode_batch_desc.cg_mode, + ) return self.draft_tokens[:num_reqs]