diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index b829c31e7dbe..1e32e9061885 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -97,6 +97,9 @@ def is_valid_runtime_mode(self) -> bool: def __str__(self) -> str: return self.name + def __bool__(self) -> bool: + return self != CUDAGraphMode.NONE + @config class PassConfig: diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index b06a35805799..5a1edc076b4a 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -104,19 +104,24 @@ def apply_staged_writes(self) -> None: self.num_blocks.copy_to_uva() def gather_block_tables( - self, idx_mapping: torch.Tensor + self, + idx_mapping: torch.Tensor, + num_reqs_padded: int, ) -> tuple[torch.Tensor, ...]: num_reqs = idx_mapping.shape[0] - _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs)]( + # Launch kernel with num_reqs_padded to fuse zeroing of padded rows. + _gather_block_tables_kernel[(self.num_kv_cache_groups, num_reqs_padded)]( idx_mapping, self.block_table_ptrs, self.input_block_table_ptrs, self.block_table_strides, self.num_blocks.gpu, self.num_blocks.gpu.stride(0), + num_reqs, + self.input_block_tables[0].shape[1], # max_num_blocks BLOCK_SIZE=1024, # type: ignore ) - return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) + return tuple(bt[:num_reqs_padded] for bt in self.input_block_tables) def get_dummy_block_tables(self, num_reqs: int) -> tuple[torch.Tensor, ...]: # NOTE(woosuk): The output may be used for CUDA graph capture. @@ -130,6 +135,7 @@ def compute_slot_mappings( idx_mapping: torch.Tensor, query_start_loc: torch.Tensor, positions: torch.Tensor, + num_tokens_padded: int, ) -> torch.Tensor: num_reqs = idx_mapping.shape[0] num_tokens = positions.shape[0] @@ -151,7 +157,7 @@ def compute_slot_mappings( PAD_ID=PAD_SLOT_ID, TRITON_BLOCK_SIZE=1024, # type: ignore ) - return self.slot_mappings[:, :num_tokens] + return self.slot_mappings[:, :num_tokens_padded] def get_dummy_slot_mappings(self, num_tokens: int) -> torch.Tensor: # Fill the entire slot_mappings tensor, not just the first `num_tokens` entries. @@ -173,21 +179,31 @@ def _gather_block_tables_kernel( block_table_strides, # [num_kv_cache_groups] num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs] num_blocks_stride, + num_reqs, # actual number of requests (for padding) + max_num_blocks, # stride for zeroing padded rows BLOCK_SIZE: tl.constexpr, ): # kv cache group id group_id = tl.program_id(0) batch_idx = tl.program_id(1) - req_idx = tl.load(batch_idx_to_req_idx + batch_idx) + stride = tl.load(block_table_strides + group_id) + dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) + dst_row_ptr = dst_block_table_ptr + batch_idx * stride + + if batch_idx >= num_reqs: + # Zero out padded rows. + for i in tl.range(0, max_num_blocks, BLOCK_SIZE): + offset = i + tl.arange(0, BLOCK_SIZE) + tl.store(dst_row_ptr + offset, 0, mask=offset < max_num_blocks) + return + + req_idx = tl.load(batch_idx_to_req_idx + batch_idx) group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride num_blocks = tl.load(group_num_blocks_ptr + req_idx) - stride = tl.load(block_table_strides + group_id) src_block_table_ptr = _load_ptr(src_block_table_ptrs + group_id, tl.int32) src_row_ptr = src_block_table_ptr + req_idx * stride - dst_block_table_ptr = _load_ptr(dst_block_table_ptrs + group_id, tl.int32) - dst_row_ptr = dst_block_table_ptr + batch_idx * stride for i in tl.range(0, num_blocks, BLOCK_SIZE): offset = i + tl.arange(0, BLOCK_SIZE) diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index b4e7773cd4c0..2ec3cb2a2e15 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections import defaultdict from collections.abc import Callable +from dataclasses import dataclass from typing import Any import torch @@ -11,235 +13,260 @@ from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import BatchDescriptor, set_forward_context +from vllm.logger import init_logger from vllm.model_executor.offloader.base import get_offloader -from vllm.utils.math_utils import cdiv +from vllm.platforms import current_platform from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens -from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class BatchExecutionDescriptor: + """Describes the shape of the batch and CG mode to run; this is used to make shape + matches between the capture and runtime.""" + + cg_mode: CUDAGraphMode + num_tokens: int + num_reqs: int | None # None means no request padding is needed (PIECEWISE graphs) + uniform_token_count: int | None = None + + +def _is_compatible( + desc: BatchExecutionDescriptor, + num_reqs: int, + num_tokens: int, + uniform_token_count: int | None, +) -> bool: + # desc.uniform_token_count=None (PIECEWISE) can handle any uniform_token_count + # desc.num_reqs=None means no request padding needed (PIECEWISE) + return ( + ( + desc.uniform_token_count is None + or desc.uniform_token_count == uniform_token_count + ) + and (desc.num_reqs is None or desc.num_reqs >= num_reqs) + and desc.num_tokens >= num_tokens + ) + + +def get_uniform_token_count( + num_reqs: int, + num_tokens: int, + max_query_len: int, +) -> int | None: + """ + Return the uniform token count if batch is uniform, else None. + A batch is uniform if all requests have the same number of tokens. + """ + if (max_query_len == num_tokens // num_reqs) and ( + num_tokens == max_query_len * num_reqs + ): + return max_query_len + return None + class CudaGraphManager: def __init__( self, vllm_config: VllmConfig, - use_aux_hidden_state_outputs: bool, device: torch.device, + cudagraph_mode: CUDAGraphMode, + decode_query_len: int, ): self.vllm_config = vllm_config - self.scheduler_config = vllm_config.scheduler_config - self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs self.device = device - - self.max_model_len = vllm_config.model_config.max_model_len - self.max_num_reqs = self.scheduler_config.max_num_seqs - self.max_num_tokens = self.scheduler_config.max_num_batched_tokens + self.max_num_reqs = vllm_config.scheduler_config.max_num_seqs + self.compilation_config = vllm_config.compilation_config + assert self.compilation_config is not None + self.cudagraph_mode = cudagraph_mode + self.decode_query_len = decode_query_len self.dp_size = vllm_config.parallel_config.data_parallel_size - self.uniform_decode_query_len = 1 - spec_config = vllm_config.speculative_config - if spec_config is not None: - self.uniform_decode_query_len += spec_config.num_speculative_tokens + self.graphs: dict[BatchExecutionDescriptor, torch.cuda.CUDAGraph] = {} + self.pool = current_platform.get_global_graph_pool() if cudagraph_mode else None - self.compilation_config = vllm_config.compilation_config - assert self.compilation_config is not None - self.cudagraph_mode = self.compilation_config.cudagraph_mode + self._graphs_captured = False + self._candidates: list[list[BatchExecutionDescriptor]] = [] + self._capture_descs: dict[CUDAGraphMode, list[BatchExecutionDescriptor]] = {} + self._init_candidates() - use_uniform_decode_cudagraph = ( - self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and self.cudagraph_mode.separate_routine() - ) - self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes( - self.compilation_config.cudagraph_capture_sizes, - self.max_num_reqs, - self.max_num_tokens, - self.cudagraph_mode, - self.uniform_decode_query_len, - use_uniform_decode_cudagraph, - ) + def _init_candidates(self) -> None: + """Build priority-ordered candidate lists for each token count.""" + capture_sizes = self.compilation_config.cudagraph_capture_sizes + if not (self.cudagraph_mode and capture_sizes): + return - self.graphs: dict[int, torch.cuda.CUDAGraph] = {} - self.pool = None - if self.cudagraph_mode != CUDAGraphMode.NONE: - self.pool = torch.cuda.graph_pool_handle() - self.hidden_states: torch.Tensor | None = None - self.aux_hidden_states: list[torch.Tensor] = [] + capture_sizes = sorted(capture_sizes) + max_decode_tokens = self.max_num_reqs * self.decode_query_len + decode_mode = self.cudagraph_mode.decode_mode() + mixed_mode = self.cudagraph_mode.mixed_mode() + separate_decode_routine = self.cudagraph_mode.separate_routine() + + descs_by_token_count = defaultdict(list) + descs_by_mode = defaultdict(list) + + for num_tokens in capture_sizes: + # Capture uniform decode specfifc graphs if required + # (i.e. separate decode routine) + if ( + separate_decode_routine + and decode_mode + and self.decode_query_len <= num_tokens <= max_decode_tokens + ): + desc = BatchExecutionDescriptor( + cg_mode=decode_mode, + num_tokens=num_tokens, + num_reqs=num_tokens // self.decode_query_len, + uniform_token_count=self.decode_query_len, + ) + descs_by_mode[decode_mode].append(desc) + descs_by_token_count[num_tokens].append(desc) + + if mixed_mode: + # for PIECEWISE graphs there is no limit on requests when replaying + # i.e. no request padding is needed + # so we leave it as None + num_reqs = ( + min(num_tokens, self.max_num_reqs) + if mixed_mode == CUDAGraphMode.FULL + else None + ) + desc = BatchExecutionDescriptor( + cg_mode=mixed_mode, + num_tokens=num_tokens, + num_reqs=num_reqs, + ) + descs_by_mode[mixed_mode].append(desc) + descs_by_token_count[num_tokens].append(desc) + + if not descs_by_token_count: + return + + sorted_padded = sorted(descs_by_token_count.keys()) + self._candidates = [[] for _ in range(sorted_padded[-1] + 1)] + + current_range_start = 0 + for cg_size in sorted_padded: + for i in range(current_range_start, cg_size + 1): + self._candidates[i] = descs_by_token_count[cg_size] + current_range_start = cg_size + 1 + + for mode, descs in descs_by_mode.items(): + descs.sort(key=lambda d: d.num_tokens, reverse=True) + self._capture_descs[mode] = descs def needs_capture(self) -> bool: - return len(self.cudagraph_sizes) > 0 - - def get_cudagraph_size( - self, num_tokens: int, uniform_decode: bool = False - ) -> int | None: - if uniform_decode and self.uniform_decode_cudagraph_sizes: - return self.uniform_decode_cudagraph_sizes.get(num_tokens) - return self.cudagraph_sizes.get(num_tokens) + return len(self._capture_descs) > 0 - def capture_graph( + @torch.inference_mode() + def capture( self, - num_tokens: int, - capture_cg_mode: CUDAGraphMode, - model: nn.Module, - model_state: ModelState, - input_buffers: InputBuffers, - block_tables: BlockTables, - attn_groups: list[list[AttentionGroup]], - kv_cache_config: KVCacheConfig, - has_lora: bool = False, - uniform_decode: bool = False, + create_forward_fn: Callable[ + [BatchExecutionDescriptor], Callable[[CUDAGraphMode], None] + ], + progress_bar_desc: str = "Capturing CUDA graphs", ) -> None: - # select and check capture function - assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( - f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}" - ) - if capture_cg_mode == CUDAGraphMode.PIECEWISE: - capture_fn = self._capture_piecewise_graph - else: - capture_fn = self._capture_full_graph - # prepare inputs - if uniform_decode: - num_reqs = min( - cdiv(num_tokens, self.uniform_decode_query_len), - self.max_num_reqs, - ) - else: - num_reqs = min(num_tokens, self.max_num_reqs) - - model_inputs = { - "input_ids": input_buffers.input_ids[:num_tokens], - "positions": input_buffers.positions[:num_tokens], - # NOTE: Values returned by `prepare_dummy_inputs` will override the - # default values above. - **model_state.prepare_dummy_inputs(num_reqs, num_tokens), - } - - attn_metadata, slot_mappings = prepare_inputs_to_capture( - num_reqs, - num_tokens, - model_state, - input_buffers, - block_tables, - attn_groups, - kv_cache_config, - ) - num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) - - # Warm up. - with set_forward_context( - attn_metadata, - self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - num_tokens_across_dp=num_tokens_across_dp, - slot_mapping=slot_mappings, - ): - model_output = model(**model_inputs) - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Allocate output buffers if not already done. - if self.hidden_states is None: - self.hidden_states = torch.empty_like(hidden_states) - if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: - self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states] - - capture_fn( - num_tokens=num_tokens, - num_reqs=num_reqs, - model=model, - model_inputs=model_inputs, - num_tokens_across_dp=num_tokens_across_dp, - attn_metadata=attn_metadata, - slot_mappings=slot_mappings, - has_lora=has_lora, - ) - - def _capture_full_graph( + """Capture CUDA graphs. + + Args: + create_forward_fn: Factory that prepares inputs (OUTSIDE graph) and + returns a function that runs forward with a given CUDAGraphMode. + """ + with graph_capture(device=self.device): + # Capture in order: PIECEWISE first, then FULL. PIECEWISE has larger + # activations so FULL activations should fit in already allocated + # buffers in the graph pool. + for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]: + if mode not in self._capture_descs: + continue + + descs = self._capture_descs[mode] + if is_global_first_rank(): + descs = tqdm(descs, desc=f"{progress_bar_desc} ({mode.name})") + for desc in descs: + # Prepare inputs and get forward function + forward_fn = create_forward_fn(desc) + + # Warmup + forward_fn(CUDAGraphMode.NONE) + + # Capture + logger.debug( + "CG Capture: mode=%s, batch_desc=%s", desc.cg_mode.name, desc + ) + if desc.cg_mode == CUDAGraphMode.PIECEWISE: + forward_fn(CUDAGraphMode.PIECEWISE) + else: + assert desc not in self.graphs, ( + f"Graph already captured for {desc}" + ) + graph = torch.cuda.CUDAGraph() + # Sync offloader's copy stream before capture. + # Ensure any pre-capture prefetches from offloader are complete. + get_offloader().sync_prev_onload() + with torch.cuda.graph(graph, self.pool): + forward_fn(CUDAGraphMode.NONE) + # Join offloader's copy stream after forward to avoid + # unjoined stream error. The last layer's start_prefetch + # forks copy_stream, but wait_prefetch only happens in + # the next forward pass. + get_offloader().join_after_forward() + self.graphs[desc] = graph + self._graphs_captured = True + + def dispatch( self, - num_tokens: int, num_reqs: int, - model: nn.Module, - model_inputs: dict[str, torch.Tensor | None], - num_tokens_across_dp: torch.Tensor, - attn_metadata: dict[str, Any] | None, - slot_mappings: dict[str, torch.Tensor] | None, - has_lora: bool = False, - ) -> None: - assert attn_metadata is not None - # Capture the graph. - assert num_tokens not in self.graphs - graph = torch.cuda.CUDAGraph() + num_tokens: int, + uniform_token_count: int | None, + ) -> BatchExecutionDescriptor: + """Find matching cudagraph descriptor from priority-ordered candidates.""" + if self._graphs_captured and 0 < num_tokens < len(self._candidates): + for desc in self._candidates[num_tokens]: + if _is_compatible(desc, num_reqs, num_tokens, uniform_token_count): + return desc + return BatchExecutionDescriptor( + cg_mode=CUDAGraphMode.NONE, num_tokens=num_tokens, num_reqs=num_reqs + ) - # Sync offloader's copy stream before capture. - # Ensure any pre-capture prefetches from offloader are complete. + def run_fullgraph(self, desc: BatchExecutionDescriptor): + """Replay a captured FULL cudagraph.""" + assert desc.cg_mode == CUDAGraphMode.FULL, ( + f"Expected FULL mode, got {desc.cg_mode}" + ) + assert desc in self.graphs, f"No cudagraph for {desc}" + # Sync offloader before replay - needed when transitioning from + # eager/piecewise to full cudagraph (e.g., prefill → decode). + # The previous eager iteration's start_prefetch may have queued + # H2D copies on copy_stream that the graph's captured events + # cannot see. Without this, replay could overwrite static buffers + # while those copies are still in flight. get_offloader().sync_prev_onload() + self.graphs[desc].replay() + + +class ModelCudaGraphManager(CudaGraphManager): + """CudaGraphManager with model-specific capture and hidden state management.""" - with ( - set_forward_context( - attn_metadata=attn_metadata, - vllm_config=self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - num_tokens_across_dp=num_tokens_across_dp, - slot_mapping=slot_mappings, - ), - torch.cuda.graph(graph, self.pool), - ): - model_output = model(**model_inputs) - - # Join offloader's copy stream after forward to avoid unjoined - # stream error. The last layer's start_prefetch forks copy_stream, - # but wait_prefetch only happens in the next forward pass. - get_offloader().join_after_forward() - - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = model_output - else: - hidden_states = model_output - aux_hidden_states = None - - # Copy outputs to the output buffers. - assert self.hidden_states is not None - self.hidden_states[:num_tokens] = hidden_states - if self.use_aux_hidden_state_outputs: - for i, aux_hidden in enumerate(aux_hidden_states): - self.aux_hidden_states[i][:num_tokens] = aux_hidden - self.graphs[num_tokens] = graph - - def _capture_piecewise_graph( + def __init__( self, - num_tokens: int, - num_reqs: int, - model: nn.Module, - model_inputs: dict[str, torch.Tensor | None], - num_tokens_across_dp: torch.Tensor, - attn_metadata: dict[str, Any] | None, - slot_mappings: dict[str, torch.Tensor] | None, - has_lora: bool = False, - ) -> None: - # create batch descriptor for piecewise cudagraph dispatch key - batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora) - - # Capture run - CUDAGraphWrapper inside torch.compile will auto capture. - with set_forward_context( - attn_metadata=None, # piecewise no need attn_metadata - vllm_config=self.vllm_config, - num_tokens=num_tokens, - cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, - num_tokens_across_dp=num_tokens_across_dp, - batch_descriptor=batch_descriptor, - slot_mapping=slot_mappings, - ): - model(**model_inputs) + vllm_config: VllmConfig, + device: torch.device, + cudagraph_mode: CUDAGraphMode, + decode_query_len: int, + ): + super().__init__(vllm_config, device, cudagraph_mode, decode_query_len) + self.hidden_states: torch.Tensor | None = None + self.aux_hidden_states: list[torch.Tensor] = [] + self.use_aux_hidden_state_outputs = False - @torch.inference_mode() def capture( self, model: nn.Module, @@ -249,139 +276,82 @@ def capture( attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, + use_aux_hidden_state_outputs: bool = False, + progress_bar_desc: str = "Capturing CUDA graphs", ) -> None: - common_kwargs = dict( - device=self.device, - capture_fn=self.capture_graph, - model=model, - model_state=model_state, - input_buffers=input_buffers, - block_tables=block_tables, - attn_groups=attn_groups, - kv_cache_config=kv_cache_config, - has_lora=has_lora, - ) + """Capture CUDA graphs for model forward pass.""" + self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs - # Phase 1: Capture for mixed prefill-decode batches if needed. - mixed_mode = self.cudagraph_mode.mixed_mode() - if mixed_mode != CUDAGraphMode.NONE: - capture_graphs( - cudagraph_sizes=self.cudagraph_sizes, - capture_cudagraph_mode=mixed_mode, - desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})", - uniform_decode=False, - **common_kwargs, + def create_forward_fn( + desc: BatchExecutionDescriptor, + ) -> Callable[[CUDAGraphMode], None]: + num_tokens = desc.num_tokens + num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs) + num_tokens_across_dp = ( + torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu") + if self.dp_size > 1 + else None ) - - # Phase 2: Capture FULL graphs for uniform decode batches if needed. - # This is only needed if we use a separate routine for decode batches - # and the decode_mode is FULL. - if self.uniform_decode_cudagraph_sizes: - capture_graphs( - cudagraph_sizes=self.uniform_decode_cudagraph_sizes, - capture_cudagraph_mode=CUDAGraphMode.FULL, - desc="Capturing CUDA graphs (decode, FULL)", - uniform_decode=True, - **common_kwargs, + attn_metadata, slot_mappings = prepare_inputs_to_capture( + num_reqs, + num_tokens, + model_state, + input_buffers, + block_tables, + attn_groups, + kv_cache_config, ) - def get_cudagraph_runtime_mode( - self, num_reqs: int, num_tokens: int, max_query_len: int - ) -> tuple[CUDAGraphMode, int | None]: - is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( - num_tokens == max_query_len * num_reqs - ) - - cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode) - if cudagraph_size is None: - cudagraph_mode = CUDAGraphMode.NONE - elif is_uniform_decode: - cudagraph_mode = self.cudagraph_mode.decode_mode() - else: - cudagraph_mode = self.cudagraph_mode.mixed_mode() - - if ( - cudagraph_mode == CUDAGraphMode.FULL - and cudagraph_size is not None - and cudagraph_size not in self.graphs - ): - # If graph wasn't captured yet, fall back to eager. - # This might happen when the dummy run is called before capture. - cudagraph_mode = CUDAGraphMode.NONE - cudagraph_size = None - return cudagraph_mode, cudagraph_size + def forward_fn(cg_mode: CUDAGraphMode) -> None: + batch_descriptor = ( + BatchDescriptor(num_tokens=num_tokens) + if cg_mode == CUDAGraphMode.PIECEWISE + else None + ) + with set_forward_context( + attn_metadata if cg_mode != CUDAGraphMode.PIECEWISE else None, + self.vllm_config, + num_tokens=num_tokens, + cudagraph_runtime_mode=cg_mode, + num_tokens_across_dp=num_tokens_across_dp, + slot_mapping=slot_mappings, + batch_descriptor=batch_descriptor, + ): + model_inputs = { + "input_ids": input_buffers.input_ids[:num_tokens], + "positions": input_buffers.positions[:num_tokens], + **model_state.prepare_dummy_inputs(num_reqs, num_tokens), + } + model_output = model(**model_inputs) + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = model_output + else: + hidden_states = model_output + aux_hidden_states = [] + if self.hidden_states is None: + self.hidden_states = torch.empty_like(hidden_states) + if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: + self.aux_hidden_states = [ + torch.empty_like(x) for x in aux_hidden_states + ] + self.hidden_states[:num_tokens] = hidden_states + for i, aux in enumerate(aux_hidden_states): + self.aux_hidden_states[i][:num_tokens] = aux + + return forward_fn + + super().capture(create_forward_fn, progress_bar_desc) def run_fullgraph( - self, num_tokens: int + self, desc: BatchExecutionDescriptor ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: - assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens" - # Sync offloader before replay - needed when transitioning from - # eager/piecewise to full cudagraph (e.g., prefill → decode). - # The previous eager iteration's start_prefetch may have queued - # H2D copies on copy_stream that the graph's captured events - # cannot see. Without this, replay could overwrite static buffers - # while those copies are still in flight. - get_offloader().sync_prev_onload() - self.graphs[num_tokens].replay() + """Replay a captured FULL cudagraph and return hidden states.""" + super().run_fullgraph(desc) assert self.hidden_states is not None - hidden_states = self.hidden_states[:num_tokens] + hidden_states = self.hidden_states[: desc.num_tokens] if not self.use_aux_hidden_state_outputs: return hidden_states - return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states] - - -def get_cudagraph_sizes( - capture_sizes: list[int] | None, - max_num_reqs: int, - max_num_tokens: int, - cudagraph_mode: CUDAGraphMode, - uniform_decode_query_len: int = 1, - uniform_decode_cudagraph: bool = False, -) -> tuple[dict[int, int], dict[int, int]]: - # Support both FULL and PIECEWISE cudagraph modes - if cudagraph_mode == CUDAGraphMode.NONE: - return {}, {} - if not capture_sizes: - return {}, {} - - capture_sizes = sorted(capture_sizes) - if not capture_sizes: - return {}, {} - - cudagraph_sizes: dict[int, int] = {} - for i in range(1, capture_sizes[-1] + 1): - for x in capture_sizes: - if i <= x: - cudagraph_sizes[i] = x - break - - uniform_decode_cudagraph_sizes: dict[int, int] = {} - if uniform_decode_cudagraph: - max_num_tokens = max_num_reqs * uniform_decode_query_len - uniform_decode_cudagraph_sizes = { - k: v - for k, v in cudagraph_sizes.items() - if v <= max_num_tokens and v >= uniform_decode_query_len - } - return cudagraph_sizes, uniform_decode_cudagraph_sizes - - -def capture_graphs( - cudagraph_sizes: dict[int, int], - device: torch.device, - capture_fn: Callable, - capture_cudagraph_mode: CUDAGraphMode, - desc: str = "Capturing CUDA graphs", - **capture_kwargs, -) -> None: - # Capture larger graphs first. - sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True) - if is_global_first_rank(): - sizes_to_capture = tqdm(sizes_to_capture, desc=desc) - - with graph_capture(device=device): - for size in sizes_to_capture: - capture_fn(size, capture_cudagraph_mode, **capture_kwargs) + return hidden_states, [x[: desc.num_tokens] for x in self.aux_hidden_states] def prepare_inputs_to_capture( diff --git a/vllm/v1/worker/gpu/dp_utils.py b/vllm/v1/worker/gpu/dp_utils.py index 724a6c39f90c..f0e2bfcf54b8 100644 --- a/vllm/v1/worker/gpu/dp_utils.py +++ b/vllm/v1/worker/gpu/dp_utils.py @@ -1,9 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import torch import torch.distributed as dist +from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import get_dp_group +from vllm.v1.worker.gpu.cudagraph_utils import ( + BatchExecutionDescriptor, + CudaGraphManager, +) def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | None: @@ -12,66 +19,63 @@ def make_num_tokens_across_dp(dp_size: int, num_tokens: int) -> torch.Tensor | N return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu") -def get_batch_metadata_across_dp( +def sync_cudagraph_and_dp_padding( + cudagraph_manager: CudaGraphManager, + desired_batch_desc: BatchExecutionDescriptor, num_tokens: int, - cudagraph_size: int, - cudagraph_runtime_mode: int, + num_reqs: int, + uniform_token_count: int | None, dp_size: int, dp_rank: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert dp_size > 1 - # Use CPU group to avoid CPU-GPU synchronization. +) -> tuple[BatchExecutionDescriptor, torch.Tensor | None]: + """ + Coordinates the batch descriptor and DP padding across all ranks. + + Returns (synced_batch_desc, num_tokens_across_dp). + """ + assert dp_size > 1, "DP size must be greater than 1" group = get_dp_group().cpu_group tensor = torch.zeros(3, dp_size, dtype=torch.int32, device="cpu") tensor[0][dp_rank] = num_tokens - tensor[1][dp_rank] = cudagraph_size - tensor[2][dp_rank] = cudagraph_runtime_mode + tensor[1][dp_rank] = desired_batch_desc.cg_mode.value + tensor[2][dp_rank] = uniform_token_count or 0 # (0 means None) dist.all_reduce(tensor, group=group) - return tensor[0], tensor[1], tensor[2] + num_tokens_across_dp = tensor[0] + cg_mode_across_dp = tensor[1] + uniform_token_counts_across_dp = tensor[2] -def get_cudagraph_and_dp_padding( - num_tokens: int, - cudagraph_size: int | None, - cudagraph_runtime_mode: int, - dp_size: int, - dp_rank: int, -) -> tuple[int, torch.Tensor | None, int]: - if dp_size == 1: - if cudagraph_size is not None: - return cudagraph_size, None, cudagraph_runtime_mode - else: - return num_tokens, None, cudagraph_runtime_mode + if torch.all(num_tokens_across_dp == 0).item(): + synced_desc = BatchExecutionDescriptor( + cg_mode=CUDAGraphMode.NONE, num_tokens=0, num_reqs=0 + ) + return synced_desc, None - # Convert None to -1 for sync (indicates no cudagraph available) - if num_tokens == 0: - cudagraph_size = 0 - elif cudagraph_size is None: - cudagraph_size = -1 + synced_cg_mode = CUDAGraphMode(int(cg_mode_across_dp.min().item())) - num_tokens_across_dp, cudagraph_size_across_dp, cudagraph_mode_across_dp = ( - get_batch_metadata_across_dp( - num_tokens, cudagraph_size, cudagraph_runtime_mode, dp_size, dp_rank - ) + # If any rank wants to run eager, all ranks run eager + if synced_cg_mode == CUDAGraphMode.NONE: + return BatchExecutionDescriptor( + cg_mode=CUDAGraphMode.NONE, + num_tokens=num_tokens, + num_reqs=num_reqs, + ), num_tokens_across_dp + + synced_num_tokens = int(num_tokens_across_dp.max().item()) + synced_uniform_token_count = uniform_token_counts_across_dp[0] + # If ranks disagree on the uniform token count, or its 0 (means None) set to None + if synced_uniform_token_count == 0 or not torch.all( + uniform_token_counts_across_dp == synced_uniform_token_count + ): + synced_uniform_token_count = None + + # Dispatch for the final synced values, use num_reqs instead of synced_num_reqs + # so we don't perform request padding for PIECEWISE graphs + synced_desc = cudagraph_manager.dispatch( + num_reqs, synced_num_tokens, synced_uniform_token_count ) - if torch.all(num_tokens_across_dp == 0).item(): - # All ranks have zero tokens to run. - return 0, None, 0 - # Synchronize cudagraph_runtime_mode across ranks by taking the minimum. - synced_cudagraph_mode = int(cudagraph_mode_across_dp.min().item()) - # Check if all ranks have valid cudagraph_size. - all_have_cudagraph = torch.all(cudagraph_size_across_dp != -1).item() + # Update num_tokens_across_dp to reflect padded size. + num_tokens_across_dp[:] = synced_desc.num_tokens - if synced_cudagraph_mode != 0 and all_have_cudagraph: - # All ranks use cudagraph. Pad to max cudagraph_size. - max_cudagraph_size = int(cudagraph_size_across_dp.max().item()) - num_tokens_across_dp[:] = max_cudagraph_size - return max_cudagraph_size, num_tokens_across_dp, synced_cudagraph_mode - else: - # Fall back to eager mode (no cudagraph). - # Either some rank doesn't have cudagraph size or mode is NONE. - synced_cudagraph_mode = 0 - num_tokens_across_dp = torch.clamp(num_tokens_across_dp, min=1) - num_tokens_after_padding = int(num_tokens_across_dp[dp_rank].item()) - return num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode + return synced_desc, num_tokens_across_dp diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index 1ca87612edf7..9b8707075fd6 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -37,6 +37,7 @@ class InputBatch: # batch_idx -> req_id req_ids: list[str] num_reqs: int + num_reqs_after_padding: int # batch_idx -> req_state_idx idx_mapping: torch.Tensor @@ -123,6 +124,7 @@ def make_dummy( return cls( req_ids=req_ids, num_reqs=num_reqs, + num_reqs_after_padding=num_reqs, idx_mapping=idx_mapping, idx_mapping_np=idx_mapping_np, expanded_idx_mapping=expanded_idx_mapping, @@ -330,7 +332,8 @@ def combine_sampled_and_draft_tokens( cu_num_logits: torch.Tensor, num_logits: int, ) -> torch.Tensor: - num_reqs = seq_lens.shape[0] + # use idx_mapping.shape[0] for actual request count + num_reqs = idx_mapping.shape[0] num_speculative_steps = draft_tokens.shape[-1] logits_indices = torch.empty( diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 30ab27d190ae..41c2f37042ca 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -40,7 +40,6 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors from vllm.tasks import SupportedTask -from vllm.utils.math_utils import cdiv from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput @@ -57,8 +56,12 @@ from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.buffer_utils import async_copy_to_gpu from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens -from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager -from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding +from vllm.v1.worker.gpu.cudagraph_utils import ( + BatchExecutionDescriptor, + ModelCudaGraphManager, + get_uniform_token_count, +) +from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import ( InputBatch, InputBuffers, @@ -137,6 +140,10 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.is_first_pp_rank = True self.is_last_pp_rank = True + # Data parallelism. + self.dp_size = self.parallel_config.data_parallel_size + self.dp_rank = self.parallel_config.data_parallel_rank + # Decode context parallelism. self.dcp_size = self.parallel_config.decode_context_parallel_size self.use_dcp = self.dcp_size > 1 @@ -193,10 +200,12 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) # CUDA graphs. - self.cudagraph_manager = CudaGraphManager( + self.decode_query_len = self.num_speculative_steps + 1 + self.cudagraph_manager = ModelCudaGraphManager( self.vllm_config, - self.use_aux_hidden_state_outputs, self.device, + self.compilation_config.cudagraph_mode, + decode_query_len=self.decode_query_len, ) # Structured outputs worker. self.structured_outputs_worker = StructuredOutputsWorker( @@ -331,17 +340,18 @@ def _dummy_run( **kwargs, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: # Create a dummy scheduler output. + num_reqs = min(num_tokens, self.max_num_reqs) if uniform_decode: - # Align tokens to uniform_decode_query_len for cudagraph - # compatibility across DP ranks. - query_len = self.cudagraph_manager.uniform_decode_query_len - num_reqs = min(cdiv(num_tokens, query_len), self.max_num_reqs) - num_tokens = num_reqs * query_len - num_tokens_per_request = [query_len] * num_reqs - else: - num_reqs = min(num_tokens, self.max_num_reqs) - num_tokens_per_request = [num_tokens // num_reqs] * num_reqs - num_tokens_per_request[-1] += num_tokens % num_reqs + # HACK(lucas): for now since the worker is shared between MRV1 and MRV2, + # and for spec-decode with MTP we want to make sure the dummy runs use + # 1+num_speculative_tokens we use max here, this will likely be eventually + # changed in the worker: https://github.com/vllm-project/vllm/pull/35243 + num_tokens = max(num_tokens, self.decode_query_len) + num_reqs = num_tokens // self.decode_query_len + assert num_tokens % self.decode_query_len == 0 + num_tokens_per_request = [num_tokens // num_reqs] * num_reqs + num_tokens_per_request[-1] += num_tokens % num_reqs + assert sum(num_tokens_per_request) == num_tokens num_scheduled_tokens = { f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request) @@ -498,13 +508,14 @@ def capture_model(self) -> int: with self.maybe_setup_dummy_loras(self.lora_config): self.cudagraph_manager.capture( - model=self.model, - model_state=self.model_state, - input_buffers=self.input_buffers, - block_tables=self.block_tables, - attn_groups=self.attn_groups, - kv_cache_config=self.kv_cache_config, + self.model, + self.model_state, + self.input_buffers, + self.block_tables, + self.attn_groups, + self.kv_cache_config, has_lora=self.lora_config is not None, + use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs, ) if self.speculator is not None: self.speculator.capture_model() @@ -592,9 +603,10 @@ def update_requests(self, scheduler_output: SchedulerOutput) -> None: ) def prepare_inputs( - self, scheduler_output: SchedulerOutput, num_tokens_after_padding: int + self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor ) -> InputBatch: num_tokens = scheduler_output.total_num_scheduled_tokens + num_tokens_after_padding = batch_desc.num_tokens assert num_tokens > 0 num_tokens_per_req = scheduler_output.num_scheduled_tokens num_reqs = len(num_tokens_per_req) @@ -644,6 +656,8 @@ def prepare_inputs( ) # Get query_start_loc. + # num_reqs_padded is None for PIECEWISE graphs (no request padding needed) + num_reqs_padded = batch_desc.num_reqs or num_reqs query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) query_start_loc_np[0] = 0 np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1]) @@ -651,8 +665,8 @@ def prepare_inputs( # Some attention backends like FA3 require query_start_loc to be non-decreasing. query_start_loc_np[num_reqs + 1 :] = num_tokens async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) - query_start_loc_np = query_start_loc_np[: num_reqs + 1] - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1] + query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1] # Get prefill tokens if any. if self.req_states.any_prefills(idx_mapping_np): @@ -674,7 +688,7 @@ def prepare_inputs( self.input_buffers.positions, self.input_buffers.seq_lens, ) - seq_lens = self.input_buffers.seq_lens[:num_reqs] + seq_lens = self.input_buffers.seq_lens[:num_reqs_padded] dcp_local_seq_lens = None if self.use_dcp: @@ -687,7 +701,7 @@ def prepare_inputs( self.dcp_rank, self.cp_interleave, ) - dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] + dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded] # Some input token ids are directly read from the last sampled tokens # and draft tokens. Also, get the logits indices to sample tokens from. @@ -706,6 +720,7 @@ def prepare_inputs( return InputBatch( req_ids=req_ids, num_reqs=num_reqs, + num_reqs_after_padding=num_reqs_padded, idx_mapping=idx_mapping, idx_mapping_np=idx_mapping_np, expanded_idx_mapping=expanded_idx_mapping, @@ -729,13 +744,18 @@ def prepare_inputs( def prepare_attn( self, input_batch: InputBatch ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]: - # Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks] - block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping) - # Compute slot mappings: [num_kv_cache_groups, num_tokens] + # Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks]. + block_tables = self.block_tables.gather_block_tables( + input_batch.idx_mapping, + num_reqs_padded=input_batch.num_reqs_after_padding, + ) + # Slot mappings: [num_kv_cache_groups, num_tokens_padded]. + # Kernel pads beyond num_tokens with PAD_SLOT_ID. slot_mappings = self.block_tables.compute_slot_mappings( input_batch.idx_mapping, input_batch.query_start_loc, input_batch.positions, + num_tokens_padded=input_batch.num_tokens_after_padding, ) return block_tables, slot_mappings @@ -851,27 +871,29 @@ def execute_model( empty_output = self.kv_connector.no_forward(scheduler_output) return empty_output - # Get local cudagraph mode and size. - local_cudagraph_mode, local_cudagraph_size = ( - self.cudagraph_manager.get_cudagraph_runtime_mode( - num_reqs=len(scheduler_output.num_scheduled_tokens), - num_tokens=scheduler_output.total_num_scheduled_tokens, - max_query_len=max(scheduler_output.num_scheduled_tokens.values()), - ) + # Get batch descriptor and sync across DP ranks. + num_reqs = len(scheduler_output.num_scheduled_tokens) + num_toks = scheduler_output.total_num_scheduled_tokens + max_query_len = max(scheduler_output.num_scheduled_tokens.values()) + uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len) + + batch_desc = self.cudagraph_manager.dispatch( + num_reqs, num_toks, uniform_tok_count ) + num_tokens_across_dp = None - # DP sync: num_tokens + cudagraph_size + cudagraph_mode - num_tokens_after_padding, num_tokens_across_dp, synced_cudagraph_mode = ( - get_cudagraph_and_dp_padding( - scheduler_output.total_num_scheduled_tokens, - local_cudagraph_size, - local_cudagraph_mode.value, - self.parallel_config.data_parallel_size, - self.parallel_config.data_parallel_rank, + if self.dp_size > 1: + batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( + self.cudagraph_manager, + batch_desc, + num_toks, + num_reqs, + uniform_tok_count, + self.dp_size, + self.dp_rank, ) - ) - cudagraph_runtime_mode = CUDAGraphMode(synced_cudagraph_mode) - if num_tokens_after_padding == 0: + + if batch_desc.num_tokens == 0: # All DP ranks have zero tokens to run. empty_output = self.kv_connector.no_forward(scheduler_output) return empty_output @@ -879,9 +901,7 @@ def execute_model( if not dummy_run: # Common case. # Prepare all the inputs and copy to the input buffers. - input_batch = self.prepare_inputs( - scheduler_output, num_tokens_after_padding - ) + input_batch = self.prepare_inputs(scheduler_output, batch_desc) block_tables, slot_mappings = self.prepare_attn(input_batch) if self.lora_config: @@ -894,9 +914,10 @@ def execute_model( self._set_active_loras(*lora_inputs) else: # No actual tokens to run. A dummy run for DP or memory profiling. - num_reqs = min(num_tokens_after_padding, self.max_num_reqs) input_batch = InputBatch.make_dummy( - num_reqs, num_tokens_after_padding, self.input_buffers + batch_desc.num_reqs or num_reqs, + batch_desc.num_tokens, + self.input_buffers, ) if not skip_attn_for_dummy_run: block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) @@ -948,14 +969,12 @@ def execute_model( model_inputs["intermediate_tensors"] = intermediate_tensors # Run model. - if cudagraph_runtime_mode == CUDAGraphMode.FULL: + if batch_desc.cg_mode == CUDAGraphMode.FULL: # Use explicit cudagraph replay for FULL mode. # NOTE(woosuk): Here, we don't need to pass the input tensors, # because they are already copied to the CUDA graph input buffers. self.kv_connector.pre_forward(scheduler_output) - model_output = self.cudagraph_manager.run_fullgraph( - input_batch.num_tokens_after_padding - ) + model_output = self.cudagraph_manager.run_fullgraph(batch_desc) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: @@ -972,7 +991,7 @@ def execute_model( attn_metadata, self.vllm_config, num_tokens=input_batch.num_tokens_after_padding, - cudagraph_runtime_mode=cudagraph_runtime_mode, + cudagraph_runtime_mode=batch_desc.cg_mode, num_tokens_across_dp=num_tokens_across_dp, batch_descriptor=batch_descriptor, slot_mapping=slot_mappings_by_layer, diff --git a/vllm/v1/worker/gpu/model_states/default.py b/vllm/v1/worker/gpu/model_states/default.py index e27916b40663..f0b0e20c5a2e 100644 --- a/vllm/v1/worker/gpu/model_states/default.py +++ b/vllm/v1/worker/gpu/model_states/default.py @@ -142,12 +142,15 @@ def prepare_attn( attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, ) -> dict[str, Any]: + # Use padded sizes - padding is handled by model_runner.prepare_attn. + num_reqs = input_batch.num_reqs_after_padding + num_tokens = input_batch.num_tokens_after_padding query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) max_query_len = input_batch.num_scheduled_tokens.max().item() attn_metadata = build_attn_metadata( attn_groups=attn_groups, - num_reqs=input_batch.num_reqs, - num_tokens=input_batch.num_tokens, + num_reqs=num_reqs, + num_tokens=num_tokens, query_start_loc_gpu=input_batch.query_start_loc, query_start_loc_cpu=query_start_loc_cpu, max_query_len=max_query_len, diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py index 157ed1182485..1e75c48966b2 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/cudagraph.py @@ -1,214 +1,91 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import Any import torch from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode -from vllm.model_executor.offloader.base import get_offloader from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.cudagraph_utils import ( - capture_graphs, - get_cudagraph_sizes, + BatchExecutionDescriptor, + CudaGraphManager, prepare_inputs_to_capture, ) -from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.utils import AttentionGroup -class EagleCudaGraphManager: - def __init__(self, vllm_config: VllmConfig, device: torch.device): - self.vllm_config = vllm_config - self.scheduler_config = vllm_config.scheduler_config - self.device = device +class EagleCudaGraphManager(CudaGraphManager): + """CudaGraphManager for Eagle speculative decoding (FULL mode only).""" - self.max_model_len = vllm_config.model_config.max_model_len - self.max_num_reqs = self.scheduler_config.max_num_seqs - self.max_num_tokens = self.scheduler_config.max_num_batched_tokens - self.dp_size = vllm_config.parallel_config.data_parallel_size - self.compilation_config = vllm_config.compilation_config - assert self.compilation_config is not None - - # NOTE(woosuk): For Eagle, we only use CUDA graphs for decode. - self.cudagraph_mode = self.compilation_config.cudagraph_mode.decode_mode() - - # only need to capture uniform decode cudagraph sizes (the 2nd return value) - _, self.cudagraph_sizes = get_cudagraph_sizes( - self.compilation_config.cudagraph_capture_sizes, - self.max_num_reqs, - self.max_num_tokens, - self.cudagraph_mode, - uniform_decode_query_len=1, - uniform_decode_cudagraph=True, + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + cudagraph_mode: CUDAGraphMode, + draft_tokens: torch.Tensor, + ): + assert not cudagraph_mode.has_mode(CUDAGraphMode.PIECEWISE), ( + "EagleCudaGraphManager does not support PIECEWISE mode yet" ) - - self.graphs: dict[int, torch.cuda.CUDAGraph] = {} - self.pool = None - if self.cudagraph_mode != CUDAGraphMode.NONE: + # Eagle always uses uniform decode with query_len=1 + super().__init__(vllm_config, device, cudagraph_mode, decode_query_len=1) + self.draft_tokens = draft_tokens + + # Use a dedicated pool for Eagle to avoid memory overlap with the main + # model's cudagraph. The base class uses a shared global pool, but Eagle's + # internal allocations (e.g., gumbel_sample temporaries) can conflict with + # the main model's allocations when sharing the same pool. + if cudagraph_mode: self.pool = torch.cuda.graph_pool_handle() - def get_cudagraph_size(self, num_tokens: int) -> int | None: - return self.cudagraph_sizes.get(num_tokens) - - def get_cudagraph_runtime_mode( - self, num_tokens: int - ) -> tuple[CUDAGraphMode, int | None]: - cudagraph_size = self.get_cudagraph_size(num_tokens) - if cudagraph_size is None: - cudagraph_mode = CUDAGraphMode.NONE - else: - cudagraph_mode = self.cudagraph_mode - - if ( - cudagraph_mode == CUDAGraphMode.FULL - and cudagraph_size is not None - and cudagraph_size not in self.graphs - ): - # If graph wasn't captured yet, fall back to eager. - # This might happen when the dummy run is called before capture. - cudagraph_mode = CUDAGraphMode.NONE - cudagraph_size = None - return cudagraph_mode, cudagraph_size - - def capture_graph( + def capture( self, - num_tokens: int, - capture_cg_mode: CUDAGraphMode, generate_fn: Callable, model_state: ModelState, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, + progress_bar_desc: str = "Capturing CUDA graphs", ) -> None: - assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( - f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}" - ) - if capture_cg_mode == CUDAGraphMode.PIECEWISE: - capture_fn = self._capture_piecewise_graph - else: - capture_fn = self._capture_full_graph - - num_reqs = min(num_tokens, self.max_num_reqs) - attn_metadata, slot_mappings = prepare_inputs_to_capture( - num_reqs, - num_tokens, - model_state, - input_buffers, - block_tables, - attn_groups, - kv_cache_config, - ) - num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) - - # Warm up. - generate_fn( - num_reqs, - num_tokens, - attn_metadata, - slot_mappings, - num_tokens_across_dp, - CUDAGraphMode.NONE, - ) - - # Capture the graph. - capture_fn( - num_reqs=num_reqs, - num_tokens=num_tokens, - generate_fn=generate_fn, - attn_metadata=attn_metadata, - slot_mappings=slot_mappings, - num_tokens_across_dp=num_tokens_across_dp, - ) - - def _capture_full_graph( - self, - num_reqs: int, - num_tokens: int, - generate_fn: Callable, - attn_metadata: dict[str, Any], - slot_mappings: dict[str, torch.Tensor], - num_tokens_across_dp: torch.Tensor, - ) -> None: - assert num_tokens not in self.graphs - graph = torch.cuda.CUDAGraph() - - # Sync offloader's copy stream before capture. - # Ensure any pre-capture prefetches from offloader are complete. - get_offloader().sync_prev_onload() + """Capture CUDA graphs for Eagle speculative decoding (FULL mode only).""" + + def create_forward_fn( + desc: BatchExecutionDescriptor, + ) -> Callable[[CUDAGraphMode], None]: + num_tokens = desc.num_tokens + num_reqs = desc.num_reqs or min(num_tokens, self.max_num_reqs) + num_tokens_across_dp = ( + torch.full((self.dp_size,), num_tokens, dtype=torch.int32, device="cpu") + if self.dp_size > 1 + else None + ) + attn_metadata, slot_mappings = prepare_inputs_to_capture( + num_reqs, + num_tokens, + model_state, + input_buffers, + block_tables, + attn_groups, + kv_cache_config, + ) - with torch.cuda.graph(graph, self.pool): - generate_fn( + return lambda cg_mode: generate_fn( num_reqs, num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp, - CUDAGraphMode.NONE, + cg_mode, ) - # Join offloader's copy stream after forward to avoid unjoined - # stream error. The last layer's start_prefetch forks copy_stream, - # but wait_prefetch only happens in the next forward pass. - get_offloader().join_after_forward() - self.graphs[num_tokens] = graph - - def _capture_piecewise_graph( - self, - num_reqs: int, - num_tokens: int, - generate_fn: Callable, - attn_metadata: dict[str, Any], - slot_mappings: dict[str, torch.Tensor], - num_tokens_across_dp: torch.Tensor, - ) -> None: - generate_fn( - num_reqs, - num_tokens, - attn_metadata, - slot_mappings, - num_tokens_across_dp, - CUDAGraphMode.PIECEWISE, - ) - - @torch.inference_mode() - def capture( - self, - generate_fn: Callable, - model_state: ModelState, - input_buffers: InputBuffers, - block_tables: BlockTables, - attn_groups: list[list[AttentionGroup]], - kv_cache_config: KVCacheConfig, - ) -> None: - if self.cudagraph_mode == CUDAGraphMode.NONE: - return - capture_graphs( - self.cudagraph_sizes, - self.device, - self.capture_graph, - capture_cudagraph_mode=self.cudagraph_mode, - desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})", - generate_fn=generate_fn, - model_state=model_state, - input_buffers=input_buffers, - block_tables=block_tables, - attn_groups=attn_groups, - kv_cache_config=kv_cache_config, - ) + super().capture(create_forward_fn, progress_bar_desc) - def run_fullgraph(self, num_tokens: int) -> None: - assert num_tokens in self.graphs - # Sync offloader before replay - needed when transitioning from - # eager/piecewise to full cudagraph (e.g., prefill → decode). - # The previous eager iteration's start_prefetch may have queued - # H2D copies on copy_stream that the graph's captured events - # cannot see. Without this, replay could overwrite static buffers - # while those copies are still in flight. - get_offloader().sync_prev_onload() - self.graphs[num_tokens].replay() + def run_fullgraph(self, desc: BatchExecutionDescriptor) -> torch.Tensor: + """Replay a captured FULL cudagraph and return draft tokens.""" + super().run_fullgraph(desc) + return self.draft_tokens diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index 9185850dcb62..8d3c3ba8e9ef 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -16,7 +16,7 @@ build_slot_mappings_by_layer, ) from vllm.v1.worker.gpu.block_table import BlockTables -from vllm.v1.worker.gpu.dp_utils import get_cudagraph_and_dp_padding +from vllm.v1.worker.gpu.dp_utils import sync_cudagraph_and_dp_padding from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample @@ -75,7 +75,16 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): device=device, ) - self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device) + # currently we don't support PIECEWISE for Eagle. + cudagraph_mode = vllm_config.compilation_config.cudagraph_mode + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL: + cudagraph_mode = CUDAGraphMode.FULL_DECODE_ONLY + else: + cudagraph_mode = CUDAGraphMode.NONE + + self.cudagraph_manager = EagleCudaGraphManager( + vllm_config, device, cudagraph_mode, self.draft_tokens + ) def load_model(self, target_model: nn.Module) -> None: self.model = load_eagle_model(target_model, self.vllm_config) @@ -171,7 +180,7 @@ def generate_draft( ) if attn_metadata is not None: self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos + idx_mapping, query_start_loc, pos, num_tokens_padded ) def capture_model(self) -> None: @@ -185,6 +194,7 @@ def capture_model(self) -> None: self.block_tables, self.attn_groups, self.kv_cache_config, + progress_bar_desc="Capturing eagle CUDA graphs", ) @torch.inference_mode() @@ -251,6 +261,7 @@ def propose( logits = self.model.compute_logits(sample_hidden_states) num_reqs = input_batch.num_reqs + num_reqs_padded = input_batch.num_reqs_after_padding # NOTE(woosuk): For draft sampling, we only consider the temperature # and ignore the other sampling parameters such as top_k and top_p, # for simplicity and performance. @@ -292,48 +303,52 @@ def propose( self.max_num_reqs, ) - if not (dummy_run and skip_attn_for_dummy_run): - query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] - slot_mappings = self.block_tables.compute_slot_mappings( - idx_mapping, query_start_loc, pos - ) + # Get batch descriptor and sync across DP ranks. + # Eagle uses FULL-only mode, dispatch with uniform_token_count=1 for decode - cudagraph_mode, cudagraph_size = ( - self.cudagraph_manager.get_cudagraph_runtime_mode(num_reqs) - ) - num_tokens_padded, num_tokens_across_dp, synced_cudagraph_mode = ( - get_cudagraph_and_dp_padding( + batch_desc = self.cudagraph_manager.dispatch(num_reqs, num_reqs, 1) + num_tokens_across_dp = None + + if self.dp_size > 1: + batch_desc, num_tokens_across_dp = sync_cudagraph_and_dp_padding( + self.cudagraph_manager, + batch_desc, num_reqs, - cudagraph_size, - cudagraph_mode.value, + num_reqs, + 1, # uniform_token_count self.dp_size, self.dp_rank, ) - ) - cudagraph_mode = CUDAGraphMode(synced_cudagraph_mode) - if cudagraph_mode == CUDAGraphMode.FULL: - # Run full CUDA graph. - self.cudagraph_manager.run_fullgraph(num_tokens_padded) - return self.draft_tokens[:num_reqs] + + if not (dummy_run and skip_attn_for_dummy_run): + query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] + slot_mappings = self.block_tables.compute_slot_mappings( + idx_mapping, query_start_loc, pos, batch_desc.num_tokens + ) + + if batch_desc.cg_mode == CUDAGraphMode.FULL: + return self.cudagraph_manager.run_fullgraph(batch_desc)[:num_reqs] # Run eager or piecewise CUDA graph. attn_metadata_updated = None slot_mappings_updated = None if not (dummy_run and skip_attn_for_dummy_run): query_start_loc_cpu = torch.arange( - num_reqs + 1, dtype=torch.int32, device="cpu" + num_reqs_padded + 1, dtype=torch.int32, device="cpu" ) - block_tables = [x[:num_reqs] for x in self.block_tables.input_block_tables] + block_tables = [ + x[:num_reqs_padded] for x in self.block_tables.input_block_tables + ] # FIXME(woosuk): This is UNSAFE!! attn_metadata_updated = build_attn_metadata( attn_groups=self.attn_groups, - num_reqs=num_reqs, - num_tokens=num_reqs, + num_reqs=num_reqs_padded, + num_tokens=num_reqs_padded, query_start_loc_gpu=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, max_query_len=1, - seq_lens=self.input_buffers.seq_lens[:num_reqs], + seq_lens=self.input_buffers.seq_lens[:num_reqs_padded], max_seq_len=self.max_model_len, block_tables=block_tables, slot_mappings=slot_mappings, @@ -345,11 +360,11 @@ def propose( self.generate_draft( num_reqs, - num_tokens_padded, + batch_desc.num_tokens, attn_metadata_updated, slot_mappings_updated, num_tokens_across_dp=num_tokens_across_dp, - cudagraph_runtime_mode=cudagraph_mode, + cudagraph_runtime_mode=batch_desc.cg_mode, ) return self.draft_tokens[:num_reqs]