diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index e27b5ee38834..ee19c47f1b48 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -31,8 +31,9 @@ class CudagraphDispatcher: runnable without cudagraph (if the mode does not match or mode is NONE). """ - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, for_draft_model: bool = False): self.vllm_config = vllm_config + self.for_draft_model = for_draft_model self.compilation_config = vllm_config.compilation_config self.uniform_decode_query_len = ( 1 @@ -134,9 +135,11 @@ def _create_padded_batch_descriptor( uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, + uniform_decode_query_len: int | None = None, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs - uniform_decode_query_len = self.uniform_decode_query_len + if uniform_decode_query_len is None: + uniform_decode_query_len = self.uniform_decode_query_len num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): @@ -229,6 +232,26 @@ def initialize_cudagraph_keys( ), ) + if self.for_draft_model and cudagraph_mode.has_full_cudagraphs(): + max_num_tokens = self.vllm_config.scheduler_config.max_num_seqs + assert self.compilation_config.cudagraph_capture_sizes is not None, ( + "Cudagraph capture sizes must be set when full mode is enabled." + ) + capture_sizes_for_draft_model = [] + for size in self.compilation_config.cudagraph_capture_sizes: + capture_sizes_for_draft_model.append(size) + if size >= max_num_tokens: + break + for bs, num_active_loras in product( + capture_sizes_for_draft_model, lora_cases + ): + self.add_cudagraph_key( + CUDAGraphMode.FULL, + self._create_padded_batch_descriptor( + bs, True, num_active_loras > 0, num_active_loras, 1 + ), + ) + self.keys_initialized = True def dispatch( @@ -239,6 +262,7 @@ def dispatch( num_active_loras: int = 0, valid_modes: AbstractSet[CUDAGraphMode] | None = None, invalid_modes: AbstractSet[CUDAGraphMode] | None = None, + uniform_decode_query_len: int | None = None, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using piecewise only), @@ -298,9 +322,15 @@ def dispatch( ) effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1 - normalized_uniform = uniform_decode and self.cudagraph_mode.separate_routine() + normalized_uniform = uniform_decode and ( + self.cudagraph_mode.separate_routine() or self.for_draft_model + ) batch_desc = self._create_padded_batch_descriptor( - num_tokens, normalized_uniform, has_lora, effective_num_active_loras + num_tokens, + normalized_uniform, + has_lora, + effective_num_active_loras, + uniform_decode_query_len, ) if CUDAGraphMode.FULL in allowed_modes: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 4b20413ca702..9de81047da8f 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,13 +9,14 @@ import torch import torch.nn as nn +from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import ( CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, ) from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -123,7 +124,7 @@ def __init__( # Keys are initialized later via initialize_cudagraph_keys() called from # gpu_model_runner._check_and_update_cudagraph_mode after # adjust_cudagraph_sizes_for_spec_decode is called. - self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config, True) # persistent buffers for cuda graph self.input_ids = torch.zeros( @@ -359,21 +360,11 @@ def _get_slot_mapping( return {name: view for name in self._draft_attn_layer_names} def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: - """Initialize cudagraph dispatcher keys for eagle. + """Initialize cudagraph dispatcher keys for eagle.""" + if self.speculative_config.enforce_eager: + cudagraph_mode = CUDAGraphMode.NONE - Eagle only supports PIECEWISE cudagraphs (via mixed_mode). - This should be called after adjust_cudagraph_sizes_for_spec_decode. - """ - if ( - not self.speculative_config.enforce_eager - and cudagraph_mode.mixed_mode() - in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] - ): - eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE - else: - eagle_cudagraph_mode = CUDAGraphMode.NONE - - self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) + self.cudagraph_dispatcher.initialize_cudagraph_keys(cudagraph_mode) def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: """Greedy-sample draft tokens from hidden states.""" @@ -393,6 +384,7 @@ def propose( next_token_ids: torch.Tensor, token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, + target_model_batch_desc: BatchDescriptor, sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, @@ -403,8 +395,12 @@ def propose( batch_size = common_attn_metadata.batch_size() if self.method == "eagle3": + if isinstance(self.model, CUDAGraphWrapper): + model = self.model.unwrap() + else: + model = self.model assert isinstance( - self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states @@ -431,8 +427,9 @@ def propose( for layer_name in attn_group.layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( - self._determine_batch_execution_and_padding(num_tokens) + uniform_decode = target_model_batch_desc.uniform + cudagraph_runtime_mode, batch_desc, num_input_tokens, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding(num_tokens, uniform_decode) ) if self.supports_mm_inputs: @@ -464,6 +461,7 @@ def propose( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=self._get_slot_mapping( num_input_tokens, common_attn_metadata.slot_mapping ), @@ -517,14 +515,18 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = ( - self._determine_batch_execution_and_padding(batch_size) + cudagraph_runtime_mode, batch_desc, input_batch_size, batch_size_across_dp = ( + self._determine_batch_execution_and_padding( + batch_size, True, uniform_decode_query_len=1 + ) ) common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] - common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[ + : batch_size + 1 + ] + common_attn_metadata.query_start_loc_cpu[: batch_size + 1] = torch.from_numpy( self.token_arange_np[: batch_size + 1] ).clone() @@ -628,6 +630,7 @@ def propose( num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=self._get_slot_mapping(input_batch_size), ): ret_hidden_states = self.model(**model_kwargs) @@ -823,6 +826,7 @@ def prepare_next_token_ids_padded( requests: dict[str, CachedRequestState], gpu_input_batch: InputBatch, discard_request_mask: torch.Tensor, + batch_size: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding. @@ -844,16 +848,17 @@ def prepare_next_token_ids_padded( self.backup_next_token_ids.copy_to_gpu(num_reqs) backup_tokens_gpu = self.backup_next_token_ids.gpu - batch_size, num_tokens = sampled_token_ids.shape + _, num_tokens = sampled_token_ids.shape device = sampled_token_ids.device assert discard_request_mask.dtype == torch.bool assert backup_tokens_gpu.dtype == torch.int32 - next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device) - valid_sampled_tokens_count = next_token_ids.new_empty(batch_size) + next_token_ids = torch.zeros(batch_size, dtype=torch.int32, device=device) + valid_sampled_tokens_count = next_token_ids.new_zeros(batch_size) # Kernel grid: one program per request (row) + # NOTE: For CUDA Graph, we need the `batch_size` to be `num_reqs_padded` here grid = (batch_size,) # Find the next power of 2 for block sizes @@ -878,6 +883,7 @@ def prepare_inputs_padded( common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, + gpu_input_batch: InputBatch, ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding @@ -897,6 +903,14 @@ def prepare_inputs_padded( (num_reqs,), dtype=torch.int32, device=device ) + actual_num_reqs = gpu_input_batch.num_reqs + spec_decode_metadata.cu_num_draft_tokens = nn.functional.pad( + spec_decode_metadata.cu_num_draft_tokens, + (0, num_reqs - actual_num_reqs), + mode="constant", + value=spec_decode_metadata.cu_num_draft_tokens[-1], + ) + grid = (num_reqs,) eagle_prepare_inputs_padded_kernel[grid]( spec_decode_metadata.cu_num_draft_tokens, @@ -1237,6 +1251,19 @@ def load_model(self, target_model: nn.Module) -> None: ) self.model = self._get_model() + # wrap the model with full cudagraph wrapper if needed. + cudagraph_mode = self.compilation_config.cudagraph_mode + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.vllm_config.parallel_config.use_ubatching + and not self.speculative_config.disable_padded_drafter_batch + ): + # Currently Ubatch does not support FULL in speculative decoding, unpadded + # drafter batch either due to the dynamic number of tokens. + # We can consider supporting FULL for these cases in the future if needed. + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) # Find draft layers (attention layers added by draft model) all_attn_layers = get_layers_from_vllm_config( @@ -1469,6 +1496,7 @@ def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + common_attn_metadata: CommonAttentionMetadata | None = None, use_cudagraphs: bool = True, is_graph_capturing: bool = False, slot_mappings: dict[str, torch.Tensor] | None = None, @@ -1476,14 +1504,38 @@ def dummy_run( # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. for fwd_idx in range( - self.num_speculative_tokens if not is_graph_capturing else 1 + self.num_speculative_tokens + if not is_graph_capturing + else min(self.num_speculative_tokens, 2) ): - if fwd_idx <= 1: - cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( - self._determine_batch_execution_and_padding( - num_tokens, use_cudagraphs=use_cudagraphs - ) + if fwd_idx > 0 and common_attn_metadata is not None: + # All speculative steps except the first one typically use + # a uniform decode with 1 token per request. + uniform_decode = True + num_tokens = common_attn_metadata.num_reqs + uniform_decode_query_len = 1 + else: + # For the first step, note that for FULL_DECODE_ONLY and + # FULL_AND_PIECEWISE we need to set uniform_decode to True + # while for FULL we don't + mode = self.cudagraph_dispatcher.cudagraph_mode + is_full_sep = ( + mode.decode_mode() == CUDAGraphMode.FULL and mode.separate_routine() ) + uniform_decode = is_full_sep and common_attn_metadata is not None + uniform_decode_query_len = None + + ( + cudagraph_runtime_mode, + batch_desc, + num_input_tokens, + num_tokens_across_dp, + ) = self._determine_batch_execution_and_padding( + num_tokens, + uniform_decode, + use_cudagraphs=use_cudagraphs, + uniform_decode_query_len=uniform_decode_query_len, + ) # Make sure to use EAGLE's own buffer during cudagraph capture. if ( @@ -1495,12 +1547,26 @@ def dummy_run( else: slot_mapping_dict = slot_mappings or {} + if common_attn_metadata is not None: + dummy_attn_metadata = {} + for attn_group in self.draft_attn_groups: + attn_metadata = ( + attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + ) + for layer_name in attn_group.layer_names: + dummy_attn_metadata[layer_name] = attn_metadata + else: + dummy_attn_metadata = None + with set_forward_context( - None, + dummy_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=slot_mapping_dict, ): if self.supports_mm_inputs: @@ -1623,11 +1689,15 @@ def initialize_attn_backend( def _determine_batch_execution_and_padding( self, num_tokens: int, + uniform_decode: bool = False, use_cudagraphs: bool = True, - ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]: + uniform_decode_query_len: int | None = None, + ) -> tuple[CUDAGraphMode, BatchDescriptor, int, torch.Tensor | None]: cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch( num_tokens, + uniform_decode, valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None), + uniform_decode_query_len=uniform_decode_query_len, ) num_tokens_padded = batch_desc.num_tokens @@ -1662,7 +1732,7 @@ def _determine_batch_execution_and_padding( assert batch_desc.num_tokens == num_tokens_padded num_tokens_across_dp[dp_rank] = num_tokens_padded - return cudagraph_mode, num_tokens_padded, num_tokens_across_dp + return cudagraph_mode, batch_desc, num_tokens_padded, num_tokens_across_dp class EagleProposer(SpecDecodeBaseProposer): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index be7734487791..10a5a0806af7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -383,6 +383,7 @@ class ExecuteModelState(NamedTuple): ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None + batch_desc: BatchDescriptor class GPUModelRunner( @@ -505,6 +506,7 @@ def __init__( self.encoder_cudagraph_manager: EncoderCudaGraphManager | None = None self.use_aux_hidden_state_outputs = False + self.supports_sd_full_graph = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many @@ -519,6 +521,12 @@ def __init__( | MedusaProposer | ExtractHiddenStatesProposer ) + + self.supports_sd_full_graph = ( + self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + if self.speculative_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -2323,11 +2331,13 @@ def _build_attn_group_metadata( for _metadata in attn_metadata.values(): _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] - if spec_decode_common_attn_metadata is not None and ( - num_reqs != num_reqs_padded or num_tokens != num_tokens_padded + if ( + spec_decode_common_attn_metadata is not None + and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded) + and not self.supports_sd_full_graph ): - # Currently the drafter still only uses piecewise cudagraphs (and modifies - # the attention metadata in directly), and therefore does not want to use + # Currently the drafter still only uses piecewise cudagraphs (except for + # Eagle, which supports FULL now), and therefore does not want to use # padded attention metadata. spec_decode_common_attn_metadata = ( spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) @@ -4094,6 +4104,7 @@ def execute_model( ec_connector_output, cudagraph_stats, slot_mappings, + batch_desc, ) self.kv_connector_output = kv_connector_output @@ -4138,6 +4149,7 @@ def sample_tokens( ec_connector_output, cudagraph_stats, slot_mappings, + batch_desc, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -4182,6 +4194,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, slot_mappings, + batch_desc, ) self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -4216,6 +4229,7 @@ def propose_draft_token_ids(sampled_token_ids): self.requests, self.input_batch, self.discard_request_mask.gpu, + spec_decode_common_attn_metadata.num_reqs, ) ) self._copy_valid_sampled_token_count( @@ -4474,6 +4488,7 @@ def propose_draft_token_ids( spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + target_model_batch_desc: BatchDescriptor, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4667,6 +4682,7 @@ def propose_draft_token_ids( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count, + self.input_batch, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -4694,8 +4710,9 @@ def propose_draft_token_ids( target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, token_indices_to_sample=token_indices_to_sample, - sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, + target_model_batch_desc=target_model_batch_desc, + sampling_metadata=sampling_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, @@ -5338,6 +5355,7 @@ def _dummy_run( ) attn_metadata: PerLayerAttnMetadata | None = None + spec_decode_cm: CommonAttentionMetadata | None = None slot_mappings_by_group, slot_mappings = self._get_slot_mappings( num_tokens_padded=num_tokens, @@ -5383,7 +5401,7 @@ def _dummy_run( self.input_batch.block_table.commit_block_table(num_reqs_padded) pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL - attn_metadata, _ = self._build_attention_metadata( + attn_metadata, spec_decode_cm = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs_padded, @@ -5486,13 +5504,14 @@ def _dummy_run( EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, ) assert self.speculative_config is not None - # Eagle currently only supports PIECEWISE cudagraphs. - # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( ( is_graph_capturing - and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + and ( + cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + or self.supports_sd_full_graph + ) ) or ( not is_graph_capturing @@ -5512,6 +5531,7 @@ def _dummy_run( self.drafter.dummy_run( num_tokens, + common_attn_metadata=spec_decode_cm, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, slot_mappings=slot_mappings, @@ -6093,6 +6113,11 @@ def _capture_cudagraphs( # Only rank 0 should print progress bar during capture if is_global_first_rank(): + logger.info( + "Capturing CUDA graphs for %d batches (%s)", + len(batch_descriptors), + ", ".join(str(desc.num_tokens) for desc in batch_descriptors), + ) batch_descriptors = tqdm( batch_descriptors, disable=not self.load_config.use_tqdm_on_load, @@ -6375,7 +6400,6 @@ def _check_and_update_cudagraph_mode( # sizes for decode and mixed prefill-decode. if ( cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() and self.uniform_decode_query_len > 1 ): self.compilation_config.adjust_cudagraph_sizes_for_spec_decode(