diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 67e6c718fd05..f1170b1b8b7b 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -2,14 +2,20 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy from contextlib import nullcontext -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from pydantic import ValidationError from vllm.compilation.counter import compilation_counter from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CUDAGraphMode, + ParallelConfig, + SchedulerConfig, + VllmConfig, +) from vllm.config.compilation import CompilationMode, PassConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform @@ -17,6 +23,7 @@ _is_torch_equal_or_newer, is_torch_equal, ) +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # This import automatically registers `torch.ops.silly.attention` from . import silly_attention # noqa: F401 @@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config): assert "torch.ops._C.static_scaled_fp8_quant.default(" in code +def _create_vllm_config_for_validation( + compilation_config: CompilationConfig, +) -> MagicMock: + """Helper to create a mock VllmConfig for padding validation testing.""" + mock_config = MagicMock(spec=VllmConfig) + mock_config.compilation_config = compilation_config + mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8) + mock_config.parallel_config = ParallelConfig() + mock_config.speculative_config = None + mock_config.lora_config = None + return mock_config + + def test_compile_sizes_padding_validation(): """Test that compile_sizes with values that would be padded raises an error.""" # cudagraph_capture_sizes=[1, 2, 4, 8] means: @@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation(): cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[3], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) with pytest.raises(ValueError, match="would be padded to"): config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[5], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=[1, 2, 4, 8], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [1, 2, 4, 8] + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise config = CompilationConfig( cudagraph_capture_sizes=[1, 2, 4, 8], max_cudagraph_capture_size=8, compile_sizes=["cudagraph_capture_sizes"], + cudagraph_mode=CUDAGraphMode.FULL, ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [1, 2, 4, 8] @@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation(): ) config.post_init_cudagraph_sizes() assert sorted(config.compile_sizes) == [3, 5, 7] + dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config)) + dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 37830093cd3c..c3e6d7899e2e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -9,6 +9,7 @@ from tests.utils import multi_gpu_test from vllm.engine.arg_utils import EngineArgs from vllm.sampling_params import SamplingParams +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from ...utils import check_logprobs_close, check_outputs_equal @@ -172,7 +173,14 @@ def test_mamba_cache_cg_padding( tensor dimensions aren't compatible. """ vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config() - while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)): + cudagraph_dispatcher = CudagraphDispatcher(vllm_config) + cudagraph_dispatcher.initialize_cudagraph_keys( + vllm_config.compilation_config.cudagraph_mode + ) + while ( + len(example_prompts) + == cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens + ): example_prompts.append(example_prompts[0]) try: diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index f9d3e8d0532b..cdac67a2d650 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -61,9 +61,6 @@ def _create_vllm_config( ) compilation_config.post_init_cudagraph_sizes() - mock_config.pad_for_cudagraph = ( - lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size] - ) return mock_config @@ -169,6 +166,7 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): rt_mode, key = dispatcher.dispatch( num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True ) + if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() @@ -360,7 +358,7 @@ def test_capture_replay_bypass_logic(self): ): full_wrapper(input_1) - rt_mode, key = self.dispatcher.dispatch(desc_1) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens) # 1. Capture first shape action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "capture_global" @@ -369,7 +367,7 @@ def test_capture_replay_bypass_logic(self): action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key) assert action == "replay" - rt_mode, key = self.dispatcher.dispatch(desc_2) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens) # 3. Capture second shape action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key) assert action == "capture_global" @@ -381,7 +379,7 @@ def test_capture_replay_bypass_logic(self): assert action == "replay" # 5. Bypass if no key match - rt_mode, key = self.dispatcher.dispatch(desc_3_unseen) + rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens) assert rt_mode == CUDAGraphMode.NONE action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key) assert action == "bypass" diff --git a/vllm/compilation/piecewise_backend.py b/vllm/compilation/piecewise_backend.py index 29d6f89990cd..ebf2715a109b 100644 --- a/vllm/compilation/piecewise_backend.py +++ b/vllm/compilation/piecewise_backend.py @@ -10,7 +10,7 @@ from vllm.compilation.backends import VllmBackend from vllm.compilation.monitor import end_monitoring_torch_compile from vllm.config import VllmConfig -from vllm.config.compilation import Range +from vllm.config.utils import Range from vllm.logger import init_logger logger = init_logger(__name__) diff --git a/vllm/compilation/sequence_parallelism.py b/vllm/compilation/sequence_parallelism.py index 34ff2ab47e56..b35c192dfd23 100644 --- a/vllm/compilation/sequence_parallelism.py +++ b/vllm/compilation/sequence_parallelism.py @@ -11,7 +11,7 @@ from torch._inductor.pattern_matcher import PatternMatcherPass from vllm.config import VllmConfig -from vllm.config.compilation import Range +from vllm.config.utils import Range from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 035aa24e33c7..8ce31ad18133 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -581,15 +581,6 @@ class CompilationConfig: local_cache_dir: str = field(default=None, init=False) # type: ignore """local cache dir for each rank""" - bs_to_padded_graph_size: list[int] = field( - default=None, # type: ignore - init=False, - ) - """optimization: - Intuitively, bs_to_padded_graph_size should be dict[int, int]. - since we know all keys are in a range [0, max_cudagraph_capture_size], - we can optimize it to list[int] for better lookup performance.""" - # keep track of enabled and disabled custom ops enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False) """custom ops that are enabled""" @@ -639,7 +630,6 @@ def compute_hash(self) -> str: "debug_dump_path", "cache_dir", "local_cache_dir", - "bs_to_padded_graph_size", "traced_files", "compilation_time", "static_forward_context", @@ -661,7 +651,6 @@ def __repr__(self) -> str: "enabled_custom_ops": True, "disabled_custom_ops": True, "compilation_time": True, - "bs_to_padded_graph_size": True, "traced_files": True, "inductor_compile_config": { "post_grad_custom_post_pass": True, @@ -882,7 +871,6 @@ def post_init_cudagraph_sizes(self) -> None: """To complete the initialization after cudagraph related configs are set. This includes: - initialize compile_sizes - - pre-compute the mapping bs_to_padded_graph_size """ computed_compile_sizes = [] @@ -906,23 +894,6 @@ def post_init_cudagraph_sizes(self) -> None: if self.cudagraph_capture_sizes: assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size - # May get recomputed in the model runner if adjustment is needed for spec-decode - self.compute_bs_to_padded_graph_size() - - # Validate that compile_sizes won't be changed by padding. - # Only validate when cudagraphs are actually being used. - if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE: - for size in self.compile_sizes: - if size <= self.max_cudagraph_capture_size: - padded = self.bs_to_padded_graph_size[size] - if padded != size: - raise ValueError( - f"compile_sizes contains {size} which would be " - f"padded to {padded}. All compile_sizes must be " - "values that won't be changed by cudagraph padding. " - "Use values from cudagraph_capture_sizes." - ) - def set_splitting_ops_for_v1( self, all2all_backend: str, data_parallel_size: int = 1 ): @@ -1134,24 +1105,6 @@ def adjust_cudagraph_sizes_for_spec_decode( self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes - # Recompute after adjusting the cudagraph sizes - self.compute_bs_to_padded_graph_size() - - def compute_bs_to_padded_graph_size(self): - # pre-compute the mapping from batch size to padded graph size - self.bs_to_padded_graph_size = [ - 0 for i in range(self.max_cudagraph_capture_size + 1) - ] - for end, start in zip( - self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1], - [0] + self.cudagraph_capture_sizes, - ): - for bs in range(start, end): - if bs == start: - self.bs_to_padded_graph_size[bs] = start - else: - self.bs_to_padded_graph_size[bs] = end - def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" if self.compile_ranges_split_points is None: diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 8a3500c0aac6..3e47e98c150c 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -57,13 +57,47 @@ def __init__(self, vllm_config: VllmConfig): ) self.keys_initialized = False + # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called + self.cudagraph_mode = CUDAGraphMode.NONE + + def _compute_bs_to_padded_graph_size(self) -> None: + """Pre-compute the mapping from batch size to padded graph size.""" + max_size = self.compilation_config.max_cudagraph_capture_size + capture_sizes = self.compilation_config.cudagraph_capture_sizes + self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1) + for end, start in zip( + capture_sizes + [max_size + 1], + [0] + capture_sizes, + ): + for bs in range(start, end): + if bs == start: + self._bs_to_padded_graph_size[bs] = start + else: + self._bs_to_padded_graph_size[bs] = end + + # Validate that compile_sizes won't be changed by padding. + # Only validate when cudagraphs are actually being used. + if ( + self.compilation_config.compile_sizes + and self.cudagraph_mode != CUDAGraphMode.NONE + ): + for size in self.compilation_config.compile_sizes: + if size <= self.compilation_config.max_cudagraph_capture_size: + padded = self._bs_to_padded_graph_size[size] + if padded != size: + raise ValueError( + f"compile_sizes contains {size} which would be " + f"padded to {padded}. All compile_sizes must be " + "values that won't be changed by cudagraph padding. " + "Use values from cudagraph_capture_sizes." + ) def _create_padded_batch_descriptor( self, num_tokens: int, uniform_decode: bool, has_lora: bool ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len - num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens) + num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL): num_reqs = num_tokens_padded // uniform_decode_query_len @@ -88,12 +122,19 @@ def add_cudagraph_key( self.cudagraph_keys[runtime_mode].add(batch_descriptor) def initialize_cudagraph_keys( - self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int + self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1 ): # This should be called only after attention backend is initialized. So we can # get the correct cudagraph mode after backend support is resolved. self.cudagraph_mode = cudagraph_mode + # Early exit if cudagraphs are disabled + if cudagraph_mode == CUDAGraphMode.NONE: + self.keys_initialized = True + return + + self._compute_bs_to_padded_graph_size() + # LoRA activation cases to specialize the cuda graphs on if self.vllm_config.lora_config: if self.compilation_config.cudagraph_specialize_lora: @@ -143,15 +184,24 @@ def initialize_cudagraph_keys( def dispatch( self, num_tokens: int, - uniform_decode: bool, - has_lora: bool, + uniform_decode: bool = False, + has_lora: bool = False, disable_full: bool = False, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ - Given conditions(e.g.,batch descriptor and if using cascade attention), + Given conditions(e.g.,batch descriptor and if using piecewise only), dispatch to a cudagraph runtime mode and the valid batch descriptor. A new batch descriptor is returned as we might dispatch a uniform batch to a graph that supports a more general batch (uniform to non-uniform). + + Args: + num_tokens: Number of tokens in the batch. + uniform_decode: Whether the batch is uniform decode (i.e. uniform and query + length is uniform_decode_query_len). + has_lora: Whether LoRA is active. + disable_full: If True, skip FULL cudagraph checks and + return PIECEWISE or NONE only. (can be used for features like + cascade attention that are not supported by full cudagraphs) """ if ( not self.keys_initialized diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index ff34afb168da..8ce53a2933de 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -9,7 +9,6 @@ import torch.nn as nn from vllm.config import ( - CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, @@ -36,6 +35,7 @@ TreeAttentionMetadataBuilder, ) from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.sampler import _SAMPLING_EPS @@ -100,24 +100,13 @@ def __init__( self._get_eagle3_use_aux_hidden_state_from_config() ) - self.use_cuda_graph = False - self.compilation_config = self.vllm_config.compilation_config - if self.compilation_config.mode == CompilationMode.VLLM_COMPILE: - cudagraph_mode = self.compilation_config.cudagraph_mode - if cudagraph_mode != CUDAGraphMode.NONE and not cudagraph_mode.has_mode( - CUDAGraphMode.PIECEWISE - ): - logger.warning( - "Currently the eagle proposer only supports cudagraph_mode " - "PIECEWISE, if you want the drafter to use cuda graphs, " - "please set compilation_config.cudagraph_mode to PIECEWISE " - "or FULL_AND_PIECEWISE" - ) - self.use_cuda_graph = ( - cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE) - and not self.speculative_config.enforce_eager - ) + + # Cudagraph dispatcher for PIECEWISE-only dispatching in eagle. + # Keys are initialized later via initialize_cudagraph_keys() called from + # gpu_model_runner._check_and_update_cudagraph_mode after + # adjust_cudagraph_sizes_for_spec_decode is called. + self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) # persistent buffers for cuda graph self.input_ids = torch.zeros( @@ -234,6 +223,23 @@ def _set_positions(self, num_tokens: int, positions: torch.Tensor): else: self.positions[:num_tokens] = positions + def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: + """Initialize cudagraph dispatcher keys for eagle. + + Eagle only supports PIECEWISE cudagraphs (via mixed_mode). + This should be called after adjust_cudagraph_sizes_for_spec_decode. + """ + if ( + not self.speculative_config.enforce_eager + and cudagraph_mode.mixed_mode() + in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL] + ): + eagle_cudagraph_mode = CUDAGraphMode.PIECEWISE + else: + eagle_cudagraph_mode = CUDAGraphMode.NONE + + self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) + def propose( self, # [num_tokens] @@ -304,16 +310,10 @@ def propose( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens ) - cudagraph_runtime_mode = CUDAGraphMode.NONE - if ( - self.use_cuda_graph - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -412,16 +412,10 @@ def propose( num_tokens_unpadded=batch_size, num_tokens_padded=batch_size ) - if ( - self.use_cuda_graph - and batch_size_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size_dp_padded) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - input_batch_size = batch_size_dp_padded - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + batch_size_dp_padded + ) + input_batch_size = batch_desc.num_tokens if batch_size_across_dp is not None: batch_size_across_dp[self.dp_rank] = input_batch_size @@ -870,15 +864,10 @@ def propose_tree( self.positions[:num_tokens] = tree_positions.view(-1) self.hidden_states[:num_tokens] = tree_hidden_states.view(num_tokens, -1) - if ( - self.use_cuda_graph - and num_tokens <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) - cudagraph_runtime_mode = CUDAGraphMode.PIECEWISE - else: - num_input_tokens = num_tokens - cudagraph_runtime_mode = CUDAGraphMode.NONE + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens + ) + num_input_tokens = batch_desc.num_tokens # Run the model. with set_forward_context( per_layer_attn_metadata, @@ -1216,9 +1205,6 @@ def dummy_run( use_cudagraphs: bool = True, is_graph_capturing: bool = False, ) -> None: - # Determine if CUDA graphs should be used for this run. - cudagraphs_enabled = use_cudagraphs and self.use_cuda_graph - # FIXME: when using tree-based specdec, adjust number of forward-passes # according to the depth of the tree. for fwd_idx in range( @@ -1228,16 +1214,10 @@ def dummy_run( num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp( num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens ) - if ( - cudagraphs_enabled - and num_tokens_dp_padded - <= self.compilation_config.max_cudagraph_capture_size - ): - num_input_tokens = self.vllm_config.pad_for_cudagraph( - num_tokens_dp_padded - ) - else: - num_input_tokens = num_tokens_dp_padded + cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( + num_tokens_dp_padded + ) + num_input_tokens = batch_desc.num_tokens if num_tokens_across_dp is not None: num_tokens_across_dp[self.dp_rank] = num_input_tokens @@ -1246,9 +1226,7 @@ def dummy_run( self.vllm_config, num_tokens=num_input_tokens, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE - if cudagraphs_enabled - else CUDAGraphMode.NONE, + cudagraph_runtime_mode=cudagraph_runtime_mode, ): if self.supports_mm_inputs: input_ids = None @@ -1340,7 +1318,8 @@ def _pad_batch_across_dp( num_tokens_unpadded=num_tokens_unpadded, parallel_config=self.vllm_config.parallel_config, allow_microbatching=False, - allow_dp_padding=self.use_cuda_graph, + allow_dp_padding=self.cudagraph_dispatcher.cudagraph_mode + != CUDAGraphMode.NONE, num_tokens_padded=num_tokens_padded, uniform_decode=None, num_scheduled_tokens_per_request=None, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 12f74dbcaa8e..1fdc689f0c07 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2140,15 +2140,11 @@ def _prepare_kv_sharing_fast_prefill( self.kv_sharing_fast_prefill_logits_indices[num_logits:].fill_( logits_indices[-1].item() ) - if ( - self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE - and num_logits <= self.cudagraph_batch_sizes[-1] - ): - # Use piecewise CUDA graphs. - # Add padding to the batch size. - num_logits_padded = self.vllm_config.pad_for_cudagraph(num_logits) - else: - num_logits_padded = num_logits + # Dispatch for the decoder portion of the model. + _, batch_desc = self.cudagraph_dispatcher.dispatch( + num_logits, disable_full=True + ) + num_logits_padded = batch_desc.num_tokens logits_indices_padded = self.kv_sharing_fast_prefill_logits_indices[ :num_logits_padded ] @@ -5210,6 +5206,11 @@ def _check_and_update_cudagraph_mode( cudagraph_mode, self.uniform_decode_query_len ) + # Initialize eagle's cudagraph dispatcher if using eagle spec decode. + if self.speculative_config and self.speculative_config.use_eagle(): + assert isinstance(self.drafter, EagleProposer) + self.drafter.initialize_cudagraph_keys(cudagraph_mode) + def calculate_reorder_batch_threshold(self) -> None: """ Choose the minimum reorder batch threshold from all attention groups.