From 85e9d81237462f344995e78667e4fe8b14f59d6c Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 20 Mar 2026 04:59:03 -0600 Subject: [PATCH 1/8] Cherry-pick PR #34880: enable FULL CUDAGraph for EAGLE proposer Cherry-pick 409a12e3a to enable FULL CUDAGraph mode for the EAGLE proposer during draft speculative steps, reducing CPU overhead. Signed-off-by: Matthias Gehre --- vllm/v1/cudagraph_dispatcher.py | 38 +++++++- vllm/v1/spec_decode/eagle.py | 140 +++++++++++++++++++++-------- vllm/v1/worker/gpu_model_runner.py | 43 ++++++--- 3 files changed, 172 insertions(+), 49 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 701c97d6de42..fa858b0b227e 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -31,8 +31,9 @@ class CudagraphDispatcher: runnable without cudagraph (if the mode does not match or mode is NONE). """ - def __init__(self, vllm_config: VllmConfig): + def __init__(self, vllm_config: VllmConfig, for_draft_model: bool = False): self.vllm_config = vllm_config + self.for_draft_model = for_draft_model self.compilation_config = vllm_config.compilation_config self.uniform_decode_query_len = ( 1 @@ -131,9 +132,11 @@ def _create_padded_batch_descriptor( uniform_decode: bool, has_lora: bool, num_active_loras: int = 0, + uniform_decode_query_len: int | None = None, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs - uniform_decode_query_len = self.uniform_decode_query_len + if uniform_decode_query_len is None: + uniform_decode_query_len = self.uniform_decode_query_len num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): @@ -226,6 +229,26 @@ def initialize_cudagraph_keys( ), ) + if self.for_draft_model and cudagraph_mode.has_full_cudagraphs(): + max_num_tokens = self.vllm_config.scheduler_config.max_num_seqs + assert self.compilation_config.cudagraph_capture_sizes is not None, ( + "Cudagraph capture sizes must be set when full mode is enabled." + ) + capture_sizes_for_draft_model = [] + for size in self.compilation_config.cudagraph_capture_sizes: + capture_sizes_for_draft_model.append(size) + if size >= max_num_tokens: + break + for bs, num_active_loras in product( + capture_sizes_for_draft_model, lora_cases + ): + self.add_cudagraph_key( + CUDAGraphMode.FULL, + self._create_padded_batch_descriptor( + bs, True, num_active_loras > 0, num_active_loras, 1 + ), + ) + self.keys_initialized = True def dispatch( @@ -236,6 +259,7 @@ def dispatch( num_active_loras: int = 0, valid_modes: AbstractSet[CUDAGraphMode] | None = None, invalid_modes: AbstractSet[CUDAGraphMode] | None = None, + uniform_decode_query_len: int | None = None, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using piecewise only), @@ -293,9 +317,15 @@ def dispatch( ) effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1 - normalized_uniform = uniform_decode and self.cudagraph_mode.separate_routine() + normalized_uniform = uniform_decode and ( + self.cudagraph_mode.separate_routine() or self.for_draft_model + ) batch_desc = self._create_padded_batch_descriptor( - num_tokens, normalized_uniform, has_lora, effective_num_active_loras + num_tokens, + normalized_uniform, + has_lora, + effective_num_active_loras, + uniform_decode_query_len, ) if CUDAGraphMode.FULL in allowed_modes: diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 9414ab598454..a758dae3a5dc 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,13 +9,14 @@ import torch import torch.nn as nn +from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.config import ( CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, ) from vllm.distributed.parallel_state import get_pp_group -from vllm.forward_context import set_forward_context +from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model @@ -124,7 +125,7 @@ def __init__( # Keys are initialized later via initialize_cudagraph_keys() called from # gpu_model_runner._check_and_update_cudagraph_mode after # adjust_cudagraph_sizes_for_spec_decode is called. - self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config, True) # persistent buffers for cuda graph self.input_ids = torch.zeros( @@ -360,21 +361,11 @@ def _get_slot_mapping( return {name: view for name in self._draft_attn_layer_names} def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: - """Initialize cudagraph dispatcher keys for eagle. + """Initialize cudagraph dispatcher keys for eagle.""" + if self.speculative_config.enforce_eager: + cudagraph_mode = CUDAGraphMode.NONE - Eagle only supports PIECEWISE cudagraphs (via mixed_mode). - This should be called after adjust_cudagraph_sizes_for_spec_decode. - """ - if ( - not self.speculative_config.enforce_eager - and cudagraph_mode.mixed_mode() - in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] - ): - eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE - else: - eagle_cudagraph_mode = CUDAGraphMode.NONE - - self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) + self.cudagraph_dispatcher.initialize_cudagraph_keys(cudagraph_mode) def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor: """Greedy-sample draft tokens from hidden states.""" @@ -394,6 +385,7 @@ def propose( next_token_ids: torch.Tensor, token_indices_to_sample: torch.Tensor | None, common_attn_metadata: CommonAttentionMetadata, + target_model_batch_desc: BatchDescriptor, sampling_metadata: SamplingMetadata, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None, @@ -404,8 +396,12 @@ def propose( batch_size = common_attn_metadata.batch_size() if self.method == "eagle3": + if isinstance(self.model, CUDAGraphWrapper): + model = self.model.unwrap() + else: + model = self.model assert isinstance( - self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states @@ -434,8 +430,9 @@ def propose( for layer_name in attn_group.layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( - self._determine_batch_execution_and_padding(num_tokens) + uniform_decode = target_model_batch_desc.uniform + cudagraph_runtime_mode, batch_desc, num_input_tokens, num_tokens_across_dp = ( + self._determine_batch_execution_and_padding(num_tokens, uniform_decode) ) if self.supports_mm_inputs: @@ -467,6 +464,7 @@ def propose( num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=self._get_slot_mapping( num_input_tokens, common_attn_metadata.slot_mapping ), @@ -520,14 +518,18 @@ def propose( # Generate the remaining draft tokens. draft_token_ids_list = [draft_token_ids] - cudagraph_runtime_mode, input_batch_size, batch_size_across_dp = ( - self._determine_batch_execution_and_padding(batch_size) + cudagraph_runtime_mode, batch_desc, input_batch_size, batch_size_across_dp = ( + self._determine_batch_execution_and_padding( + batch_size, True, uniform_decode_query_len=1 + ) ) common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.max_query_len = 1 - common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] - common_attn_metadata.query_start_loc_cpu = torch.from_numpy( + common_attn_metadata.query_start_loc[: batch_size + 1] = self.arange[ + : batch_size + 1 + ] + common_attn_metadata.query_start_loc_cpu[: batch_size + 1] = torch.from_numpy( self.token_arange_np[: batch_size + 1] ).clone() @@ -631,6 +633,7 @@ def propose( num_tokens=input_batch_size, num_tokens_across_dp=batch_size_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=self._get_slot_mapping(input_batch_size), ): ret_hidden_states = self.model(**model_kwargs) @@ -836,6 +839,8 @@ def prepare_next_token_ids_padded( """ # Precompute get_token_id for when there is no valid next token num_reqs = gpu_input_batch.num_reqs + # NOTE: For CUDA Graph, we need the `num_reqs_padded` here + batch_size = common_attn_metadata.num_reqs self.backup_next_token_ids.np[:num_reqs] = np.array( [ requests[gpu_input_batch.req_ids[i]].get_token_id( @@ -848,14 +853,14 @@ def prepare_next_token_ids_padded( self.backup_next_token_ids.copy_to_gpu(num_reqs) backup_tokens_gpu = self.backup_next_token_ids.gpu - batch_size, num_tokens = sampled_token_ids.shape + _, num_tokens = sampled_token_ids.shape device = sampled_token_ids.device assert discard_request_mask.dtype == torch.bool assert backup_tokens_gpu.dtype == torch.int32 - next_token_ids = torch.empty(batch_size, dtype=torch.int32, device=device) - valid_sampled_tokens_count = next_token_ids.new_empty(batch_size) + next_token_ids = torch.zeros(batch_size, dtype=torch.int32, device=device) + valid_sampled_tokens_count = next_token_ids.new_zeros(batch_size) # Kernel grid: one program per request (row) grid = (batch_size,) @@ -882,6 +887,7 @@ def prepare_inputs_padded( common_attn_metadata: CommonAttentionMetadata, spec_decode_metadata: SpecDecodeMetadata, valid_sampled_tokens_count: torch.Tensor, + gpu_input_batch: InputBatch, ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: """ This function is used to prepare the inputs for speculative decoding @@ -901,6 +907,14 @@ def prepare_inputs_padded( (num_reqs,), dtype=torch.int32, device=device ) + actual_num_reqs = gpu_input_batch.num_reqs + spec_decode_metadata.cu_num_draft_tokens = nn.functional.pad( + spec_decode_metadata.cu_num_draft_tokens, + (0, num_reqs - actual_num_reqs), + mode="constant", + value=spec_decode_metadata.cu_num_draft_tokens[-1], + ) + grid = (num_reqs,) eagle_prepare_inputs_padded_kernel[grid]( spec_decode_metadata.cu_num_draft_tokens, @@ -1241,6 +1255,19 @@ def load_model(self, target_model: nn.Module) -> None: ) self.model = self._get_model() + # wrap the model with full cudagraph wrapper if needed. + cudagraph_mode = self.compilation_config.cudagraph_mode + if ( + cudagraph_mode.has_full_cudagraphs() + and not self.vllm_config.parallel_config.use_ubatching + and not self.speculative_config.disable_padded_drafter_batch + ): + # Currently Ubatch does not support FULL in speculative decoding, unpadded + # drafter batch either due to the dynamic number of tokens. + # We can consider supporting FULL for these cases in the future if needed. + self.model = CUDAGraphWrapper( + self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL + ) # Find draft layers (attention layers added by draft model) all_attn_layers = get_layers_from_vllm_config( @@ -1475,6 +1502,7 @@ def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None: def dummy_run( self, num_tokens: int, + common_attn_metadata: CommonAttentionMetadata | None = None, use_cudagraphs: bool = True, is_graph_capturing: bool = False, slot_mappings: dict[str, torch.Tensor] | None = None, @@ -1482,14 +1510,38 @@ def dummy_run( # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. for fwd_idx in range( - self.num_speculative_tokens if not is_graph_capturing else 1 + self.num_speculative_tokens + if not is_graph_capturing + else min(self.num_speculative_tokens, 2) ): - if fwd_idx <= 1: - cudagraph_runtime_mode, num_input_tokens, num_tokens_across_dp = ( - self._determine_batch_execution_and_padding( - num_tokens, use_cudagraphs=use_cudagraphs - ) + if fwd_idx > 0 and common_attn_metadata is not None: + # All speculative steps except the first one typically use + # a uniform decode with 1 token per request. + uniform_decode = True + num_tokens = common_attn_metadata.num_reqs + uniform_decode_query_len = 1 + else: + # For the first step, note that for FULL_DECODE_ONLY and + # FULL_AND_PIECEWISE we need to set uniform_decode to True + # while for FULL we don't + mode = self.cudagraph_dispatcher.cudagraph_mode + is_full_sep = ( + mode.decode_mode() == CUDAGraphMode.FULL and mode.separate_routine() ) + uniform_decode = is_full_sep and common_attn_metadata is not None + uniform_decode_query_len = None + + ( + cudagraph_runtime_mode, + batch_desc, + num_input_tokens, + num_tokens_across_dp, + ) = self._determine_batch_execution_and_padding( + num_tokens, + uniform_decode, + use_cudagraphs=use_cudagraphs, + uniform_decode_query_len=uniform_decode_query_len, + ) # Make sure to use EAGLE's own buffer during cudagraph capture. if ( @@ -1501,12 +1553,26 @@ def dummy_run( else: slot_mapping_dict = slot_mappings or {} + if common_attn_metadata is not None: + dummy_attn_metadata = {} + for attn_group in self.draft_attn_groups: + attn_metadata = ( + attn_group.get_metadata_builder().build_for_drafting( + common_attn_metadata=common_attn_metadata, draft_index=0 + ) + ) + for layer_name in attn_group.layer_names: + dummy_attn_metadata[layer_name] = attn_metadata + else: + dummy_attn_metadata = None + with set_forward_context( - None, + dummy_attn_metadata, self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, + batch_descriptor=batch_desc, slot_mapping=slot_mapping_dict, ): if self.supports_mm_inputs: @@ -1629,11 +1695,15 @@ def initialize_attn_backend( def _determine_batch_execution_and_padding( self, num_tokens: int, + uniform_decode: bool = False, use_cudagraphs: bool = True, - ) -> tuple[CUDAGraphMode, int, torch.Tensor | None]: + uniform_decode_query_len: int | None = None, + ) -> tuple[CUDAGraphMode, BatchDescriptor, int, torch.Tensor | None]: cudagraph_mode, batch_desc = self.cudagraph_dispatcher.dispatch( num_tokens, + uniform_decode, valid_modes=({CUDAGraphMode.NONE} if not use_cudagraphs else None), + uniform_decode_query_len=uniform_decode_query_len, ) num_tokens_padded = batch_desc.num_tokens @@ -1668,7 +1738,7 @@ def _determine_batch_execution_and_padding( assert batch_desc.num_tokens == num_tokens_padded num_tokens_across_dp[dp_rank] = num_tokens_padded - return cudagraph_mode, num_tokens_padded, num_tokens_across_dp + return cudagraph_mode, batch_desc, num_tokens_padded, num_tokens_across_dp class EagleProposer(SpecDecodeBaseProposer): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index aab26e90f1bd..e2254c8b230e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -387,6 +387,7 @@ class ExecuteModelState(NamedTuple): ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None + batch_desc: BatchDescriptor class GPUModelRunner( @@ -506,6 +507,7 @@ def __init__( self.late_interaction_runner = LateInteractionRunner() self.use_aux_hidden_state_outputs = False + self.supports_sd_full_graph = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many @@ -520,6 +522,12 @@ def __init__( | MedusaProposer | ExtractHiddenStatesProposer ) + + self.supports_sd_full_graph = ( + self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + if self.speculative_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -2153,11 +2161,13 @@ def _build_attn_group_metadata( for _metadata in attn_metadata.values(): _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] - if spec_decode_common_attn_metadata is not None and ( - num_reqs != num_reqs_padded or num_tokens != num_tokens_padded + if ( + spec_decode_common_attn_metadata is not None + and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded) + and not self.supports_sd_full_graph ): - # Currently the drafter still only uses piecewise cudagraphs (and modifies - # the attention metadata in directly), and therefore does not want to use + # Currently the drafter still only uses piecewise cudagraphs (except for + # Eagle, which supports FULL now), and therefore does not want to use # padded attention metadata. spec_decode_common_attn_metadata = ( spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) @@ -3931,6 +3941,7 @@ def execute_model( ec_connector_output, cudagraph_stats, slot_mappings, + batch_desc, ) self.kv_connector_output = kv_connector_output return None @@ -3969,6 +3980,7 @@ def sample_tokens( ec_connector_output, cudagraph_stats, slot_mappings, + batch_desc, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -4012,6 +4024,7 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, slot_mappings, + batch_desc, ) self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -4301,6 +4314,7 @@ def propose_draft_token_ids( spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + target_model_batch_desc: BatchDescriptor, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4494,6 +4508,7 @@ def propose_draft_token_ids( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count, + self.input_batch, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -4521,8 +4536,9 @@ def propose_draft_token_ids( target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, token_indices_to_sample=token_indices_to_sample, - sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, + target_model_batch_desc=target_model_batch_desc, + sampling_metadata=sampling_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, @@ -5166,6 +5182,7 @@ def _dummy_run( ) attn_metadata: PerLayerAttnMetadata | None = None + spec_decode_cm: CommonAttentionMetadata | None = None slot_mappings_by_group, slot_mappings = self._get_slot_mappings( num_tokens_padded=num_tokens, @@ -5200,7 +5217,7 @@ def _dummy_run( self.query_start_loc.copy_to_gpu() pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL - attn_metadata, _ = self._build_attention_metadata( + attn_metadata, spec_decode_cm = self._build_attention_metadata( num_tokens=num_tokens_unpadded, num_tokens_padded=num_tokens_padded if pad_attn else None, num_reqs=num_reqs_padded, @@ -5303,13 +5320,14 @@ def _dummy_run( EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, ) assert self.speculative_config is not None - # Eagle currently only supports PIECEWISE cudagraphs. - # Therefore only use cudagraphs if the main model uses PIECEWISE # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( ( is_graph_capturing - and cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + and ( + cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE + or self.supports_sd_full_graph + ) ) or ( not is_graph_capturing @@ -5329,6 +5347,7 @@ def _dummy_run( self.drafter.dummy_run( num_tokens, + common_attn_metadata=spec_decode_cm, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, slot_mappings=slot_mappings, @@ -5877,6 +5896,11 @@ def _capture_cudagraphs( # Only rank 0 should print progress bar during capture if is_global_first_rank(): + logger.info( + "Capturing CUDA graphs for %d batches (%s)", + len(batch_descriptors), + ", ".join(str(desc.num_tokens) for desc in batch_descriptors), + ) batch_descriptors = tqdm( batch_descriptors, disable=not self.load_config.use_tqdm_on_load, @@ -6159,7 +6183,6 @@ def _check_and_update_cudagraph_mode( # sizes for decode and mixed prefill-decode. if ( cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() and self.uniform_decode_query_len > 1 ): self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( From e5b3715e57e6a08431f91be4cb7410345b50ecb3 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 20 Mar 2026 06:12:56 -0600 Subject: [PATCH 2/8] Guard int4/int8 skinny GEMM sweep functions behind VLLM_SKINNY_GEMM_SWEEP The sweep functions in skinny_gemms_int4.cu and skinny_gemms_int8.cu instantiate many template combinations only used for benchmarking. Wrapping them with #ifdef VLLM_SKINNY_GEMM_SWEEP (matching the existing pattern in skinny_gemms.cu) reduces build time from ~1236s to ~213s. Signed-off-by: Matthias Gehre --- csrc/rocm/ops.h | 32 +-- csrc/rocm/skinny_gemms_int4.cu | 405 +++++++++++++++++---------------- csrc/rocm/skinny_gemms_int8.cu | 120 +++++----- csrc/rocm/torch_bindings.cpp | 28 +-- 4 files changed, 299 insertions(+), 286 deletions(-) diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index e310269701b6..1f097c1ad3c7 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -14,6 +14,22 @@ torch::Tensor wvSplitK_int8(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, const int64_t CuCount); +torch::Tensor wvSplitK_int4(const at::Tensor& in_a, const at::Tensor& in_b, + const at::Tensor& in_scale, + const std::optional& in_bias, + const int64_t CuCount); + +torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b, + const at::Tensor& in_scale, + const std::optional& in_bias, + const int64_t CuCount, const int64_t group_size); + +#ifdef VLLM_SKINNY_GEMM_SWEEP +torch::Tensor wvSplitK_sweep(const at::Tensor& in_a, const at::Tensor& in_b, + const std::optional& in_bias, + const int64_t CuCount, const int64_t ytile, + const int64_t unrl); + torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, @@ -22,11 +38,6 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a, const int64_t unrl, const int64_t achunk, const int64_t wvprgrp); -torch::Tensor wvSplitK_int4(const at::Tensor& in_a, const at::Tensor& in_b, - const at::Tensor& in_scale, - const std::optional& in_bias, - const int64_t CuCount); - torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, @@ -35,11 +46,6 @@ torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a, const int64_t unrl, const int64_t achunk, const int64_t wvprgrp); -torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b, - const at::Tensor& in_scale, - const std::optional& in_bias, - const int64_t CuCount, const int64_t group_size); - torch::Tensor wvSplitK_int4g_sweep( const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, const int64_t CuCount, const int64_t group_size, const int64_t ytile, @@ -49,12 +55,6 @@ torch::Tensor wvSplitK_int4g_hf_sweep( const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, const int64_t CuCount, const int64_t group_size, const int64_t ytile, const int64_t unrl, const int64_t achunk, const int64_t wvprgrp); - -#ifdef VLLM_SKINNY_GEMM_SWEEP -torch::Tensor wvSplitK_sweep(const at::Tensor& in_a, const at::Tensor& in_b, - const std::optional& in_bias, - const int64_t CuCount, const int64_t ytile, - const int64_t unrl); #endif torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, diff --git a/csrc/rocm/skinny_gemms_int4.cu b/csrc/rocm/skinny_gemms_int4.cu index 12b65f8bf728..4032ff40e765 100644 --- a/csrc/rocm/skinny_gemms_int4.cu +++ b/csrc/rocm/skinny_gemms_int4.cu @@ -721,6 +721,9 @@ torch::Tensor wvSplitK_int4(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } +// Sweep functions disabled by default to reduce compile time. +// Build with -DVLLM_SKINNY_GEMM_SWEEP to enable. +#ifdef VLLM_SKINNY_GEMM_SWEEP torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, @@ -766,62 +769,62 @@ torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a, const int THRDS = is_gfx11_int4() ? 32 : 64; -#define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \ - { \ - dim3 block(_THRDS, _WVPRGRP); \ - int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ - wvSplitK_int4_hf_sml_ \ - <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ - biasptr, cptr, __wvPrGrp, CuCount); \ - } + #define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \ + { \ + dim3 block(_THRDS, _WVPRGRP); \ + int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ + wvSplitK_int4_hf_sml_ \ + <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ + biasptr, cptr, __wvPrGrp, CuCount); \ + } -#define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \ - switch (N_in) { \ - case 1: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \ - case 2: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \ - case 3: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \ - case 4: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \ - default: \ - TORCH_CHECK(false, "Unsupported N=", N_in); \ - } + #define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \ + switch (N_in) { \ + case 1: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \ + case 2: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \ + case 3: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \ + case 4: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \ + default: \ + TORCH_CHECK(false, "Unsupported N=", N_in); \ + } -#define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \ - if (unrl == 1) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \ - } else if (unrl == 2) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \ - } else if (unrl == 4) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \ - } else { \ - TORCH_CHECK(false, "Unsupported unrl=", unrl); \ - } + #define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \ + if (unrl == 1) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \ + } else if (unrl == 2) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \ + } else if (unrl == 4) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \ + } else { \ + TORCH_CHECK(false, "Unsupported unrl=", unrl); \ + } -#define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \ - if (ytile == 1) { \ - SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \ - } else if (ytile == 2) { \ - SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \ - } else if (ytile == 4) { \ - SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \ - } else { \ - TORCH_CHECK(false, "Unsupported ytile=", ytile); \ - } + #define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \ + if (ytile == 1) { \ + SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \ + } else if (ytile == 2) { \ + SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \ + } else if (ytile == 4) { \ + SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \ + } else { \ + TORCH_CHECK(false, "Unsupported ytile=", ytile); \ + } -#define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \ - if (wvprgrp == 8) { \ - SWEEP_YTILE(_THRDS, 8, _ACHUNK) \ - } else if (wvprgrp == 12) { \ - SWEEP_YTILE(_THRDS, 12, _ACHUNK) \ - } else if (wvprgrp == 16) { \ - SWEEP_YTILE(_THRDS, 16, _ACHUNK) \ - } else { \ - TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ - } + #define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \ + if (wvprgrp == 8) { \ + SWEEP_YTILE(_THRDS, 8, _ACHUNK) \ + } else if (wvprgrp == 12) { \ + SWEEP_YTILE(_THRDS, 12, _ACHUNK) \ + } else if (wvprgrp == 16) { \ + SWEEP_YTILE(_THRDS, 16, _ACHUNK) \ + } else { \ + TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ + } if (THRDS == 32) { if (achunk == 8) { @@ -845,15 +848,17 @@ torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a, } } -#undef SWEEP_LAUNCH -#undef SWEEP_N -#undef SWEEP_UNRL -#undef SWEEP_YTILE -#undef SWEEP_WVPRGRP + #undef SWEEP_LAUNCH + #undef SWEEP_N + #undef SWEEP_UNRL + #undef SWEEP_YTILE + #undef SWEEP_WVPRGRP return out_c; } +#endif // VLLM_SKINNY_GEMM_SWEEP + // Per-group W4A16 skinny GEMM: packed int4 weights with group-wise scales. // in_a: packed int4 weights [M, K/2] (int8) or [M, K/8] (int32) // in_b: activations [N, K] (fp16/bf16) @@ -1000,6 +1005,7 @@ torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } +#ifdef VLLM_SKINNY_GEMM_SWEEP torch::Tensor wvSplitK_int4g_sweep( const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, const int64_t CuCount, const int64_t group_size, const int64_t ytile, @@ -1048,73 +1054,77 @@ torch::Tensor wvSplitK_int4g_sweep( const int THRDS = is_gfx11_int4() ? 32 : 64; -#define SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N, _GS) \ - { \ - dim3 block(_THRDS, _WVPRGRP); \ - int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ - wvSplitK_int4_hf_sml_ \ - <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ - biasptr, cptr, __wvPrGrp, CuCount); \ - } + #define SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N, _GS) \ + { \ + dim3 block(_THRDS, _WVPRGRP); \ + int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ + wvSplitK_int4_hf_sml_ \ + <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ + biasptr, cptr, __wvPrGrp, CuCount); \ + } -#define SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _GS) \ - switch (N_in) { \ - case 1: \ - SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1, _GS) break; \ - case 2: \ - SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2, _GS) break; \ - case 3: \ - SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3, _GS) break; \ - case 4: \ - SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4, _GS) break; \ - default: \ - TORCH_CHECK(false, "Unsupported N=", N_in); \ - } + #define SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _GS) \ + switch (N_in) { \ + case 1: \ + SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1, _GS) \ + break; \ + case 2: \ + SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2, _GS) \ + break; \ + case 3: \ + SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3, _GS) \ + break; \ + case 4: \ + SWEEP_G_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4, _GS) \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported N=", N_in); \ + } -#define SWEEP_G_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _GS) \ - if (unrl == 1) { \ - SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1, _GS) \ - } else if (unrl == 2) { \ - SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2, _GS) \ - } else if (unrl == 4) { \ - SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported unrl=", unrl); \ - } + #define SWEEP_G_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _GS) \ + if (unrl == 1) { \ + SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1, _GS) \ + } else if (unrl == 2) { \ + SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2, _GS) \ + } else if (unrl == 4) { \ + SWEEP_G_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported unrl=", unrl); \ + } -#define SWEEP_G_YTILE(_THRDS, _WVPRGRP, _ACHUNK, _GS) \ - if (ytile == 1) { \ - SWEEP_G_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK, _GS) \ - } else if (ytile == 2) { \ - SWEEP_G_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK, _GS) \ - } else if (ytile == 4) { \ - SWEEP_G_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported ytile=", ytile); \ - } + #define SWEEP_G_YTILE(_THRDS, _WVPRGRP, _ACHUNK, _GS) \ + if (ytile == 1) { \ + SWEEP_G_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK, _GS) \ + } else if (ytile == 2) { \ + SWEEP_G_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK, _GS) \ + } else if (ytile == 4) { \ + SWEEP_G_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported ytile=", ytile); \ + } -#define SWEEP_G_WVPRGRP(_THRDS, _ACHUNK, _GS) \ - if (wvprgrp == 8) { \ - SWEEP_G_YTILE(_THRDS, 8, _ACHUNK, _GS) \ - } else if (wvprgrp == 12) { \ - SWEEP_G_YTILE(_THRDS, 12, _ACHUNK, _GS) \ - } else if (wvprgrp == 16) { \ - SWEEP_G_YTILE(_THRDS, 16, _ACHUNK, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ - } + #define SWEEP_G_WVPRGRP(_THRDS, _ACHUNK, _GS) \ + if (wvprgrp == 8) { \ + SWEEP_G_YTILE(_THRDS, 8, _ACHUNK, _GS) \ + } else if (wvprgrp == 12) { \ + SWEEP_G_YTILE(_THRDS, 12, _ACHUNK, _GS) \ + } else if (wvprgrp == 16) { \ + SWEEP_G_YTILE(_THRDS, 16, _ACHUNK, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ + } -#define SWEEP_G_ACHUNK(_THRDS, _GS) \ - if (achunk == 8) { \ - SWEEP_G_WVPRGRP(_THRDS, 8, _GS) \ - } else if (achunk == 16) { \ - SWEEP_G_WVPRGRP(_THRDS, 16, _GS) \ - } else if (achunk == 32) { \ - SWEEP_G_WVPRGRP(_THRDS, 32, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported achunk=", achunk); \ - } + #define SWEEP_G_ACHUNK(_THRDS, _GS) \ + if (achunk == 8) { \ + SWEEP_G_WVPRGRP(_THRDS, 8, _GS) \ + } else if (achunk == 16) { \ + SWEEP_G_WVPRGRP(_THRDS, 16, _GS) \ + } else if (achunk == 32) { \ + SWEEP_G_WVPRGRP(_THRDS, 32, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported achunk=", achunk); \ + } if (THRDS == 32) { if (group_size == 128) { @@ -1130,12 +1140,12 @@ torch::Tensor wvSplitK_int4g_sweep( } } -#undef SWEEP_G_LAUNCH -#undef SWEEP_G_N -#undef SWEEP_G_UNRL -#undef SWEEP_G_YTILE -#undef SWEEP_G_WVPRGRP -#undef SWEEP_G_ACHUNK + #undef SWEEP_G_LAUNCH + #undef SWEEP_G_N + #undef SWEEP_G_UNRL + #undef SWEEP_G_YTILE + #undef SWEEP_G_WVPRGRP + #undef SWEEP_G_ACHUNK return out_c; } @@ -1188,77 +1198,77 @@ torch::Tensor wvSplitK_int4g_hf_sweep( const int THRDS = is_gfx11_int4() ? 32 : 64; -#define SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N, _GS) \ - { \ - dim3 block(_THRDS, _WVPRGRP); \ - int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ - wvSplitK_int4_hf_ \ - <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ - biasptr, cptr, __wvPrGrp, CuCount); \ - } + #define SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N, _GS) \ + { \ + dim3 block(_THRDS, _WVPRGRP); \ + int __wvPrGrp = mindiv_int4(M_in, CuCount * _YTILE, _WVPRGRP); \ + wvSplitK_int4_hf_ \ + <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ + biasptr, cptr, __wvPrGrp, CuCount); \ + } -#define SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _GS) \ - switch (N_in) { \ - case 1: \ - SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1, _GS) \ - break; \ - case 2: \ - SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2, _GS) \ - break; \ - case 3: \ - SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3, _GS) \ - break; \ - case 4: \ - SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4, _GS) \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported N=", N_in); \ - } + #define SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _GS) \ + switch (N_in) { \ + case 1: \ + SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1, _GS) \ + break; \ + case 2: \ + SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2, _GS) \ + break; \ + case 3: \ + SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3, _GS) \ + break; \ + case 4: \ + SWEEP_GHF_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4, _GS) \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported N=", N_in); \ + } -#define SWEEP_GHF_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _GS) \ - if (unrl == 1) { \ - SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1, _GS) \ - } else if (unrl == 2) { \ - SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2, _GS) \ - } else if (unrl == 4) { \ - SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported unrl=", unrl); \ - } + #define SWEEP_GHF_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _GS) \ + if (unrl == 1) { \ + SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1, _GS) \ + } else if (unrl == 2) { \ + SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2, _GS) \ + } else if (unrl == 4) { \ + SWEEP_GHF_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported unrl=", unrl); \ + } -#define SWEEP_GHF_YTILE(_THRDS, _WVPRGRP, _ACHUNK, _GS) \ - if (ytile == 1) { \ - SWEEP_GHF_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK, _GS) \ - } else if (ytile == 2) { \ - SWEEP_GHF_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK, _GS) \ - } else if (ytile == 4) { \ - SWEEP_GHF_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported ytile=", ytile); \ - } + #define SWEEP_GHF_YTILE(_THRDS, _WVPRGRP, _ACHUNK, _GS) \ + if (ytile == 1) { \ + SWEEP_GHF_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK, _GS) \ + } else if (ytile == 2) { \ + SWEEP_GHF_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK, _GS) \ + } else if (ytile == 4) { \ + SWEEP_GHF_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported ytile=", ytile); \ + } -#define SWEEP_GHF_WVPRGRP(_THRDS, _ACHUNK, _GS) \ - if (wvprgrp == 8) { \ - SWEEP_GHF_YTILE(_THRDS, 8, _ACHUNK, _GS) \ - } else if (wvprgrp == 12) { \ - SWEEP_GHF_YTILE(_THRDS, 12, _ACHUNK, _GS) \ - } else if (wvprgrp == 16) { \ - SWEEP_GHF_YTILE(_THRDS, 16, _ACHUNK, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ - } + #define SWEEP_GHF_WVPRGRP(_THRDS, _ACHUNK, _GS) \ + if (wvprgrp == 8) { \ + SWEEP_GHF_YTILE(_THRDS, 8, _ACHUNK, _GS) \ + } else if (wvprgrp == 12) { \ + SWEEP_GHF_YTILE(_THRDS, 12, _ACHUNK, _GS) \ + } else if (wvprgrp == 16) { \ + SWEEP_GHF_YTILE(_THRDS, 16, _ACHUNK, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ + } -#define SWEEP_GHF_ACHUNK(_THRDS, _GS) \ - if (achunk == 8) { \ - SWEEP_GHF_WVPRGRP(_THRDS, 8, _GS) \ - } else if (achunk == 16) { \ - SWEEP_GHF_WVPRGRP(_THRDS, 16, _GS) \ - } else if (achunk == 32) { \ - SWEEP_GHF_WVPRGRP(_THRDS, 32, _GS) \ - } else { \ - TORCH_CHECK(false, "Unsupported achunk=", achunk); \ - } + #define SWEEP_GHF_ACHUNK(_THRDS, _GS) \ + if (achunk == 8) { \ + SWEEP_GHF_WVPRGRP(_THRDS, 8, _GS) \ + } else if (achunk == 16) { \ + SWEEP_GHF_WVPRGRP(_THRDS, 16, _GS) \ + } else if (achunk == 32) { \ + SWEEP_GHF_WVPRGRP(_THRDS, 32, _GS) \ + } else { \ + TORCH_CHECK(false, "Unsupported achunk=", achunk); \ + } if (THRDS == 32) { if (group_size == 128) { @@ -1278,12 +1288,13 @@ torch::Tensor wvSplitK_int4g_hf_sweep( } } -#undef SWEEP_GHF_LAUNCH -#undef SWEEP_GHF_N -#undef SWEEP_GHF_UNRL -#undef SWEEP_GHF_YTILE -#undef SWEEP_GHF_WVPRGRP -#undef SWEEP_GHF_ACHUNK + #undef SWEEP_GHF_LAUNCH + #undef SWEEP_GHF_N + #undef SWEEP_GHF_UNRL + #undef SWEEP_GHF_YTILE + #undef SWEEP_GHF_WVPRGRP + #undef SWEEP_GHF_ACHUNK return out_c; } +#endif // VLLM_SKINNY_GEMM_SWEEP diff --git a/csrc/rocm/skinny_gemms_int8.cu b/csrc/rocm/skinny_gemms_int8.cu index c2321c67b011..f8e092fcf9db 100644 --- a/csrc/rocm/skinny_gemms_int8.cu +++ b/csrc/rocm/skinny_gemms_int8.cu @@ -415,8 +415,9 @@ torch::Tensor wvSplitK_int8(const at::Tensor& in_a, const at::Tensor& in_b, return out_c; } -// Sweep function: all 4 tuning params dispatched at runtime (fp16 only). -// Used for benchmarking only — not for production. +// Sweep function disabled by default to reduce compile time. +// Build with -DVLLM_SKINNY_GEMM_SWEEP to enable. +#ifdef VLLM_SKINNY_GEMM_SWEEP torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale, @@ -459,62 +460,62 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a, const int THRDS = is_gfx11_int8() ? 32 : 64; -#define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \ - { \ - dim3 block(_THRDS, _WVPRGRP); \ - int __wvPrGrp = mindiv_int8(M_in, CuCount * _YTILE, _WVPRGRP); \ - wvSplitK_int8_hf_sml_ \ - <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ - biasptr, cptr, __wvPrGrp, CuCount); \ - } + #define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \ + { \ + dim3 block(_THRDS, _WVPRGRP); \ + int __wvPrGrp = mindiv_int8(M_in, CuCount * _YTILE, _WVPRGRP); \ + wvSplitK_int8_hf_sml_ \ + <<>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \ + biasptr, cptr, __wvPrGrp, CuCount); \ + } -#define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \ - switch (N_in) { \ - case 1: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \ - case 2: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \ - case 3: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \ - case 4: \ - SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \ - default: \ - TORCH_CHECK(false, "Unsupported N=", N_in); \ - } + #define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \ + switch (N_in) { \ + case 1: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \ + case 2: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \ + case 3: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \ + case 4: \ + SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \ + default: \ + TORCH_CHECK(false, "Unsupported N=", N_in); \ + } -#define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \ - if (unrl == 1) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \ - } else if (unrl == 2) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \ - } else if (unrl == 4) { \ - SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \ - } else { \ - TORCH_CHECK(false, "Unsupported unrl=", unrl); \ - } + #define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \ + if (unrl == 1) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \ + } else if (unrl == 2) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \ + } else if (unrl == 4) { \ + SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \ + } else { \ + TORCH_CHECK(false, "Unsupported unrl=", unrl); \ + } -#define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \ - if (ytile == 1) { \ - SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \ - } else if (ytile == 2) { \ - SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \ - } else if (ytile == 4) { \ - SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \ - } else { \ - TORCH_CHECK(false, "Unsupported ytile=", ytile); \ - } + #define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \ + if (ytile == 1) { \ + SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \ + } else if (ytile == 2) { \ + SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \ + } else if (ytile == 4) { \ + SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \ + } else { \ + TORCH_CHECK(false, "Unsupported ytile=", ytile); \ + } -#define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \ - if (wvprgrp == 8) { \ - SWEEP_YTILE(_THRDS, 8, _ACHUNK) \ - } else if (wvprgrp == 12) { \ - SWEEP_YTILE(_THRDS, 12, _ACHUNK) \ - } else if (wvprgrp == 16) { \ - SWEEP_YTILE(_THRDS, 16, _ACHUNK) \ - } else { \ - TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ - } + #define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \ + if (wvprgrp == 8) { \ + SWEEP_YTILE(_THRDS, 8, _ACHUNK) \ + } else if (wvprgrp == 12) { \ + SWEEP_YTILE(_THRDS, 12, _ACHUNK) \ + } else if (wvprgrp == 16) { \ + SWEEP_YTILE(_THRDS, 16, _ACHUNK) \ + } else { \ + TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \ + } if (THRDS == 32) { if (achunk == 8) { @@ -538,11 +539,12 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a, } } -#undef SWEEP_LAUNCH -#undef SWEEP_N -#undef SWEEP_UNRL -#undef SWEEP_YTILE -#undef SWEEP_WVPRGRP + #undef SWEEP_LAUNCH + #undef SWEEP_N + #undef SWEEP_UNRL + #undef SWEEP_YTILE + #undef SWEEP_WVPRGRP return out_c; } +#endif // VLLM_SKINNY_GEMM_SWEEP diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 8a79fcff6054..d72c35a8da9d 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -40,13 +40,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "Tensor? in_bias, int CuCount) -> Tensor"); rocm_ops.impl("wvSplitK_int8", torch::kCUDA, &wvSplitK_int8); - // W8A16 skinny GEMM sweep: all tuning params as runtime args (benchmark only) - rocm_ops.def( - "wvSplitK_int8_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, " - "Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, " - "int wvprgrp) -> Tensor"); - rocm_ops.impl("wvSplitK_int8_sweep", torch::kCUDA, &wvSplitK_int8_sweep); - // W4A16 skinny GEMM: packed int4 weights, fp16/bf16 activations, per-channel // scale rocm_ops.def( @@ -54,19 +47,25 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "Tensor? in_bias, int CuCount) -> Tensor"); rocm_ops.impl("wvSplitK_int4", torch::kCUDA, &wvSplitK_int4); - // W4A16 skinny GEMM sweep: all tuning params as runtime args (benchmark only) - rocm_ops.def( - "wvSplitK_int4_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, " - "Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, " - "int wvprgrp) -> Tensor"); - rocm_ops.impl("wvSplitK_int4_sweep", torch::kCUDA, &wvSplitK_int4_sweep); - // W4A16 grouped skinny GEMM: packed int4 weights, per-group scales rocm_ops.def( "wvSplitK_int4_g(Tensor in_a, Tensor in_b, Tensor in_scale, " "Tensor? in_bias, int CuCount, int group_size) -> Tensor"); rocm_ops.impl("wvSplitK_int4_g", torch::kCUDA, &wvSplitK_int4_g); +#ifdef VLLM_SKINNY_GEMM_SWEEP + rocm_ops.def( + "wvSplitK_int8_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, " + "Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, " + "int wvprgrp) -> Tensor"); + rocm_ops.impl("wvSplitK_int8_sweep", torch::kCUDA, &wvSplitK_int8_sweep); + + rocm_ops.def( + "wvSplitK_int4_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, " + "Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, " + "int wvprgrp) -> Tensor"); + rocm_ops.impl("wvSplitK_int4_sweep", torch::kCUDA, &wvSplitK_int4_sweep); + rocm_ops.def( "wvSplitK_int4g_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, " "int CuCount, int group_size, int ytile, int unrl, int achunk, " @@ -79,6 +78,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "int wvprgrp) -> Tensor"); rocm_ops.impl("wvSplitK_int4g_hf_sweep", torch::kCUDA, &wvSplitK_int4g_hf_sweep); +#endif // VLLM_SKINNY_GEMM_SWEEP // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( From 4b4697be1f4be88377d1e7c7f0e7f07e3b581f82 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 20 Mar 2026 08:19:18 -0600 Subject: [PATCH 3/8] Fix drafter FULL cudagraph capture and add hipGraph dump support - Pass spec_decode_common_attn_metadata to drafter.dummy_run() so the drafter can dispatch uniform_decode=True and match FULL batch keys - Allow any non-NONE cudagraph mode during capture (not just PIECEWISE) so the drafter's FULL CUDAGraphWrapper actually triggers capture - Add hasattr fallback for get_eagle3_aux_hidden_state_layers to support models like Qwen3 that only have the default method - Add _dump_all_full_graphs() call after capture for hipGraph debugging - Re-apply PR #34880 changes lost during merge with awq_gemv_ifdef_sweep Signed-off-by: Matthias Gehre --- vllm/v1/worker/gpu_model_runner.py | 261 ++++++++++------------------- 1 file changed, 88 insertions(+), 173 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e2254c8b230e..ee018696f2a6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -181,7 +181,6 @@ ) from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.ec_connector_model_runner_mixin import ECConnectorModelRunnerMixin -from vllm.v1.worker.gpu.pool.late_interaction_runner import LateInteractionRunner from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin @@ -197,7 +196,6 @@ from .utils import ( AttentionGroup, - KVBlockZeroer, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, prepare_kernel_block_sizes, @@ -387,7 +385,6 @@ class ExecuteModelState(NamedTuple): ec_connector_output: ECConnectorOutput | None cudagraph_stats: CUDAGraphStat | None slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None - batch_desc: BatchDescriptor class GPUModelRunner( @@ -427,12 +424,8 @@ def __init__( self.is_multimodal_raw_input_only_model = ( model_config.is_multimodal_raw_input_only_model ) - # These will be overridden in load_model() + # This will be overridden in load_model() self.is_multimodal_pruning_enabled = False - self.requires_sequential_video_encoding = False - # Set to True after init_routed_experts_capturer() completes. - # Prevents routed experts code from running during profiling/dummy run. - self.routed_experts_initialized = False self.max_model_len = model_config.max_model_len # Always set to false after the first forward pass @@ -504,10 +497,8 @@ def __init__( # mm_hash -> encoder_output self.encoder_cache: dict[str, torch.Tensor] = {} - self.late_interaction_runner = LateInteractionRunner() self.use_aux_hidden_state_outputs = False - self.supports_sd_full_graph = False # Set up speculative decoding. # NOTE(Jiayi): currently we put the entire draft model on # the last PP rank. This is not ideal if there are many @@ -522,12 +513,6 @@ def __init__( | MedusaProposer | ExtractHiddenStatesProposer ) - - self.supports_sd_full_graph = ( - self.speculative_config.use_eagle() - and not self.speculative_config.disable_padded_drafter_batch - ) - if self.speculative_config.method == "ngram": from vllm.v1.spec_decode.ngram_proposer import NgramProposer @@ -754,19 +739,6 @@ def __init__( self.uniform_decode_query_len = 1 + self.num_spec_tokens - # When spec decode is active, the mamba backend classifies requests - # with query_len <= reorder_batch_threshold as "decodes". Prefill - # chunks that fall under this threshold get processed via the decode - # path, which stores intermediate states at sequential slots. We must - # set num_accepted_tokens to the chunk's query_len for those requests - # so the next iteration reads from the correct final-state slot. - # Prefills that went through the actual prefill path should keep the - # default value of 1 (the prefill path stores state at slot 0 only). - self.needs_prefill_as_decode_slots: bool = False - self.prefill_as_decode_num_tokens = self._make_buffer( - self.max_num_reqs, dtype=torch.int32 - ) - # Cudagraph dispatcher for runtime cudagraph dispatching. self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) @@ -865,7 +837,6 @@ def reset_mm_cache(self) -> None: """ if self.mm_budget: self.mm_budget.reset_cache() - self.late_interaction_runner.clear() def reset_encoder_cache(self) -> None: """Clear the GPU-side encoder cache storing vision embeddings. @@ -874,7 +845,6 @@ def reset_encoder_cache(self) -> None: stale embeddings computed with old weights are not reused. """ self.encoder_cache.clear() - self.late_interaction_runner.clear() @torch.inference_mode() def init_fp8_kv_scales(self) -> None: @@ -1014,26 +984,6 @@ def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> None: decode_threshold=self.reorder_batch_threshold, ) - def _init_kv_zero_meta(self) -> None: - """One-time precomputation for _zero_block_ids. - - Delegates to KVBlockZeroer.init_meta with the runner's state. - Called from gpu_worker.py outside the CuMem pool context. - """ - self._kv_block_zeroer = KVBlockZeroer(self.device, self.pin_memory) - self._kv_block_zeroer.init_meta( - attn_groups_iter=self._kv_cache_spec_attn_group_iterator(), - kernel_block_sizes=self._kernel_block_sizes, - cache_dtype=self.cache_config.cache_dtype, - runner_only_attn_layers=self.runner_only_attn_layers, - static_forward_context=(self.compilation_config.static_forward_context), - ) - - def _zero_block_ids(self, block_ids: list[int]) -> None: - """Zero the KV cache memory for the given block IDs.""" - if hasattr(self, "_kv_block_zeroer"): - self._kv_block_zeroer.zero_block_ids(block_ids) - # Note: used for model runner override. def _init_device_properties(self) -> None: """Initialize attributes from torch.cuda.get_device_properties""" @@ -1058,9 +1008,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) - self.late_interaction_runner.on_requests_finished( - scheduler_output.finished_req_ids - ) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -1070,11 +1017,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) - # Zero GPU memory for freshly allocated cache blocks to prevent - # stale NaN/data from corrupting attention or SSM computation. - if scheduler_output.new_block_ids_to_zero: - self._zero_block_ids(scheduler_output.new_block_ids_to_zero) - # Free the cached encoder outputs. for mm_hash in scheduler_output.free_encoder_mm_hashes: self.encoder_cache.pop(mm_hash, None) @@ -1153,7 +1095,6 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: lora_request=new_req_data.lora_request, ) self.requests[req_id] = req_state - self.late_interaction_runner.register_request(req_id, pooling_params) if sampling_params and sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = ( @@ -1383,22 +1324,12 @@ def _update_states_after_model_execute( .int() .argmax(-1) ) - spec_decode_active = bool(scheduler_output.scheduled_spec_decode_tokens) - if self.needs_prefill_as_decode_slots and spec_decode_active: - mamba_utils.update_accepted_tokens_for_prefill_as_decode( - self.input_batch, - self.prefill_as_decode_num_tokens, - self.num_accepted_tokens.gpu, - scheduler_output, - self.reorder_batch_threshold, - num_reqs, - ) - if self.cache_config.mamba_cache_mode == "align": for i, num_tokens in enumerate( self.num_accepted_tokens.gpu[:num_reqs].cpu().numpy() ): self.input_batch.num_accepted_tokens_cpu[i] = num_tokens + mamba_utils.postprocess_mamba( scheduler_output, self.kv_cache_config, @@ -1435,7 +1366,6 @@ def _update_streaming_request( req_state.prompt_embeds = new_req_data.prompt_embeds req_state.sampling_params = new_req_data.sampling_params req_state.pooling_params = new_req_data.pooling_params - self.late_interaction_runner.register_request(req_id, req_state.pooling_params) req_state.block_ids = new_req_data.block_ids req_state.num_computed_tokens = new_req_data.num_computed_tokens req_state.num_prompt_tokens = length_from_prompt_token_ids_or_embeds( @@ -1992,10 +1922,8 @@ def _get_block_table(kv_cache_gid: int): block_table_gid_0 = _get_block_table(0) slot_mapping_gid_0 = slot_mappings[0] - if self.routed_experts_initialized: - attn_gid = self.routed_experts_attn_gid - slot_mapping_attn = slot_mappings[attn_gid] - self.slot_mapping = slot_mapping_attn[:num_tokens].cpu().numpy() + if self.model_config.enable_return_routed_experts: + self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() cm_base = CommonAttentionMetadata( query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], @@ -2062,8 +1990,6 @@ def _build_attn_group_metadata( else 0 ) - if isinstance(builder, Mamba2AttentionMetadataBuilder): - self.needs_prefill_as_decode_slots = True extra_attn_metadata_args = {} if use_spec_decode and isinstance( builder, (Mamba2AttentionMetadataBuilder, GDNAttentionMetadataBuilder) @@ -2161,13 +2087,11 @@ def _build_attn_group_metadata( for _metadata in attn_metadata.values(): _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] - if ( - spec_decode_common_attn_metadata is not None - and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded) - and not self.supports_sd_full_graph + if spec_decode_common_attn_metadata is not None and ( + num_reqs != num_reqs_padded or num_tokens != num_tokens_padded ): - # Currently the drafter still only uses piecewise cudagraphs (except for - # Eagle, which supports FULL now), and therefore does not want to use + # Currently the drafter still only uses piecewise cudagraphs (and modifies + # the attention metadata in directly), and therefore does not want to use # padded attention metadata. spec_decode_common_attn_metadata = ( spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) @@ -2642,23 +2566,17 @@ def _execute_mm_encoder( ): batch_outputs: MultiModalEmbeddings - # EVS and dynamic res video related change. + # EVS-related change. # (ekhvedchenia): Temporary hack to limit peak memory usage when # processing multimodal data. This solves the issue with scheduler # putting too many video samples into a single batch. Scheduler # uses pruned vision tokens count to compare it versus compute # budget which is incorrect (Either input media size or non-pruned # output vision tokens count should be considered) - # dynamic res video for nemotron temporarily uses this hack via - # requires_sequential_video_encoding - # because it doesn't yet support video batching. # TODO(ywang96): Fix memory profiling to take EVS into account and # remove this hack. if ( - ( - self.is_multimodal_pruning_enabled - or self.requires_sequential_video_encoding - ) + self.is_multimodal_pruning_enabled and modality == "video" and num_items > 1 ): @@ -2850,7 +2768,15 @@ def get_supported_pooling_tasks(self) -> list[PoolingTask]: if not is_pooling_model(model): return [] - return list(model.pooler.get_supported_tasks()) + supported_tasks = list(model.pooler.get_supported_tasks()) + + if "score" in supported_tasks: + num_labels = getattr(self.model_config.hf_config, "num_labels", 0) + if num_labels != 1: + supported_tasks.remove("score") + logger.debug_once("Score API is only enabled for num_labels == 1.") + + return supported_tasks def get_supported_tasks(self) -> tuple[SupportedTask, ...]: tasks = list[SupportedTask]() @@ -2943,10 +2869,7 @@ def _pool( pooling_metadata = self.input_batch.get_pooling_metadata() pooling_metadata.build_pooling_cursor( - num_scheduled_tokens_np, - seq_lens_cpu, - device=hidden_states.device, - query_start_loc_gpu=self.query_start_loc.gpu[: num_reqs + 1], + num_scheduled_tokens_np, seq_lens_cpu, device=hidden_states.device ) model = cast(VllmModelForPooling, self.model) @@ -2958,12 +2881,6 @@ def _pool( seq_len == prompt_len for seq_len, prompt_len in zip(seq_lens_cpu, pooling_metadata.prompt_lens) ] - raw_pooler_output = self.late_interaction_runner.postprocess_pooler_output( - raw_pooler_output=raw_pooler_output, - pooling_params=pooling_metadata.pooling_params, - req_ids=self.input_batch.req_ids, - finished_mask=finished_mask, - ) model_runner_output = ModelRunnerOutput( req_ids=self.input_batch.req_ids.copy(), @@ -3622,7 +3539,7 @@ def execute_model( "after execute_model() returns None." ) - if self.routed_experts_initialized: + if self.vllm_config.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.clear_buffer() # noqa @@ -3646,10 +3563,10 @@ def execute_model( scheduled_spec_decode_tokens=spec_decode_tokens_copy, ) - if has_kv_transfer_group(): - kv_connector_metadata = scheduler_output.kv_connector_metadata - assert kv_connector_metadata is not None - get_kv_transfer_group().handle_preemptions(kv_connector_metadata) + if scheduler_output.preempted_req_ids and has_kv_transfer_group(): + get_kv_transfer_group().handle_preemptions( + scheduler_output.preempted_req_ids + ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with ( @@ -3941,7 +3858,6 @@ def execute_model( ec_connector_output, cudagraph_stats, slot_mappings, - batch_desc, ) self.kv_connector_output = kv_connector_output return None @@ -3980,7 +3896,6 @@ def sample_tokens( ec_connector_output, cudagraph_stats, slot_mappings, - batch_desc, ) = self.execute_model_state # Clear ephemeral state. self.execute_model_state = None @@ -4024,7 +3939,6 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, slot_mappings, - batch_desc, ) self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -4136,7 +4050,7 @@ def propose_draft_token_ids(sampled_token_ids): self.kv_connector_output = None with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): - if self.routed_experts_initialized: + if self.model_config.enable_return_routed_experts: capturer = RoutedExpertsCapturer.get_instance() if capturer is not None: capturer.save_captured_experts(indices=self.slot_mapping) # noqa @@ -4314,7 +4228,6 @@ def propose_draft_token_ids( spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, - target_model_batch_desc: BatchDescriptor, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4330,6 +4243,15 @@ def propose_draft_token_ids( self.input_batch.token_ids_cpu, slot_mappings=slot_mappings, ) + if isinstance(self.drafter, NgramProposer): + assert isinstance(sampled_token_ids, list), ( + "sampled_token_ids should be a python list when ngram is used." + ) + draft_token_ids = self.drafter.propose( + sampled_token_ids, + self.input_batch.num_tokens_no_spec, + self.input_batch.token_ids_cpu, + ) elif spec_config.use_ngram_gpu(): assert isinstance(self.drafter, NgramProposerGPU) ( @@ -4411,12 +4333,23 @@ def propose_draft_token_ids( ) target_hidden_states = [h[:num_scheduled_tokens] for h in aux_hidden_states] - draft_token_ids = self.drafter.propose( + draft_token_ids, drafter_kv_connector_output = self.drafter.propose( sampled_token_ids=sampled_token_ids, target_hidden_states=target_hidden_states, common_attn_metadata=common_attn_metadata, + scheduler_output=scheduler_output, slot_mappings=slot_mappings, ) + # Combine KVConnectorOutputs or select the non-empty one + if self.kv_connector_output and drafter_kv_connector_output: + self.kv_connector_output = KVConnectorOutput.merge( + self.kv_connector_output, drafter_kv_connector_output + ) + else: + self.kv_connector_output = ( + self.kv_connector_output or drafter_kv_connector_output + ) + next_token_ids, valid_sampled_tokens_count = ( self.drafter.prepare_next_token_ids_padded( common_attn_metadata, @@ -4508,7 +4441,6 @@ def propose_draft_token_ids( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count, - self.input_batch, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -4536,9 +4468,8 @@ def propose_draft_token_ids( target_hidden_states=target_hidden_states, next_token_ids=next_token_ids, token_indices_to_sample=token_indices_to_sample, - common_attn_metadata=common_attn_metadata, - target_model_batch_desc=target_model_batch_desc, sampling_metadata=sampling_metadata, + common_attn_metadata=common_attn_metadata, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, @@ -4630,9 +4561,12 @@ def load_model(self, load_dummy_weights: bool = False) -> None: aux_layers, ) else: - aux_layers = ( - self.model.get_eagle3_default_aux_hidden_state_layers() + get = getattr( + self.model, + "get_eagle3_aux_hidden_state_layers", + self.model.get_eagle3_default_aux_hidden_state_layers, ) + aux_layers = get() self.model.set_aux_hidden_state_layers(aux_layers) time_after_load = time.perf_counter() @@ -4666,9 +4600,6 @@ def load_model(self, load_dummy_weights: bool = False) -> None: and mm_config is not None and mm_config.is_multimodal_pruning_enabled() ) - self.requires_sequential_video_encoding = hasattr( - self.get_model(), "requires_sequential_video_encoding" - ) # Temporary hack for dynamic res video w/o support for bs>1 yet if ( is_mixture_of_experts(self.model) @@ -5182,7 +5113,7 @@ def _dummy_run( ) attn_metadata: PerLayerAttnMetadata | None = None - spec_decode_cm: CommonAttentionMetadata | None = None + spec_decode_common_attn_metadata: CommonAttentionMetadata | None = None slot_mappings_by_group, slot_mappings = self._get_slot_mappings( num_tokens_padded=num_tokens, @@ -5217,15 +5148,19 @@ def _dummy_run( self.query_start_loc.copy_to_gpu() pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL - attn_metadata, spec_decode_cm = self._build_attention_metadata( - num_tokens=num_tokens_unpadded, - num_tokens_padded=num_tokens_padded if pad_attn else None, - num_reqs=num_reqs_padded, - max_query_len=max_query_len, - ubatch_slices=(ubatch_slices_padded if pad_attn else ubatch_slices), - for_cudagraph_capture=is_graph_capturing, - slot_mappings=slot_mappings_by_group, - use_spec_decode=self.speculative_config is not None, + attn_metadata, spec_decode_common_attn_metadata = ( + self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens_padded if pad_attn else None, + num_reqs=num_reqs_padded, + max_query_len=max_query_len, + ubatch_slices=( + ubatch_slices_padded if pad_attn else ubatch_slices + ), + for_cudagraph_capture=is_graph_capturing, + slot_mappings=slot_mappings_by_group, + use_spec_decode=self.speculative_config is not None, + ) ) with self.maybe_dummy_run_with_lora( @@ -5320,14 +5255,10 @@ def _dummy_run( EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, ) assert self.speculative_config is not None - # NOTE(lucas): this is a hack, need to clean up. use_cudagraphs = ( ( is_graph_capturing - and ( - cudagraph_runtime_mode == CUDAGraphMode.PIECEWISE - or self.supports_sd_full_graph - ) + and cudagraph_runtime_mode != CUDAGraphMode.NONE ) or ( not is_graph_capturing @@ -5347,7 +5278,7 @@ def _dummy_run( self.drafter.dummy_run( num_tokens, - common_attn_metadata=spec_decode_cm, + common_attn_metadata=spec_decode_common_attn_metadata, use_cudagraphs=use_cudagraphs, is_graph_capturing=is_graph_capturing, slot_mappings=slot_mappings, @@ -5578,14 +5509,13 @@ def profile_run(self) -> None: dummy_modality ] - logger.info_once( + logger.info( "Encoder cache will be initialized with a " "budget of %s tokens, and profiled with " "%s %s items of the maximum feature size.", encoder_budget, max_mm_items_per_batch, dummy_modality, - scope="local", ) # Create dummy batch of multimodal inputs. @@ -5631,14 +5561,16 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: kv_cache_spec = self.get_kv_cache_spec() kv_cache_groups = get_kv_cache_groups(self.vllm_config, kv_cache_spec) min_blocks = self.compilation_config.max_cudagraph_capture_size or 1 + if kv_cache_groups: + page_size = kv_cache_groups[0].kv_cache_spec.page_size_bytes + group_size = max(len(g.layer_names) for g in kv_cache_groups) + available_memory = min_blocks * page_size * group_size + else: + available_memory = 1 # Attention-free model - # Temporarily change num_gpu_blocks_override to allocate a minimal KV cache - saved_override = self.cache_config.num_gpu_blocks_override - self.cache_config.num_gpu_blocks_override = min_blocks minimal_config = get_kv_cache_config_from_groups( - self.vllm_config, kv_cache_groups, available_memory=0 + self.vllm_config, kv_cache_groups, available_memory=available_memory ) - self.cache_config.num_gpu_blocks_override = saved_override self.initialize_kv_cache(minimal_config) self.cache_config.num_gpu_blocks = minimal_config.num_blocks @@ -5829,6 +5761,10 @@ def capture_model(self) -> int: torch.accelerator.synchronize() torch.accelerator.empty_cache() + from vllm.compilation.cuda_graph import _dump_all_full_graphs + + _dump_all_full_graphs() + # Lock workspace to prevent resizing during execution. # Max workspace sizes should have been captured during warmup/profiling. lock_workspace() @@ -5896,11 +5832,6 @@ def _capture_cudagraphs( # Only rank 0 should print progress bar during capture if is_global_first_rank(): - logger.info( - "Capturing CUDA graphs for %d batches (%s)", - len(batch_descriptors), - ", ".join(str(desc.num_tokens) for desc in batch_descriptors), - ) batch_descriptors = tqdm( batch_descriptors, disable=not self.load_config.use_tqdm_on_load, @@ -6183,6 +6114,7 @@ def _check_and_update_cudagraph_mode( # sizes for decode and mixed prefill-decode. if ( cudagraph_mode.decode_mode() == CUDAGraphMode.FULL + and cudagraph_mode.separate_routine() and self.uniform_decode_query_len > 1 ): self.compilation_config.adjust_cudagraph_sizes_for_spec_decode( @@ -6585,7 +6517,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kernel_block_sizes = prepare_kernel_block_sizes( kv_cache_config, self.attn_groups ) - self._kernel_block_sizes = kernel_block_sizes # create metadata builders self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes) @@ -6616,12 +6547,8 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: kv_transfer_group.register_kv_caches(kv_caches) kv_transfer_group.set_host_xfer_buffer_ops(copy_kv_blocks) - def _get_attention_kv_cache_gid(self) -> int: - """Find the KV cache group index for attention layers.""" - for gid, group in enumerate(self.kv_cache_config.kv_cache_groups): - if isinstance(group.kv_cache_spec, AttentionSpec): - return gid - return 0 + if self.model_config.enable_return_routed_experts: + self.init_routed_experts_capturer() def init_routed_experts_capturer(self): logger.info( @@ -6629,29 +6556,17 @@ def init_routed_experts_capturer(self): self.model_config.enable_return_routed_experts, ) routed_experts_capturer = RoutedExpertsCapturer.create() - self.routed_experts_attn_gid = self._get_attention_kv_cache_gid() - min_block_size = min( - [ - group.kv_cache_spec.block_size - for group in self.kv_cache_config.kv_cache_groups - ] - ) - num_groups = len(self.kv_cache_config.kv_cache_groups) + block_size = self.cache_config.block_size self.max_num_kv_tokens = ( - self.kv_cache_config.num_blocks // num_groups - ) * min_block_size - dcp_size = self.vllm_config.parallel_config.decode_context_parallel_size - pcp_size = self.vllm_config.parallel_config.prefill_context_parallel_size - if pcp_size * dcp_size > 1: - self.max_num_kv_tokens *= pcp_size * dcp_size - + self.kv_cache_config.num_blocks // len(self.kv_cache_config.kv_cache_groups) + + 1 + ) * block_size routed_experts_capturer.init_buffer( max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens, max_num_kv_tokens=self.max_num_kv_tokens, vllm_config=self.vllm_config, ) self._bind_routed_experts_capturer(routed_experts_capturer) - self.routed_experts_initialized = True def _bind_routed_experts_capturer(self, capturer: RoutedExpertsCapturer) -> None: from vllm.model_executor.layers.fused_moe.layer import FusedMoE From f2a85e099af51ea946fca2cd118e5a2b83e8c04f Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 20 Mar 2026 16:59:58 -0600 Subject: [PATCH 4/8] Merged CUDA graph for speculative decoding (Drafter + Target + logits) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Capture Drafter forward, Target forward, and compute_logits as a single CUDA graph for EAGLE speculative decoding. This eliminates GPU launch overhead between the three phases during decode. Key changes: - cuda_graph.py: Add _merged_capture_bypass flag so CUDAGraphWrapper passes through during merged capture; add hipGraph DOT dump utilities; use keep_graph=True for FULL mode to retain raw graph handles. - gpu_model_runner.py: Add _merged_capture() to record the combined [Drafter → Target → compute_logits] graph with persistent buffers; restructure _capture_cudagraphs to capture individual graphs first, then merged graphs; replay merged graph in execute_model when ready. - eagle.py: Handle target_model_batch_desc=None gracefully in propose(). Signed-off-by: Matthias Gehre --- vllm/compilation/cuda_graph.py | 154 +++++++++- vllm/v1/spec_decode/eagle.py | 9 +- vllm/v1/worker/gpu_model_runner.py | 449 ++++++++++++++++++++++++----- 3 files changed, 524 insertions(+), 88 deletions(-) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 00bf4bbc71f1..0473269469e6 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -28,6 +28,148 @@ logger = init_logger(__name__) +_dump_counter = 0 + +_merged_capture_bypass = False + + +def _dump_all_full_graphs(): + """Dump FULL hipGraph DOT files using keep_graph=True handles.""" + import ctypes + import os + + dump_dir = "/scratch/mgehre/tmp/cudagraph_dumps" + os.makedirs(dump_dir, exist_ok=True) + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + hip_path = rocm_path + "/lib/libamdhip64.so" + if not os.path.exists(hip_path): + logger.warning("hipGraph dump: libamdhip64.so not found") + return + + hip = ctypes.CDLL(hip_path) + hipGraphDebugDotPrint = hip.hipGraphDebugDotPrint + hipGraphDebugDotPrint.restype = ctypes.c_int + hipGraphDebugDotPrint.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_uint, + ] + hipGraphGetNodes = hip.hipGraphGetNodes + hipGraphGetNodes.restype = ctypes.c_int + hipGraphGetNodes.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_size_t), + ] + hipGraphGetEdges = hip.hipGraphGetEdges + hipGraphGetEdges.restype = ctypes.c_int + hipGraphGetEdges.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_size_t), + ] + + count = 0 + for wrapper in list(CUDAGraphWrapper._all_instances): + if wrapper.runtime_mode != CUDAGraphMode.FULL: + continue + for bd, entry in wrapper.concrete_cudagraph_entries.items(): + if entry.cudagraph is None: + continue + try: + graph_handle = entry.cudagraph.raw_cuda_graph() + except RuntimeError: + logger.warning("hipGraph dump: graph_%d - no raw handle", count) + count += 1 + continue + + num_nodes = ctypes.c_size_t(0) + hipGraphGetNodes(graph_handle, None, ctypes.byref(num_nodes)) + num_edges = ctypes.c_size_t(0) + hipGraphGetEdges( + graph_handle, + None, + None, + ctypes.byref(num_edges), + ) + + fname = f"graph_FULL_{count}_{bd.num_tokens}t_{bd.num_reqs}r.dot" + dot_path = os.path.join(dump_dir, fname).encode() + ret = hipGraphDebugDotPrint(graph_handle, dot_path, 0xFFFF) + + logger.warning( + "hipGraph dump: graph_%d desc=%s nodes=%d edges=%d dot_ret=%d -> %s", + count, + bd, + num_nodes.value, + num_edges.value, + ret, + fname, + ) + count += 1 + + +def dump_cuda_graph(graph: "torch.cuda.CUDAGraph", label: str): + """Dump a single keep_graph=True CUDAGraph to DOT via HIP API.""" + import ctypes + import os + + dump_dir = "/scratch/mgehre/tmp/cudagraph_dumps" + os.makedirs(dump_dir, exist_ok=True) + rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm") + hip_path = rocm_path + "/lib/libamdhip64.so" + if not os.path.exists(hip_path): + logger.warning("dump_cuda_graph: libamdhip64.so not found") + return + + hip = ctypes.CDLL(hip_path) + hipGraphDebugDotPrint = hip.hipGraphDebugDotPrint + hipGraphDebugDotPrint.restype = ctypes.c_int + hipGraphDebugDotPrint.argtypes = [ + ctypes.c_void_p, + ctypes.c_char_p, + ctypes.c_uint, + ] + hipGraphGetNodes = hip.hipGraphGetNodes + hipGraphGetNodes.restype = ctypes.c_int + hipGraphGetNodes.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_size_t), + ] + hipGraphGetEdges = hip.hipGraphGetEdges + hipGraphGetEdges.restype = ctypes.c_int + hipGraphGetEdges.argtypes = [ + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_void_p), + ctypes.POINTER(ctypes.c_size_t), + ] + + try: + graph_handle = graph.raw_cuda_graph() + except RuntimeError: + logger.warning("dump_cuda_graph: no raw handle for %s", label) + return + + num_nodes = ctypes.c_size_t(0) + hipGraphGetNodes(graph_handle, None, ctypes.byref(num_nodes)) + num_edges = ctypes.c_size_t(0) + hipGraphGetEdges(graph_handle, None, None, ctypes.byref(num_edges)) + + fname = f"graph_{label}.dot" + dot_path = os.path.join(dump_dir, fname).encode() + ret = hipGraphDebugDotPrint(graph_handle, dot_path, 0xFFFF) + logger.warning( + "hipGraph dump: %s nodes=%d edges=%d dot_ret=%d -> %s", + label, + num_nodes.value, + num_edges.value, + ret, + fname, + ) + @dataclasses.dataclass(frozen=True) class CUDAGraphStat: @@ -231,6 +373,9 @@ def clear_graphs(self) -> None: self.concrete_cudagraph_entries.clear() def __call__(self, *args: Any, **kwargs: Any) -> Any | None: + if _merged_capture_bypass: + return self.runnable(*args, **kwargs) + if not is_forward_context_available(): # No forward context means we are outside the normal # inference path (e.g. a vision encoder forward pass). @@ -264,10 +409,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: if entry.cudagraph is None: if self.cudagraph_options.debug_log_enable: - # Since we capture cudagraph for many different shapes and - # capturing is fast, we don't need to log it for every - # shape. E.g. we only log it for the first subgraph in - # piecewise mode. logger.debug( "Capturing a cudagraph on (%s,%s)", self.runtime_mode.name, @@ -280,7 +421,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: x.data_ptr() for x in args if isinstance(x, torch.Tensor) ] entry.input_addresses = input_addresses - cudagraph = torch.cuda.CUDAGraph() + keep = self.runtime_mode == CUDAGraphMode.FULL + cudagraph = torch.cuda.CUDAGraph(keep_graph=keep) with ExitStack() as stack: if self.cudagraph_options.gc_disable: @@ -329,6 +471,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: # here we always use weak ref for the output # to save memory entry.output = weak_ref_tensors(output) + if keep: + cudagraph.instantiate() entry.cudagraph = cudagraph compilation_counter.num_cudagraph_captured += 1 diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index a758dae3a5dc..c93e614d4d8e 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -430,7 +430,11 @@ def propose( for layer_name in attn_group.layer_names: per_layer_attn_metadata[layer_name] = attn_metadata - uniform_decode = target_model_batch_desc.uniform + uniform_decode = ( + target_model_batch_desc.uniform + if target_model_batch_desc is not None + else False + ) cudagraph_runtime_mode, batch_desc, num_input_tokens, num_tokens_across_dp = ( self._determine_batch_execution_and_padding(num_tokens, uniform_decode) ) @@ -1262,9 +1266,6 @@ def load_model(self, target_model: nn.Module) -> None: and not self.vllm_config.parallel_config.use_ubatching and not self.speculative_config.disable_padded_drafter_batch ): - # Currently Ubatch does not support FULL in speculative decoding, unpadded - # drafter batch either due to the dynamic number of tokens. - # We can consider supporting FULL for these cases in the future if needed. self.model = CUDAGraphWrapper( self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ee018696f2a6..595eca129fff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3760,92 +3760,115 @@ def execute_model( # When spec decode is enabled, defer connector finalization # (wait_for_save + clear metadata) until after draft model runs. defer_kv_connector_finalize = self.speculative_config is not None - with ( - set_forward_context( + + # Check if we have a merged FULL graph for this batch. + has_merged_graph = ( + hasattr(self, "_merged_full_graphs") + and batch_desc is not None + and batch_desc in self._merged_full_graphs + ) + use_merged_graph = has_merged_graph and getattr(self, "_merged_hs_ready", False) + if has_merged_graph and not use_merged_graph: + cudagraph_mode = CUDAGraphMode.NONE + + if use_merged_graph: + with set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens_padded, - num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_mode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, batch_descriptor=batch_desc, - ubatch_slices=ubatch_slices_padded, slot_mapping=slot_mappings, - skip_compiled=has_encoder_input, - ), - record_function_or_nullcontext("gpu_model_runner: forward"), - self.maybe_get_kv_connector_output( - scheduler_output, - defer_finalize=defer_kv_connector_finalize, - ) as kv_connector_output, - ): - model_output = self._model_forward( - input_ids=input_ids, - positions=positions, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds, - **model_kwargs, - ) - - with record_function_or_nullcontext("gpu_model_runner: postprocess"): - if self.use_aux_hidden_state_outputs: - # True when EAGLE 3 is used. - hidden_states, aux_hidden_states = model_output - else: - # Common case. - hidden_states = model_output - aux_hidden_states = None - - if not self.broadcast_pp_output: - # Common case. - if not get_pp_group().is_last_rank: - # Return the intermediate tensors. - assert isinstance(hidden_states, IntermediateTensors) - hidden_states.kv_connector_output = kv_connector_output - self.kv_connector_output = kv_connector_output - return hidden_states - - if self.is_pooling_model: - # Return the pooling output. - return self._pool( - hidden_states, - num_scheduled_tokens, - num_scheduled_tokens_np, - kv_connector_output, - ) + ): + self._merged_full_graphs[batch_desc].replay() + hidden_states = self._merged_full_hidden_states + aux_hidden_states = None + logits = self._merged_full_logits + sample_hidden_states = hidden_states[logits_indices] + kv_connector_output = None + self._merged_replay_active = True + else: + self._merged_replay_active = False + with ( + set_forward_context( + attn_metadata, + self.vllm_config, + num_tokens=num_tokens_padded, + num_tokens_across_dp=num_tokens_across_dp, + cudagraph_runtime_mode=cudagraph_mode, + batch_descriptor=batch_desc, + ubatch_slices=ubatch_slices_padded, + slot_mapping=slot_mappings, + skip_compiled=has_encoder_input, + ), + record_function_or_nullcontext("gpu_model_runner: forward"), + self.maybe_get_kv_connector_output( + scheduler_output, + defer_finalize=defer_kv_connector_finalize, + ) as kv_connector_output, + ): + model_output = self._model_forward( + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **model_kwargs, + ) - sample_hidden_states = hidden_states[logits_indices] - with self._prefer_hipblaslt_for_logits(): - logits = self.model.compute_logits(sample_hidden_states) - else: - # Rare case. - assert not self.is_pooling_model - - sample_hidden_states = hidden_states[logits_indices] - if not get_pp_group().is_last_rank: - all_gather_tensors = { - "residual": not is_residual_scattered_for_sp( - self.vllm_config, num_tokens_padded - ) - } - get_pp_group().send_tensor_dict( - hidden_states.tensors, - all_gather_group=get_tp_group(), - all_gather_tensors=all_gather_tensors, - ) - logits = None + with record_function_or_nullcontext("gpu_model_runner: postprocess"): + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output else: + hidden_states = model_output + aux_hidden_states = None + + if not self.broadcast_pp_output: + if not get_pp_group().is_last_rank: + assert isinstance(hidden_states, IntermediateTensors) + hidden_states.kv_connector_output = kv_connector_output + self.kv_connector_output = kv_connector_output + return hidden_states + + if self.is_pooling_model: + return self._pool( + hidden_states, + num_scheduled_tokens, + num_scheduled_tokens_np, + kv_connector_output, + ) + + sample_hidden_states = hidden_states[logits_indices] with self._prefer_hipblaslt_for_logits(): logits = self.model.compute_logits(sample_hidden_states) + else: + assert not self.is_pooling_model + sample_hidden_states = hidden_states[logits_indices] + if not get_pp_group().is_last_rank: + all_gather_tensors = { + "residual": not is_residual_scattered_for_sp( + self.vllm_config, num_tokens_padded + ) + } + get_pp_group().send_tensor_dict( + hidden_states.tensors, + all_gather_group=get_tp_group(), + all_gather_tensors=all_gather_tensors, + ) + logits = None + else: + with self._prefer_hipblaslt_for_logits(): + logits = self.model.compute_logits(sample_hidden_states) - model_output_broadcast_data: dict[str, Any] = {} - if logits is not None: - model_output_broadcast_data["logits"] = logits.contiguous() + model_output_broadcast_data: dict[str, Any] = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() - broadcasted = get_pp_group().broadcast_tensor_dict( - model_output_broadcast_data, src=len(get_pp_group().ranks) - 1 - ) - assert broadcasted is not None - logits = broadcasted["logits"] + broadcasted = get_pp_group().broadcast_tensor_dict( + model_output_broadcast_data, + src=len(get_pp_group().ranks) - 1, + ) + assert broadcasted is not None + logits = broadcasted["logits"] self.execute_model_state = ExecuteModelState( scheduler_output, @@ -3859,6 +3882,7 @@ def execute_model( cudagraph_stats, slot_mappings, ) + self._last_batch_desc = batch_desc self.kv_connector_output = kv_connector_output return None @@ -3897,6 +3921,7 @@ def sample_tokens( cudagraph_stats, slot_mappings, ) = self.execute_model_state + batch_desc = getattr(self, "_last_batch_desc", None) # Clear ephemeral state. self.execute_model_state = None @@ -3926,6 +3951,39 @@ def sample_tokens( self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None + merged_replay_active = getattr(self, "_merged_replay_active", False) + if merged_replay_active: + self._draft_token_ids = self._merged_draft_token_ids + self._copy_draft_token_ids_to_cpu(scheduler_output) + + if ( + self.valid_sampled_token_count_event is not None + and spec_decode_common_attn_metadata is not None + ): + sampled_tids = sampler_output.sampled_token_ids + next_tids, valid_counts = self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_tids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + self._copy_valid_sampled_token_count( + next_tids, + valid_counts, + ) + + self._merged_replay_active = False + + if hasattr(self, "_merged_full_graphs"): + self._update_merged_drafter_inputs( + sampler_output, + hidden_states, + aux_hidden_states, + scheduler_output, + merged_replayed=merged_replay_active, + ) + def propose_draft_token_ids(sampled_token_ids): assert spec_decode_common_attn_metadata is not None with record_function_or_nullcontext("gpu_model_runner: draft"): @@ -3939,12 +3997,13 @@ def propose_draft_token_ids(sampled_token_ids): spec_decode_metadata, spec_decode_common_attn_metadata, slot_mappings, + target_model_batch_desc=batch_desc, ) self._copy_draft_token_ids_to_cpu(scheduler_output) spec_config = self.speculative_config propose_drafts_after_bookkeeping = False - if spec_config is not None: + if spec_config is not None and not merged_replay_active: input_fits_in_drafter = spec_decode_common_attn_metadata is not None and ( spec_decode_common_attn_metadata.max_seq_len + self.num_spec_tokens <= self.effective_drafter_max_model_len @@ -3955,8 +4014,6 @@ def propose_draft_token_ids(sampled_token_ids): or spec_config.uses_extract_hidden_states() ) and not spec_config.disable_padded_drafter_batch if use_gpu_toks: - # EAGLE/DraftModel speculative decoding can use the GPU sampled tokens - # as inputs, and does not need to wait for bookkeeping to finish. assert isinstance( self.drafter, EagleProposer | DraftModelProposer | ExtractHiddenStatesProposer, @@ -4227,7 +4284,10 @@ def propose_draft_token_ids( aux_hidden_states: list[torch.Tensor] | None, spec_decode_metadata: SpecDecodeMetadata | None, common_attn_metadata: CommonAttentionMetadata, - slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None, + slot_mappings: dict[str, torch.Tensor] + | list[dict[str, torch.Tensor]] + | None = None, + target_model_batch_desc: "BatchDescriptor | None" = None, ) -> list[list[int]] | torch.Tensor: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens spec_config = self.speculative_config @@ -4441,6 +4501,7 @@ def propose_draft_token_ids( common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count, + self.input_batch, ) total_num_tokens = common_attn_metadata.num_actual_tokens # When padding the batch, token_indices is just a range @@ -4470,6 +4531,7 @@ def propose_draft_token_ids( token_indices_to_sample=token_indices_to_sample, sampling_metadata=sampling_metadata, common_attn_metadata=common_attn_metadata, + target_model_batch_desc=target_model_batch_desc, mm_embed_inputs=mm_embed_inputs, num_rejected_tokens_gpu=num_rejected_tokens_gpu, slot_mappings=slot_mappings, @@ -4477,6 +4539,43 @@ def propose_draft_token_ids( return draft_token_ids + def _update_merged_drafter_inputs( + self, + sampler_output, + hidden_states: torch.Tensor, + aux_hidden_states, + scheduler_output: "SchedulerOutput", + merged_replayed: bool = False, + ): + """Update persistent buffers so the next merged-graph drafter + reads correct hidden_states and next_token_ids.""" + if not hasattr(self, "_merged_hs_buf"): + return + + buf = self._merged_hs_buf + + if not merged_replayed: + # Only update hs buffer from the normal (non-merged) path. + # When the merged graph replayed, it already updated + # _merged_hs_buf via the captured copy_ operation. + n = scheduler_output.total_num_scheduled_tokens + if n > buf.shape[0]: + return + if self.use_aux_hidden_state_outputs and aux_hidden_states: + hs = torch.cat([h[:n] for h in aux_hidden_states], dim=-1) + else: + hs = hidden_states[:n] + buf[:n].copy_(hs) + + if n == buf.shape[0]: + self._merged_hs_ready = True + + accepted = sampler_output.sampled_token_ids + if isinstance(accepted, torch.Tensor): + tok = accepted[:, 0].int() if accepted.ndim == 2 else accepted.int() + nt = self._merged_next_token_ids + nt[: tok.shape[0]].copy_(tok) + def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): @@ -5792,6 +5891,7 @@ def _warmup_and_capture( if num_warmups is None: num_warmups = self.compilation_config.cudagraph_num_of_warmups force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + for _ in range(num_warmups): self._dummy_run( desc.num_tokens, @@ -5803,6 +5903,7 @@ def _warmup_and_capture( remove_lora=False, num_active_loras=desc.num_active_loras, ) + self._dummy_run( desc.num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, @@ -5815,6 +5916,177 @@ def _warmup_and_capture( profile_seq_lens=profile_seq_lens, ) + @torch.inference_mode() + def _merged_capture(self, desc: BatchDescriptor): + """Capture [Drafter → Target → compute_logits] in one CUDA graph. + + In the reordered flow, the drafter runs first (using the PREVIOUS + target forward's hidden_states), then the target forward verifies + the PREVIOUS step's draft tokens, then compute_logits produces + logits for the rejection sampler (which runs outside the graph). + + Between graph replays, the rejection sampler and bookkeeping run + on the CPU/GPU outside the graph. + """ + import vllm.compilation.cuda_graph as cg + + assert isinstance(self.drafter, EagleProposer) + num_tokens = desc.num_tokens + + # Additional warmup so attention layers see the padded shapes. + self._dummy_run( + num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=True, + uniform_decode=desc.uniform, + skip_eplb=True, + remove_lora=False, + num_active_loras=0, + ) + + # ── Build metadata (same layout as _dummy_run for FULL) ── + max_query_len = self.uniform_decode_query_len + num_reqs = desc.num_reqs if desc.num_reqs is not None else num_tokens + num_scheduled_tokens_list = [max_query_len] * num_reqs + if num_tokens % max_query_len != 0: + num_scheduled_tokens_list[-1] = num_tokens % max_query_len + num_scheduled_tokens_np = np.array(num_scheduled_tokens_list, dtype=np.int32) + num_tokens_unpadded = int(num_scheduled_tokens_np.sum()) + + self.seq_lens.np[:num_reqs] = max_query_len + self.seq_lens.np[num_reqs:] = 0 + self.seq_lens.copy_to_gpu() + cum_num_tokens, _ = self._get_cumsum_and_arange(num_scheduled_tokens_np) + self.query_start_loc.np[1 : num_reqs + 1] = cum_num_tokens + self.query_start_loc.copy_to_gpu() + + slot_mappings_by_group, slot_mappings = self._get_slot_mappings( + num_tokens_padded=num_tokens, + num_reqs_padded=num_reqs, + num_tokens_unpadded=num_tokens_unpadded, + ubatch_slices=None, + ) + + target_attn_metadata, spec_decode_cm = self._build_attention_metadata( + num_tokens=num_tokens_unpadded, + num_tokens_padded=num_tokens, + num_reqs=num_reqs, + max_query_len=max_query_len, + ubatch_slices=None, + for_cudagraph_capture=True, + slot_mappings=slot_mappings_by_group, + use_spec_decode=True, + ) + + input_ids = self.input_ids.gpu[:num_tokens] + positions = self._get_positions(num_tokens) + + # ── Persistent buffers for drafter ↔ target data bridge ── + hidden_size = self.model_config.get_hidden_size() + if self.use_aux_hidden_state_outputs: + get = getattr( + self.model, + "get_eagle3_aux_hidden_state_layers", + self.model.get_eagle3_default_aux_hidden_state_layers, + ) + n_aux = len(get()) + hs_dim = hidden_size * n_aux + else: + hs_dim = hidden_size + + self._merged_hs_buf = torch.zeros( + num_tokens, + hs_dim, + dtype=self.model_config.dtype, + device=self.device, + ) + self._merged_next_token_ids = torch.zeros( + num_reqs, + dtype=torch.int32, + device=self.device, + ) + + # ── Bypass individual CUDAGraphWrappers ── + cg._merged_capture_bypass = True + + if not hasattr(self, "_merged_graph_pool"): + self._merged_graph_pool = torch.cuda.graph_pool_handle() + graph_pool = self._merged_graph_pool + merged_graph = torch.cuda.CUDAGraph(keep_graph=True) + + assert spec_decode_cm is not None + + with torch.cuda.graph( + merged_graph, + pool=graph_pool, + stream=torch.cuda.current_stream(), + ): + # ═══ DRAFTER ═══ + draft_token_ids = self.drafter.propose( + target_token_ids=input_ids, + target_positions=positions, + target_hidden_states=self._merged_hs_buf, + next_token_ids=self._merged_next_token_ids, + token_indices_to_sample=None, + common_attn_metadata=spec_decode_cm, + target_model_batch_desc=desc, + sampling_metadata=self.input_batch.sampling_metadata, + num_rejected_tokens_gpu=None, + slot_mappings=slot_mappings, + ) + + # ═══ TARGET FORWARD ═══ + with set_forward_context( + target_attn_metadata, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + batch_descriptor=desc, + slot_mapping=slot_mappings, + ): + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hs = self.model( + input_ids=input_ids, + positions=positions, + ) + else: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + ) + + # ═══ COMPUTE LOGITS ═══ + # For spec decode, logits are needed at ALL positions + # (not just the last per request), so the rejection sampler + # can compare target vs. draft at every token. + sample_hidden = hidden_states[:num_tokens] + logits = self.model.compute_logits(sample_hidden) + + # ═══ Update persistent hs buffer for next replay ═══ + if self.use_aux_hidden_state_outputs: + hs_for_drafter = torch.cat([h[:num_tokens] for h in aux_hs], dim=-1) + else: + hs_for_drafter = hidden_states[:num_tokens] + self._merged_hs_buf.copy_(hs_for_drafter) + + cg._merged_capture_bypass = False + merged_graph.instantiate() + + # ── Store handles for replay / dump ── + if not hasattr(self, "_merged_full_graphs"): + self._merged_full_graphs = {} + self._merged_full_graphs[desc] = merged_graph + self._merged_full_logits = logits + self._merged_full_hidden_states = hidden_states + self._merged_draft_token_ids = draft_token_ids + + from vllm.compilation.cuda_graph import dump_cuda_graph + + dump_cuda_graph( + merged_graph, + f"MERGED_{desc.num_tokens}t_{desc.num_reqs}r", + ) + def _capture_cudagraphs( self, batch_descriptors: list[BatchDescriptor], @@ -5842,6 +6114,7 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics + merged_descs: list[BatchDescriptor] = [] for batch_desc in batch_descriptors: # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens @@ -5863,6 +6136,24 @@ def _capture_cudagraphs( allow_microbatching=allow_microbatching, ) torch.accelerator.synchronize() + + use_merged = ( + cudagraph_runtime_mode == CUDAGraphMode.FULL + and self.speculative_config is not None + and self.speculative_config.use_eagle() + and not self.speculative_config.disable_padded_drafter_batch + ) + if use_merged: + merged_descs.append(batch_desc) + + # Merged graphs are captured LAST, after all individual + # CUDAGraphWrapper graphs for all batch sizes are finalized. + # This prevents later individual captures from reallocating + # torch.compile internal buffers that the merged graph references. + for desc in merged_descs: + self._merged_capture(desc) + torch.accelerator.synchronize() + self.maybe_remove_all_loras(self.lora_config) def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: From 626c6c094b50a1e616018c891ea1459a07faee8c Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Sun, 22 Mar 2026 04:18:29 -0600 Subject: [PATCH 5/8] =?UTF-8?q?Reorder=20merged=20CUDA=20graph:=20Target?= =?UTF-8?q?=E2=86=92Logits=E2=86=92Drafter=20with=20in-graph=20rejection?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous ordering (Drafter→Target→Logits) caused the drafter to use stale next_token_ids from the previous step, collapsing the acceptance rate from ~37% to ~2%. Reorder to Target→Logits→Drafter so the drafter receives fresh hidden states and next_token_ids computed via argmax(logits) inside the graph. Key changes: - Compute in-graph greedy rejection: compare argmax(logits) with draft tokens in input_ids to derive token_indices_to_sample, num_rejected_tokens_gpu, and next_token_ids — all within the captured CUDA graph. - Remove prev_input_ids/prev_positions/prev_seq_lens/prev_slot_mapping persistent buffers that were needed for the old ordering. - The drafter now uses the current step's target hidden states and the correctly computed bonus token, matching the eager flow exactly. Results (Qwen3-4B + EAGLE3, 2 spec tokens, Strix Halo): Baseline (no merge): TPOT 8.98ms, acceptance 37.1% Merged graph: TPOT 9.31ms, acceptance 37.2% Signed-off-by: Matthias Gehre --- vllm/v1/worker/gpu_model_runner.py | 216 +++++++++++++++++++++-------- 1 file changed, 160 insertions(+), 56 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 595eca129fff..6ef5a4af904d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4,6 +4,7 @@ import functools import gc import itertools +import os import threading import time from collections import defaultdict @@ -3772,6 +3773,32 @@ def execute_model( cudagraph_mode = CUDAGraphMode.NONE if use_merged_graph: + # Compute token_indices_to_sample for the merged drafter, + # accounting for rejected tokens from the previous step. + num_reqs_m = batch_desc.num_reqs or num_tokens_padded + if spec_decode_metadata is not None: + from vllm.v1.spec_decode.utils import ( + eagle_prepare_inputs_padded_kernel, + ) + + cu_draft = spec_decode_metadata.cu_num_draft_tokens + actual_reqs = self.input_batch.num_reqs + if cu_draft.shape[0] < num_reqs_m: + cu_draft = nn.functional.pad( + cu_draft, + (0, num_reqs_m - actual_reqs), + mode="constant", + value=cu_draft[-1].item(), + ) + eagle_prepare_inputs_padded_kernel[(num_reqs_m,)]( + cu_draft, + self._merged_prev_valid_counts, + self.query_start_loc.gpu, + self._merged_token_indices_to_sample, + self._merged_num_rejected_tokens_gpu, + num_reqs_m, + ) + with set_forward_context( attn_metadata, self.vllm_config, @@ -3781,6 +3808,13 @@ def execute_model( slot_mapping=slot_mappings, ): self._merged_full_graphs[batch_desc].replay() + + if not hasattr(self, "_merged_draft_ids_safe"): + self._merged_draft_ids_safe = torch.empty_like( + self._merged_draft_token_ids + ) + self._merged_draft_ids_safe.copy_(self._merged_draft_token_ids) + hidden_states = self._merged_full_hidden_states aux_hidden_states = None logits = self._merged_full_logits @@ -3953,7 +3987,7 @@ def sample_tokens( merged_replay_active = getattr(self, "_merged_replay_active", False) if merged_replay_active: - self._draft_token_ids = self._merged_draft_token_ids + self._draft_token_ids = self._merged_draft_ids_safe self._copy_draft_token_ids_to_cpu(scheduler_output) if ( @@ -3972,6 +4006,11 @@ def sample_tokens( next_tids, valid_counts, ) + # Save for next merged replay's token_indices_to_sample + # computation and correct next_token_ids. + self._merged_prev_valid_counts[: valid_counts.shape[0]].copy_( + valid_counts + ) self._merged_replay_active = False @@ -4461,6 +4500,11 @@ def propose_draft_token_ids( self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) + # Save valid_counts for merged graph's tis/nrej computation. + if hasattr(self, "_merged_prev_valid_counts"): + self._merged_prev_valid_counts[ + : valid_sampled_tokens_count.shape[0] + ].copy_(valid_sampled_tokens_count) num_rejected_tokens_gpu = None if spec_decode_metadata is None: @@ -4547,35 +4591,19 @@ def _update_merged_drafter_inputs( scheduler_output: "SchedulerOutput", merged_replayed: bool = False, ): - """Update persistent buffers so the next merged-graph drafter - reads correct hidden_states and next_token_ids.""" + """No-op in the new Target→Logits→Drafter ordering. + + The graph already computes next_token_ids via argmax(logits) and + feeds current-step hidden states to the drafter. This method is + kept for compatibility but only sets _merged_hs_ready.""" if not hasattr(self, "_merged_hs_buf"): return - buf = self._merged_hs_buf - if not merged_replayed: - # Only update hs buffer from the normal (non-merged) path. - # When the merged graph replayed, it already updated - # _merged_hs_buf via the captured copy_ operation. n = scheduler_output.total_num_scheduled_tokens - if n > buf.shape[0]: - return - if self.use_aux_hidden_state_outputs and aux_hidden_states: - hs = torch.cat([h[:n] for h in aux_hidden_states], dim=-1) - else: - hs = hidden_states[:n] - buf[:n].copy_(hs) - - if n == buf.shape[0]: + if n == self._merged_hs_buf.shape[0]: self._merged_hs_ready = True - accepted = sampler_output.sampled_token_ids - if isinstance(accepted, torch.Tensor): - tok = accepted[:, 0].int() if accepted.ndim == 2 else accepted.int() - nt = self._merged_next_token_ids - nt[: tok.shape[0]].copy_(tok) - def update_config(self, overrides: dict[str, Any]) -> None: allowed_config_names = {"load_config", "model_config"} for config_name, config_overrides in overrides.items(): @@ -5918,12 +5946,14 @@ def _warmup_and_capture( @torch.inference_mode() def _merged_capture(self, desc: BatchDescriptor): - """Capture [Drafter → Target → compute_logits] in one CUDA graph. + """Capture [Target → compute_logits → Drafter] in one CUDA graph. - In the reordered flow, the drafter runs first (using the PREVIOUS - target forward's hidden_states), then the target forward verifies - the PREVIOUS step's draft tokens, then compute_logits produces - logits for the rejection sampler (which runs outside the graph). + The target forward runs first with the CURRENT step's inputs, + then compute_logits produces logits. next_token_ids is derived + from argmax(logits) inside the graph so the drafter gets the + correct accepted token. The drafter then predicts draft tokens + using the CURRENT step's hidden states and the freshly computed + next_token_ids. Between graph replays, the rejection sampler and bookkeeping run on the CPU/GPU outside the graph. @@ -5933,7 +5963,6 @@ def _merged_capture(self, desc: BatchDescriptor): assert isinstance(self.drafter, EagleProposer) num_tokens = desc.num_tokens - # Additional warmup so attention layers see the padded shapes. self._dummy_run( num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, @@ -5981,7 +6010,7 @@ def _merged_capture(self, desc: BatchDescriptor): input_ids = self.input_ids.gpu[:num_tokens] positions = self._get_positions(num_tokens) - # ── Persistent buffers for drafter ↔ target data bridge ── + # ── Persistent buffers ── hidden_size = self.model_config.get_hidden_size() if self.use_aux_hidden_state_outputs: get = getattr( @@ -6006,12 +6035,46 @@ def _merged_capture(self, desc: BatchDescriptor): device=self.device, ) + # Buffers to save/restore shared state that the drafter modifies + # in-place (seq_lens, query_start_loc) after the target is done. + self._merged_save_seq_lens = torch.empty( + num_reqs, dtype=self.seq_lens.gpu.dtype, device=self.device + ) + self._merged_save_qsl = torch.empty( + num_reqs + 1, + dtype=self.query_start_loc.gpu.dtype, + device=self.device, + ) + + blk_table = self.input_batch.block_table[0] + self._merged_slot_mapping_ref = blk_table.slot_mapping.gpu[:num_tokens] + self._merged_save_slot_mapping = torch.empty( + num_tokens, + dtype=self._merged_slot_mapping_ref.dtype, + device=self.device, + ) + + self._merged_token_indices_to_sample = torch.full( + (num_reqs,), + num_tokens - 1, + dtype=torch.int32, + device=self.device, + ) + self._merged_num_rejected_tokens_gpu = torch.zeros( + num_reqs, + dtype=torch.int32, + device=self.device, + ) + self._merged_prev_valid_counts = torch.ones( + num_reqs, + dtype=torch.int32, + device=self.device, + ) + # ── Bypass individual CUDAGraphWrappers ── cg._merged_capture_bypass = True - if not hasattr(self, "_merged_graph_pool"): - self._merged_graph_pool = torch.cuda.graph_pool_handle() - graph_pool = self._merged_graph_pool + graph_pool = current_platform.get_global_graph_pool() merged_graph = torch.cuda.CUDAGraph(keep_graph=True) assert spec_decode_cm is not None @@ -6021,20 +6084,6 @@ def _merged_capture(self, desc: BatchDescriptor): pool=graph_pool, stream=torch.cuda.current_stream(), ): - # ═══ DRAFTER ═══ - draft_token_ids = self.drafter.propose( - target_token_ids=input_ids, - target_positions=positions, - target_hidden_states=self._merged_hs_buf, - next_token_ids=self._merged_next_token_ids, - token_indices_to_sample=None, - common_attn_metadata=spec_decode_cm, - target_model_batch_desc=desc, - sampling_metadata=self.input_batch.sampling_metadata, - num_rejected_tokens_gpu=None, - slot_mappings=slot_mappings, - ) - # ═══ TARGET FORWARD ═══ with set_forward_context( target_attn_metadata, @@ -6056,20 +6105,70 @@ def _merged_capture(self, desc: BatchDescriptor): ) # ═══ COMPUTE LOGITS ═══ - # For spec decode, logits are needed at ALL positions - # (not just the last per request), so the rejection sampler - # can compare target vs. draft at every token. sample_hidden = hidden_states[:num_tokens] - logits = self.model.compute_logits(sample_hidden) - - # ═══ Update persistent hs buffer for next replay ═══ + with self._prefer_hipblaslt_for_logits(): + logits = self.model.compute_logits(sample_hidden) + + # ═══ IN-GRAPH GREEDY REJECTION + next_token_ids ═══ + # Compute acceptance pattern by comparing argmax(logits) + # with draft tokens (already in input_ids from _prepare_inputs). + greedy_all = logits.argmax(dim=-1) # [num_tokens] + greedy_per_req = greedy_all[:num_tokens].view(num_reqs, max_query_len) + ids_per_req = input_ids[:num_tokens].view(num_reqs, max_query_len) + # Draft tokens at positions 1..max_query_len-1 + num_draft = max_query_len - 1 + accepted = torch.zeros(num_reqs, dtype=torch.int32, device=self.device) + running_match = torch.ones(num_reqs, dtype=torch.int32, device=self.device) + for k in range(num_draft): + match_k = (greedy_per_req[:, k] == ids_per_req[:, k + 1]).int() + running_match = running_match * match_k + accepted = accepted + running_match + + offsets = ( + torch.arange(num_reqs, dtype=torch.int32, device=self.device) + * max_query_len + ) + self._merged_token_indices_to_sample[:num_reqs] = offsets + accepted + self._merged_num_rejected_tokens_gpu[:num_reqs] = num_draft - accepted + bonus = greedy_per_req.gather(1, accepted.unsqueeze(1).long()).squeeze(1) + self._merged_next_token_ids[:num_reqs] = bonus.int() + + # ═══ PREPARE DRAFTER INPUTS ═══ + # Save target state before drafter's in-place modifications. + self._merged_save_seq_lens.copy_(self.seq_lens.gpu[:num_reqs]) + self._merged_save_qsl.copy_(self.query_start_loc.gpu[: num_reqs + 1]) + self._merged_save_slot_mapping.copy_(self._merged_slot_mapping_ref) + + # Prepare hidden states for the drafter. if self.use_aux_hidden_state_outputs: hs_for_drafter = torch.cat([h[:num_tokens] for h in aux_hs], dim=-1) else: hs_for_drafter = hidden_states[:num_tokens] self._merged_hs_buf.copy_(hs_for_drafter) + # ═══ DRAFTER ═══ + # Uses CURRENT step's hidden states and freshly computed + # next_token_ids from argmax(logits). + draft_token_ids = self.drafter.propose( + target_token_ids=input_ids, + target_positions=positions, + target_hidden_states=self._merged_hs_buf, + next_token_ids=self._merged_next_token_ids, + token_indices_to_sample=self._merged_token_indices_to_sample, + common_attn_metadata=spec_decode_cm, + target_model_batch_desc=desc, + sampling_metadata=self.input_batch.sampling_metadata, + num_rejected_tokens_gpu=self._merged_num_rejected_tokens_gpu, + slot_mappings=slot_mappings, + ) + + # Restore target state after drafter's in-place modifications. + self.seq_lens.gpu[:num_reqs].copy_(self._merged_save_seq_lens) + self.query_start_loc.gpu[: num_reqs + 1].copy_(self._merged_save_qsl) + self._merged_slot_mapping_ref.copy_(self._merged_save_slot_mapping) + cg._merged_capture_bypass = False + merged_graph.instantiate() # ── Store handles for replay / dump ── @@ -6150,9 +6249,14 @@ def _capture_cudagraphs( # CUDAGraphWrapper graphs for all batch sizes are finalized. # This prevents later individual captures from reallocating # torch.compile internal buffers that the merged graph references. - for desc in merged_descs: - self._merged_capture(desc) - torch.accelerator.synchronize() + if not os.environ.get("VLLM_DISABLE_MERGED_GRAPH"): + for desc in merged_descs: + self._merged_capture(desc) + torch.accelerator.synchronize() + if os.environ.get("VLLM_DISABLE_MERGED_REPLAY") and hasattr( + self, "_merged_full_graphs" + ): + self._merged_full_graphs.clear() self.maybe_remove_all_loras(self.lora_config) From bd7764e5b64301bd1a9100acfdda0c62920a3676 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Sun, 22 Mar 2026 13:19:15 -0600 Subject: [PATCH 6/8] Remove superseded pre-replay kernel and _merged_prev_valid_counts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The in-graph greedy rejection (added in the previous commit) overwrites _merged_token_indices_to_sample and _merged_num_rejected_tokens_gpu before the drafter reads them, making the pre-replay eagle_prepare_inputs_padded_kernel redundant. Removing it also eliminates a hidden cu_draft[-1].item() GPU→CPU sync that cost ~0.3ms per step. Results (Qwen3-4B + EAGLE3, 2 spec tokens, Strix Halo): Before: TPOT 9.31ms After: TPOT 9.00ms (baseline 8.98ms) Signed-off-by: Matthias Gehre --- vllm/v1/worker/gpu_model_runner.py | 42 ------------------------------ 1 file changed, 42 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ef5a4af904d..c7d0fec5ae06 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3773,32 +3773,6 @@ def execute_model( cudagraph_mode = CUDAGraphMode.NONE if use_merged_graph: - # Compute token_indices_to_sample for the merged drafter, - # accounting for rejected tokens from the previous step. - num_reqs_m = batch_desc.num_reqs or num_tokens_padded - if spec_decode_metadata is not None: - from vllm.v1.spec_decode.utils import ( - eagle_prepare_inputs_padded_kernel, - ) - - cu_draft = spec_decode_metadata.cu_num_draft_tokens - actual_reqs = self.input_batch.num_reqs - if cu_draft.shape[0] < num_reqs_m: - cu_draft = nn.functional.pad( - cu_draft, - (0, num_reqs_m - actual_reqs), - mode="constant", - value=cu_draft[-1].item(), - ) - eagle_prepare_inputs_padded_kernel[(num_reqs_m,)]( - cu_draft, - self._merged_prev_valid_counts, - self.query_start_loc.gpu, - self._merged_token_indices_to_sample, - self._merged_num_rejected_tokens_gpu, - num_reqs_m, - ) - with set_forward_context( attn_metadata, self.vllm_config, @@ -4006,11 +3980,6 @@ def sample_tokens( next_tids, valid_counts, ) - # Save for next merged replay's token_indices_to_sample - # computation and correct next_token_ids. - self._merged_prev_valid_counts[: valid_counts.shape[0]].copy_( - valid_counts - ) self._merged_replay_active = False @@ -4500,11 +4469,6 @@ def propose_draft_token_ids( self._copy_valid_sampled_token_count( next_token_ids, valid_sampled_tokens_count ) - # Save valid_counts for merged graph's tis/nrej computation. - if hasattr(self, "_merged_prev_valid_counts"): - self._merged_prev_valid_counts[ - : valid_sampled_tokens_count.shape[0] - ].copy_(valid_sampled_tokens_count) num_rejected_tokens_gpu = None if spec_decode_metadata is None: @@ -6065,12 +6029,6 @@ def _merged_capture(self, desc: BatchDescriptor): dtype=torch.int32, device=self.device, ) - self._merged_prev_valid_counts = torch.ones( - num_reqs, - dtype=torch.int32, - device=self.device, - ) - # ── Bypass individual CUDAGraphWrappers ── cg._merged_capture_bypass = True From a6bb2b9e3d4203315981bd2543ea76b4e0ef8740 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Sun, 22 Mar 2026 16:16:49 -0600 Subject: [PATCH 7/8] Fast-path post-replay: skip prepare_next_token_ids_padded for greedy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For greedy decoding, the in-graph rejection already computes next_token_ids and num_rejected_tokens_gpu via the same argmax comparison used by the real rejection sampler. Derive valid_sampled_tokens_count directly from the in-graph results instead of calling prepare_next_token_ids_padded, which involves a Python loop (backup token computation), a CPU→GPU copy, tensor allocations, and a Triton kernel launch. Falls back to the full prepare_next_token_ids_padded for non-greedy sampling where the real sampler may produce different acceptance patterns than the greedy in-graph rejection. Signed-off-by: Matthias Gehre --- vllm/v1/worker/gpu_model_runner.py | 48 ++++++++++++++++++++++-------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index c7d0fec5ae06..618b4350b719 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3968,18 +3968,37 @@ def sample_tokens( self.valid_sampled_token_count_event is not None and spec_decode_common_attn_metadata is not None ): - sampled_tids = sampler_output.sampled_token_ids - next_tids, valid_counts = self.drafter.prepare_next_token_ids_padded( - spec_decode_common_attn_metadata, - sampled_tids, - self.requests, - self.input_batch, - self.discard_request_mask.gpu, - ) - self._copy_valid_sampled_token_count( - next_tids, - valid_counts, - ) + sampling_metadata = self.input_batch.sampling_metadata + if sampling_metadata.all_greedy: + # Fast path: reuse in-graph rejection results directly. + # The in-graph greedy rejection already computed + # next_token_ids and num_rejected_tokens_gpu via the + # same argmax comparison the real rejection sampler uses. + n = self._merged_next_token_ids.shape[0] + self._merged_valid_counts[:n] = ( + self.num_spec_tokens + + 1 + - self._merged_num_rejected_tokens_gpu[:n] + ) + self._copy_valid_sampled_token_count( + self._merged_next_token_ids, + self._merged_valid_counts, + ) + else: + sampled_tids = sampler_output.sampled_token_ids + next_tids, valid_counts = ( + self.drafter.prepare_next_token_ids_padded( + spec_decode_common_attn_metadata, + sampled_tids, + self.requests, + self.input_batch, + self.discard_request_mask.gpu, + ) + ) + self._copy_valid_sampled_token_count( + next_tids, + valid_counts, + ) self._merged_replay_active = False @@ -6029,6 +6048,11 @@ def _merged_capture(self, desc: BatchDescriptor): dtype=torch.int32, device=self.device, ) + self._merged_valid_counts = torch.zeros( + num_reqs, + dtype=torch.int32, + device=self.device, + ) # ── Bypass individual CUDAGraphWrappers ── cg._merged_capture_bypass = True From 612385ca25985de97282f17251fe5f1f0345300e Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 23 Mar 2026 05:19:58 -0600 Subject: [PATCH 8/8] Skip external rejection sampler for greedy merged graph replay When all conditions are met (greedy decoding, no logprobs, no penalties, no bad words, no constrained decoding, no non-argmax-invariant logits processors, no grammar output), construct sampled_token_ids directly from the in-graph rejection results instead of running the external rejection sampler. This eliminates the bonus argmax, target logits extraction, and rejection kernel while maintaining identical output. Signed-off-by: Matthias Gehre --- vllm/v1/worker/gpu_model_runner.py | 61 ++++++++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 618b4350b719..c976a969bddc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3939,8 +3939,28 @@ def sample_tokens( scheduler_output, grammar_output, self.input_batch, logits ) - with record_function_or_nullcontext("gpu_model_runner: sample"): - sampler_output = self._sample(logits, spec_decode_metadata) + merged_replay_active = getattr(self, "_merged_replay_active", False) + skip_external_sampler = False + if merged_replay_active and grammar_output is None: + sm = self.input_batch.sampling_metadata + skip_external_sampler = ( + sm.all_greedy + and sm.max_num_logprobs is None + and sm.no_penalties + and not sm.bad_words_token_ids + and sm.allowed_token_ids_mask is None + and not sm.logitsprocs.non_argmax_invariant + ) + + if skip_external_sampler: + n = self.input_batch.num_reqs + sampler_output = SamplerOutput( + sampled_token_ids=self._build_sampled_token_ids(n), + logprobs_tensors=None, + ) + else: + with record_function_or_nullcontext("gpu_model_runner: sample"): + sampler_output = self._sample(logits, spec_decode_metadata) self._update_states_after_model_execute( sampler_output.sampled_token_ids, scheduler_output @@ -3959,7 +3979,6 @@ def sample_tokens( self._draft_token_req_ids = None self.input_batch.prev_sampled_token_ids = None - merged_replay_active = getattr(self, "_merged_replay_active", False) if merged_replay_active: self._draft_token_ids = self._merged_draft_ids_safe self._copy_draft_token_ids_to_cpu(scheduler_output) @@ -4270,6 +4289,33 @@ def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]: self.draft_token_ids_event.synchronize() return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids + def _build_sampled_token_ids(self, n: int) -> torch.Tensor: + """Build sampled_token_ids from in-graph rejection results. + + Reproduces the output format of rejection_greedy_sample_kernel: + positions 0..accepted-1 contain the (accepted) draft tokens, + position accepted contains the bonus/replacement token, and + remaining positions are -1 (PLACEHOLDER_TOKEN_ID). + """ + mqlen = self.num_spec_tokens + 1 + num_draft = self.num_spec_tokens + rejected = self._merged_num_rejected_tokens_gpu[:n] + accepted = num_draft - rejected + + out = self._merged_sampled_token_ids[:n] + out.fill_(-1) + + # Draft tokens live in input_ids at positions 1..num_draft per request. + # For greedy acceptance, draft_token == target_argmax. + ids = self.input_ids.gpu[: n * mqlen].view(n, mqlen) + valid_mask = self._merged_col_indices[:, :num_draft] < accepted.unsqueeze(1) + out[:, :num_draft][valid_mask] = ids[:, 1:][valid_mask].int() + + # Bonus/replacement token at position `accepted`. + row_idx = torch.arange(n, dtype=torch.long, device=self.device) + out[row_idx, accepted.long()] = self._merged_next_token_ids[:n] + return out + def _copy_valid_sampled_token_count( self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor ) -> None: @@ -6053,6 +6099,15 @@ def _merged_capture(self, desc: BatchDescriptor): dtype=torch.int32, device=self.device, ) + self._merged_sampled_token_ids = torch.full( + (num_reqs, max_query_len), + -1, + dtype=torch.int32, + device=self.device, + ) + self._merged_col_indices = torch.arange( + max_query_len, dtype=torch.int32, device=self.device + ).unsqueeze(0) # ── Bypass individual CUDAGraphWrappers ── cg._merged_capture_bypass = True