From 5af3bda596ece9546a910afe6616cef9b55d50e6 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 4 Dec 2025 21:19:36 +0000 Subject: [PATCH 01/18] wip Signed-off-by: Lucas Wilkinson --- .../models/language/generation/test_hybrid.py | 7 +- tests/v1/cudagraph/test_cudagraph_dispatch.py | 13 +-- vllm/config/compilation.py | 57 --------- vllm/v1/cudagraph_dispatcher.py | 36 +++++- vllm/v1/spec_decode/eagle.py | 109 ++++++++---------- vllm/v1/worker/gpu_model_runner.py | 18 +-- 6 files changed, 96 insertions(+), 144 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 37830093cd3c..dd5b7fdc3e1c 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -9,6 +9,7 @@ from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from ...utils import check_logprobs_close, check_outputs_equal @@ -172,7 +173,11 @@ def test_mamba_cache_cg_padding( tensor dimensions aren't compatible. """ vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): + cudagraph_dispatcher = CudagraphDispatcher(vllm_config) + while ( + len(example_prompts) + == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens + ): example_prompts.append(example_prompts[0]) try: diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index f9d3e8d0532b..1ff1b8ed4c59 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -61,9 +61,6 @@ def _create_vllm_config( ) compilation_config.post_init_cudagraph_sizes() - mock_config.pad_for_cudagraph = ( - lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size] - ) return mock_config @@ -167,8 +164,8 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): # 4. disable_full should have a fall back mode (e.g., cascade attention) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) rt_mode, key = dispatcher.dispatch( - num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True - ) + num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True) + if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() @@ -360,7 +357,7 @@ def test_capture_replay_bypass_logic(self): ): full_wrapper(input_1) - rt_mode, key = self.dispatcher.dispatch(desc_1) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens) # 1. Capture first shape action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "capture_global" @@ -369,7 +366,7 @@ def test_capture_replay_bypass_logic(self): action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "replay" - rt_mode, key = self.dispatcher.dispatch(desc_2) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens) # 3. Capture second shape action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) assert action == "capture_global" @@ -381,7 +378,7 @@ def test_capture_replay_bypass_logic(self): assert action == "replay" # 5. Bypass if no key match - rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens) assert rt_mode == CUDAGraphMode.NONE action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) assert action == "bypass" diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 035aa24e33c7..5ad47dcbb8da 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -581,15 +581,6 @@ class CompilationConfig: local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" - bs_to_padded_graph_size: list[int] = field( - default=None, # type: ignore - init=False, - ) - """optimization: - Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_cudagraph_capture_size], - we can optimize it to list[int] for better lookup performance.""" - # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" @@ -639,7 +630,6 @@ def compute_hash(self) -> str: "debug_dump_path", "cache_dir", "local_cache_dir", - "bs_to_padded_graph_size", "traced_files", "compilation_time", "static_forward_context", @@ -661,7 +651,6 @@ def __repr__(self) -> str: "enabled_custom_ops": True, "disabled_custom_ops": True, "compilation_time": True, - "bs_to_padded_graph_size": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, @@ -882,7 +871,6 @@ def post_init_cudagraph_sizes(self) -> None: """To complete the initialization after cudagraph related configs are set. This includes: - initialize compile_sizes - - pre-compute the mapping bs_to_padded_graph_size """ computed_compile_sizes = [] @@ -906,23 +894,6 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # May get recomputed in the model runner if adjustment is needed for spec-decode - self.compute_bs_to_padded_graph_size() - - # Validate that compile_sizes won't be changed by padding. - # Only validate when cudagraphs are actually being used. - if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE: - for size in self.compile_sizes: - if size <= self.max_cudagraph_capture_size: - padded = self.bs_to_padded_graph_size[size] - if padded != size: - raise ValueError( - f"compile_sizes contains {size} which would be " - f"padded to {padded}. All compile_sizes must be " - "values that won't be changed by cudagraph padding. " - "Use values from cudagraph_capture_sizes." - ) - def set_splitting_ops_for_v1( self, all2all_backend: str, data_parallel_size: int = 1 ): @@ -1133,31 +1104,3 @@ def adjust_cudagraph_sizes_for_spec_decode( self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes - - # Recompute after adjusting the cudagraph sizes - self.compute_bs_to_padded_graph_size() - - def compute_bs_to_padded_graph_size(self): - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_cudagraph_capture_size + 1) - ] - for end, start in zip( - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], - [0] + self.cudagraph_capture_sizes, - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end - - def get_compile_ranges(self) -> list[Range]: - """Get the compile ranges for the compilation config.""" - if self.compile_ranges_split_points is None: - return [] - split_points = sorted(set(self.compile_ranges_split_points)) - return [ - Range(start=s + 1, end=e) - for s, e in zip([0] + split_points[:-1], split_points) - ] diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 8a3500c0aac6..56ba671388eb 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -58,12 +58,27 @@ def __init__(self, vllm_config: VllmConfig): self.keys_initialized = False + def _compute_bs_to_padded_graph_size(self) -> None: + """Pre-compute the mapping from batch size to padded graph size.""" + max_size = self.compilation_config.max_cudagraph_capture_size + capture_sizes = self.compilation_config.cudagraph_capture_sizes + self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) + for end, start in zip( + capture_sizes + [max_size + 1], + [0] + capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self._bs_to_padded_graph_size[bs] = start + else: + self._bs_to_padded_graph_size[bs] = end + def _create_padded_batch_descriptor( self, num_tokens: int, uniform_decode: bool, has_lora: bool ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len - num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): num_reqs = num_tokens_padded // uniform_decode_query_len @@ -140,18 +155,29 @@ def initialize_cudagraph_keys( self.keys_initialized = True + self._compute_bs_to_padded_graph_size() + def dispatch( self, num_tokens: int, - uniform_decode: bool, - has_lora: bool, + uniform_decode: bool = False, + has_lora: bool = False, disable_full: bool = False, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ - Given conditions(e.g.,batch descriptor and if using cascade attention), + Given conditions(e.g.,batch descriptor and if using piecewise only), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + Args: + num_tokens: Number of tokens in the batch. + uniform_decode: Whether the batch is uniform decode (i.e. uniform and query + length is uniform_decode_query_len). + has_lora: Whether LoRA is active. + piecewise_or_eager_only: If True, skip FULL cudagraph checks and + return PIECEWISE or NONE only. (can be used for features cascade + attention that are not supported by full cudagraphs) """ if ( not self.keys_initialized @@ -165,7 +191,7 @@ def dispatch( ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() - if not disable_full: +s if not disable_full: # check if key exists for full cudagraph if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: return CUDAGraphMode.FULL, batch_desc diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b7693f4f733d..3fe8b220f77c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,6 @@ import torch.nn as nn from vllm.config import ( - CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, @@ -36,6 +35,11 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + AttentionMetadataBuilder, + CommonAttentionMetadata, +) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -94,24 +98,13 @@ def __init__( self._get_eagle3_use_aux_hidden_state_from_config() ) - self.use_cuda_graph = False - self.compilation_config = self.vllm_config.compilation_config - if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - cudagraph_mode = self.compilation_config.cudagraph_mode - if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( - CUDAGraphMode.PIECEWISE - ): - logger.warning( - "Currently the eagle proposer only supports cudagraph_mode " - "PIECEWISE, if you want the drafter to use cuda graphs, " - "please set compilation_config.cudagraph_mode to PIECEWISE " - "or FULL_AND_PIECEWISE" - ) - self.use_cuda_graph = ( - cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + + # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle. + # Keys are initialized later via initialize_cudagraph_keys() called from + # gpu_model_runner._check_and_update_cudagraph_mode after + # adjust_cudagraph_sizes_for_spec_decode is called. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) # persistent buffers for cuda graph self.input_ids = torch.zeros( @@ -229,6 +222,21 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): else: self.positions[:num_tokens] = positions + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: + """Initialize cudagraph dispatcher keys for eagle. + + Eagle only supports PIECEWISE cudagraphs (via mixed_mode). + This should be called after adjust_cudagraph_sizes_for_spec_decode. + """ + eagle_cudagraph_mode = ( + cudagraph_mode.mixed_mode() + if not self.speculative_config.enforce_eager + else CUDAGraphMode.NONE + ) + self.cudagraph_dispatcher.initialize_cudagraph_keys( + eagle_cudagraph_mode, uniform_decode_query_len=1 + ) + def propose( self, # [num_tokens] @@ -298,16 +306,10 @@ def propose( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens ) - cudagraph_runtime_mode = CUDAGraphMode.NONE - if ( - self.use_cuda_graph - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -401,16 +403,10 @@ def propose( num_tokens_unpadded=batch_size, num_tokens_padded=batch_size ) - if ( - self.use_cuda_graph - and batch_size_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - input_batch_size = batch_size_dp_padded - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + batch_size_dp_padded + ) + input_batch_size = batch_desc.num_tokens if batch_size_across_dp is not None: batch_size_across_dp[self.dp_rank] = input_batch_size @@ -827,15 +823,10 @@ def propose_tree( self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if ( - self.use_cuda_graph - and num_tokens <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens + ) + num_input_tokens = batch_desc.num_tokens # Run the model. with set_forward_context( per_layer_attn_metadata, @@ -1173,9 +1164,6 @@ def dummy_run( use_cudagraphs: bool = True, is_graph_capturing: bool = False, ) -> None: - # Determine if CUDA graphs should be used for this run. - cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph - # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. for fwd_idx in range( @@ -1185,16 +1173,10 @@ def dummy_run( num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens ) - if ( - cudagraphs_enabled - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens_dp_padded - ) - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -1203,9 +1185,7 @@ def dummy_run( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if cudagraphs_enabled - else CUDAGraphMode.NONE, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): if self.supports_mm_inputs: input_ids = None @@ -1295,7 +1275,8 @@ def _pad_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.vllm_config.parallel_config, allow_microbatching=False, - allow_dp_padding=self.use_cuda_graph, + allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode + != CUDAGraphMode.NONE, num_tokens_padded=num_tokens_padded, uniform_decode=None, num_scheduled_tokens_per_request=None, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 00e401f41f3e..9fc8c04246c4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2129,15 +2129,11 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item() ) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits + # Dispatch for the decoder portion of the model. + _, batch_desc = self.cudagraph_dispatcher.dispatch( + num_logits, piecewise_or_eager_only=True + ) + num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ :num_logits_padded ] @@ -5184,6 +5180,10 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + # Initialize eagle's cudagraph dispatcher if using eagle spec decode. + if self.speculative_config and self.speculative_config.use_eagle(): + self.drafter.initialize_cudagraph_keys(cudagraph_mode) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups. From 720cb5c579669b6f052bbf1e46acd4896865c355 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 5 Dec 2025 16:53:40 +0000 Subject: [PATCH 02/18] fix precommit Signed-off-by: Lucas Wilkinson --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9fc8c04246c4..3ac72c544ba3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5182,6 +5182,7 @@ def _check_and_update_cudagraph_mode( # Initialize eagle's cudagraph dispatcher if using eagle spec decode. if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) self.drafter.initialize_cudagraph_keys(cudagraph_mode) def calculate_reorder_batch_threshold(self) -> None: From 1c200e0eb558f37a56ea72590890098f75874b78 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 11 Dec 2025 04:53:23 +0000 Subject: [PATCH 03/18] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 56ba671388eb..9d9db914c918 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -108,6 +108,7 @@ def initialize_cudagraph_keys( # This should be called only after attention backend is initialized. So we can # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode + self._compute_bs_to_padded_graph_size() # LoRA activation cases to specialize the cuda graphs on if self.vllm_config.lora_config: @@ -155,8 +156,6 @@ def initialize_cudagraph_keys( self.keys_initialized = True - self._compute_bs_to_padded_graph_size() - def dispatch( self, num_tokens: int, @@ -191,7 +190,7 @@ def dispatch( ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() -s if not disable_full: + if not disable_full: # check if key exists for full cudagraph if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: return CUDAGraphMode.FULL, batch_desc From af19ccc3949af6cf4fc4f99bb8b02cf66118d65d Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 11 Dec 2025 04:53:23 +0000 Subject: [PATCH 04/18] fix Signed-off-by: Lucas Wilkinson --- tests/v1/cudagraph/test_cudagraph_dispatch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 1ff1b8ed4c59..cdac67a2d650 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -164,7 +164,8 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): # 4. disable_full should have a fall back mode (e.g., cascade attention) desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False) rt_mode, key = dispatcher.dispatch( - num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True) + num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True + ) if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE From d8f15eabfd50faba164adc94db007e1575a6dc48 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 12 Dec 2025 11:37:41 -0800 Subject: [PATCH 05/18] fix Signed-off-by: Lucas Wilkinson --- vllm/compilation/piecewise_backend.py | 2 +- vllm/compilation/sequence_parallelism.py | 2 +- vllm/config/compilation.py | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 29d6f89990cd..ebf2715a109b 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -10,7 +10,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig -from vllm.config.compilation import Range +from vllm.config.utils import Range from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 34ff2ab47e56..b35c192dfd23 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -11,7 +11,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig -from vllm.config.compilation import Range +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5ad47dcbb8da..7f7ce037ab52 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,7 +14,6 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import ( - Range, config, get_hash_factors, hash_factors, From 7487a07f4a7ffcf9cfe8b414bd84fcaa176f9cb1 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 12 Dec 2025 20:39:27 +0000 Subject: [PATCH 06/18] fix Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 11 +++++++++++ vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 7f7ce037ab52..c5e99cd5f23f 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -14,6 +14,7 @@ import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.config.utils import ( + Range, config, get_hash_factors, hash_factors, @@ -1103,3 +1104,13 @@ def adjust_cudagraph_sizes_for_spec_decode( self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes + + def get_compile_ranges(self) -> list[Range]: + """Get the compile ranges for the compilation config.""" + if self.compile_ranges_split_points is None: + return [] + split_points = sorted(set(self.compile_ranges_split_points)) + return [ + Range(start=s + 1, end=e) + for s, e in zip([0] + split_points[:-1], split_points) + ] \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 3ac72c544ba3..73bc870b36c0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2131,7 +2131,7 @@ def _prepare_kv_sharing_fast_prefill( ) # Dispatch for the decoder portion of the model. _, batch_desc = self.cudagraph_dispatcher.dispatch( - num_logits, piecewise_or_eager_only=True + num_logits, disable_full=True ) num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ From 8c763a9a62e99dc924f2bcf8a7ebde7f892d216e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 12 Dec 2025 20:41:47 +0000 Subject: [PATCH 07/18] reveiw comment Signed-off-by: Lucas Wilkinson --- vllm/v1/spec_decode/eagle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3fe8b220f77c..1089ab6ce158 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -229,8 +229,8 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: This should be called after adjust_cudagraph_sizes_for_spec_decode. """ eagle_cudagraph_mode = ( - cudagraph_mode.mixed_mode() - if not self.speculative_config.enforce_eager + CUDAGraphMode.PIECEWISE + if not self.speculative_config.enforce_eager and cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE else CUDAGraphMode.NONE ) self.cudagraph_dispatcher.initialize_cudagraph_keys( From 1bc1b9481488c9e9a66c921380fd23bb190f770a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 12 Dec 2025 20:43:35 +0000 Subject: [PATCH 08/18] format Signed-off-by: Lucas Wilkinson --- vllm/v1/spec_decode/eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 1089ab6ce158..eb0fa5cfa802 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -230,7 +230,8 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: """ eagle_cudagraph_mode = ( CUDAGraphMode.PIECEWISE - if not self.speculative_config.enforce_eager and cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE + if not self.speculative_config.enforce_eager + and cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE else CUDAGraphMode.NONE ) self.cudagraph_dispatcher.initialize_cudagraph_keys( From 804f3fc312d6d52bb7a80a600ccab97610a99715 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 12 Dec 2025 20:51:44 +0000 Subject: [PATCH 09/18] format Signed-off-by: Lucas Wilkinson --- vllm/config/compilation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index c5e99cd5f23f..8ce31ad18133 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1113,4 +1113,4 @@ def get_compile_ranges(self) -> list[Range]: return [ Range(start=s + 1, end=e) for s, e in zip([0] + split_points[:-1], split_points) - ] \ No newline at end of file + ] From 6038121ea4be53a6bb4028b3c33a432ef23f733b Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 8 Jan 2026 17:27:05 +0000 Subject: [PATCH 10/18] fix doc Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 9d9db914c918..5a01016db390 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -174,9 +174,9 @@ def dispatch( uniform_decode: Whether the batch is uniform decode (i.e. uniform and query length is uniform_decode_query_len). has_lora: Whether LoRA is active. - piecewise_or_eager_only: If True, skip FULL cudagraph checks and - return PIECEWISE or NONE only. (can be used for features cascade - attention that are not supported by full cudagraphs) + disable_full: If True, skip FULL cudagraph checks and + return PIECEWISE or NONE only. (can be used for features like + cascade attention that are not supported by full cudagraphs) """ if ( not self.keys_initialized From d5ce3dca0d8ad92bce4bad01505d69a77f8bf74e Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 Jan 2026 15:40:58 +0000 Subject: [PATCH 11/18] add back warning Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 5a01016db390..cbd3c9689bf9 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -73,6 +73,23 @@ def _compute_bs_to_padded_graph_size(self) -> None: else: self._bs_to_padded_graph_size[bs] = end + # Validate that compile_sizes won't be changed by padding. + # Only validate when cudagraphs are actually being used. + if ( + self.compilation_config.compile_sizes + and self.cudagraph_mode != CUDAGraphMode.NONE + ): + for size in self.compilation_config.compile_sizes: + if size <= self.compilation_config.max_cudagraph_capture_size: + padded = self._bs_to_padded_graph_size[size] + if padded != size: + raise ValueError( + f"compile_sizes contains {size} which would be " + f"padded to {padded}. All compile_sizes must be " + "values that won't be changed by cudagraph padding. " + "Use values from cudagraph_capture_sizes." + ) + def _create_padded_batch_descriptor( self, num_tokens: int, uniform_decode: bool, has_lora: bool ) -> BatchDescriptor: From 7d6b35c1c6cbbd7365f7a4086cf12d89c4a40b43 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 Jan 2026 22:48:22 +0000 Subject: [PATCH 12/18] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/spec_decode/eagle.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index eb0fa5cfa802..79d2a07851f0 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -35,10 +35,6 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata -from vllm.v1.attention.backends.utils import ( - AttentionMetadataBuilder, - CommonAttentionMetadata, -) from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata From 78f12a900ee15eea999abd19412838b65e34b0dd Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 Jan 2026 22:53:02 +0000 Subject: [PATCH 13/18] clean Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 2 +- vllm/v1/spec_decode/eagle.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index cbd3c9689bf9..f9eb8ff3d0fa 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -120,7 +120,7 @@ def add_cudagraph_key( self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys( - self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1 ): # This should be called only after attention backend is initialized. So we can # get the correct cudagraph mode after backend support is resolved. diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 79d2a07851f0..1f100cedea88 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -224,15 +224,16 @@ def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: Eagle only supports PIECEWISE cudagraphs (via mixed_mode). This should be called after adjust_cudagraph_sizes_for_spec_decode. """ - eagle_cudagraph_mode = ( - CUDAGraphMode.PIECEWISE - if not self.speculative_config.enforce_eager - and cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE - else CUDAGraphMode.NONE - ) - self.cudagraph_dispatcher.initialize_cudagraph_keys( - eagle_cudagraph_mode, uniform_decode_query_len=1 - ) + if ( + not self.speculative_config.enforce_eager + and cudagraph_mode.mixed_mode() + in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] + ): + eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE + else: + eagle_cudagraph_mode = CUDAGraphMode.NONE + + self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) def propose( self, From 3e92b8b9b4643a086cae7a5f9bb92be010d28b35 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 Jan 2026 23:54:08 +0000 Subject: [PATCH 14/18] fix Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index f9eb8ff3d0fa..f54c2d5c1d0e 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -57,6 +57,8 @@ def __init__(self, vllm_config: VllmConfig): ) self.keys_initialized = False + # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called + self.cudagraph_mode = CUDAGraphMode.NONE def _compute_bs_to_padded_graph_size(self) -> None: """Pre-compute the mapping from batch size to padded graph size.""" @@ -125,6 +127,12 @@ def initialize_cudagraph_keys( # This should be called only after attention backend is initialized. So we can # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode + + # Early exit if cudagraphs are disabled (e.g., on CPU platforms) + if cudagraph_mode == CUDAGraphMode.NONE: + self.keys_initialized = True + return + self._compute_bs_to_padded_graph_size() # LoRA activation cases to specialize the cuda graphs on From 012f2d95dbedd86a0ea8eb62e3a6942633a5fad4 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Thu, 15 Jan 2026 23:54:21 +0000 Subject: [PATCH 15/18] clean Signed-off-by: Lucas Wilkinson --- vllm/v1/cudagraph_dispatcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index f54c2d5c1d0e..3e47e98c150c 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -128,7 +128,7 @@ def initialize_cudagraph_keys( # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode - # Early exit if cudagraphs are disabled (e.g., on CPU platforms) + # Early exit if cudagraphs are disabled if cudagraph_mode == CUDAGraphMode.NONE: self.keys_initialized = True return From a3930d43c994913c74fab278bb6b1792858985c0 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 19 Jan 2026 17:58:15 -0500 Subject: [PATCH 16/18] Update test Signed-off-by: Matthew Bonanni --- tests/compile/test_config.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 67e6c718fd05..8493cbf4708f 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -2,14 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy from contextlib import nullcontext -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CUDAGraphMode, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -17,6 +23,7 @@ _is_torch_equal_or_newer, is_torch_equal, ) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config): assert "torch.ops._C.static_scaled_fp8_quant.default(" in code +def _create_vllm_config_for_validation( + compilation_config: CompilationConfig, +) -> MagicMock: + """Helper to create a mock VllmConfig for padding validation testing.""" + mock_config = MagicMock(spec=VllmConfig) + mock_config.compilation_config = compilation_config + mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8) + mock_config.parallel_config = ParallelConfig() + mock_config.speculative_config = None + mock_config.lora_config = None + return mock_config + + def test_compile_sizes_padding_validation(): """Test that compile_sizes with values that would be padded raises an error.""" # cudagraph_capture_sizes=[1, 2, 4, 8] means: @@ -490,6 +510,8 @@ def test_compile_sizes_padding_validation(): compile_sizes=[3], ) config.post_init_cudagraph_sizes() + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) with pytest.raises(ValueError, match="would be padded to"): config = CompilationConfig( @@ -498,6 +520,8 @@ def test_compile_sizes_padding_validation(): compile_sizes=[5], ) config.post_init_cudagraph_sizes() + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 4, 8], @@ -506,6 +530,8 @@ def test_compile_sizes_padding_validation(): ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [1, 2, 4, 8] + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 4, 8], @@ -535,3 +561,5 @@ def test_compile_sizes_padding_validation(): ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [3, 5, 7] + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise From 7e0fe861f68f7144ab5e95e8a76116f3676f49b3 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 20 Jan 2026 07:51:54 -0500 Subject: [PATCH 17/18] Fix None value error Signed-off-by: Matthew Bonanni --- tests/compile/test_config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 8493cbf4708f..f1170b1b8b7b 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -508,6 +508,7 @@ def test_compile_sizes_padding_validation(): cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[3], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) @@ -518,6 +519,7 @@ def test_compile_sizes_padding_validation(): cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[5], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) @@ -527,6 +529,7 @@ def test_compile_sizes_padding_validation(): cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[1, 2, 4, 8], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [1, 2, 4, 8] @@ -537,6 +540,7 @@ def test_compile_sizes_padding_validation(): cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=["cudagraph_capture_sizes"], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [1, 2, 4, 8] From fa9034be22dc6889fd1b81d089bfc7f0707946b4 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Tue, 20 Jan 2026 17:00:39 +0000 Subject: [PATCH 18/18] Fix hanging test Signed-off-by: Matthew Bonanni --- tests/models/language/generation/test_hybrid.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index dd5b7fdc3e1c..c3e6d7899e2e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -174,6 +174,9 @@ def test_mamba_cache_cg_padding( """ vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() cudagraph_dispatcher = CudagraphDispatcher(vllm_config) + cudagraph_dispatcher.initialize_cudagraph_keys( + vllm_config.compilation_config.cudagraph_mode + ) while ( len(example_prompts) == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens