Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 179 additions & 58 deletions vllm/v1/worker/gpu/cudagraph_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Iterable
from collections.abc import Callable
from typing import Any

import numpy as np
Expand All @@ -12,7 +12,8 @@
from vllm.config.compilation import CUDAGraphMode
from vllm.distributed import get_dcp_group
from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import (
Expand All @@ -36,14 +37,27 @@ def __init__(self, vllm_config: VllmConfig, uses_mrope: bool, device: torch.devi
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.dp_size = vllm_config.parallel_config.data_parallel_size

self.uniform_decode_query_len = 1
spec_config = vllm_config.speculative_config
if spec_config is not None:
self.uniform_decode_query_len += spec_config.num_speculative_tokens

self.compilation_config = vllm_config.compilation_config
assert self.compilation_config is not None
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.cudagraph_sizes = get_cudagraph_sizes(

use_uniform_decode_cudagraph = (
self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and self.cudagraph_mode.separate_routine()
)
self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes(
self.compilation_config.cudagraph_capture_sizes,
self.max_num_reqs,
self.max_num_tokens,
self.cudagraph_mode,
self.uniform_decode_query_len,
use_uniform_decode_cudagraph,
)

self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
Expand All @@ -56,29 +70,42 @@ def needs_capture(self) -> bool:
return len(self.cudagraph_sizes) > 0

def get_cudagraph_size(
self,
num_tokens_after_padding: int,
num_tokens_per_request: Iterable[int],
self, num_tokens: int, uniform_decode: bool = False
) -> int | None:
return get_cudagraph_size(
num_tokens_after_padding,
num_tokens_per_request,
self.cudagraph_sizes,
self.cudagraph_mode,
)
if uniform_decode and self.uniform_decode_cudagraph_sizes:
return self.uniform_decode_cudagraph_sizes.get(num_tokens)
return self.cudagraph_sizes.get(num_tokens)

def capture_graph(
self,
num_tokens: int,
capture_cg_mode: CUDAGraphMode,
model: nn.Module,
input_buffers: InputBuffers,
mrope_positions: torch.Tensor | None,
inputs_embeds: torch.Tensor | None,
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
uniform_decode: bool = False,
) -> None:
num_reqs = min(num_tokens, self.max_num_reqs)
# select and check capture function
assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}"
)
if capture_cg_mode == CUDAGraphMode.PIECEWISE:
capture_fn = self._capture_piecewise_graph
else:
capture_fn = self._capture_full_graph
# prepare inputs
if uniform_decode:
num_reqs = min(
cdiv(num_tokens, self.uniform_decode_query_len),
self.max_num_reqs,
)
else:
num_reqs = min(num_tokens, self.max_num_reqs)
input_ids = input_buffers.input_ids[:num_tokens]
positions = input_buffers.positions[:num_tokens]
if self.uses_mrope:
Expand All @@ -94,6 +121,9 @@ def capture_graph(
attn_metadata_builders,
self.max_model_len,
kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
)
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)

Expand All @@ -114,13 +144,40 @@ def capture_graph(
if self.hidden_states is None:
self.hidden_states = torch.empty_like(hidden_states)

capture_fn(
num_tokens=num_tokens,
num_reqs=num_reqs,
model=model,
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
num_tokens_across_dp=num_tokens_across_dp,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings,
has_lora=has_lora,
)

def _capture_full_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
assert attn_metadata is not None
# Capture the graph.
assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph()
with (
set_forward_context(
attn_metadata,
self.vllm_config,
attn_metadata=attn_metadata,
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp,
Expand All @@ -133,9 +190,44 @@ def capture_graph(
positions=positions,
inputs_embeds=inputs_embeds,
)
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states
self.graphs[num_tokens] = graph

def _capture_piecewise_graph(
self,
num_tokens: int,
num_reqs: int,
model: nn.Module,
input_ids: torch.Tensor,
positions: torch.Tensor,
inputs_embeds: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor,
attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
has_lora: bool = False,
) -> None:
# create batch descriptor for piecewise cudagraph dispatch key
batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora)

# Capture run - CUDAGraphWrapper inside torch.compile will auto capture.
with set_forward_context(
attn_metadata=None, # piecewise no need attn_metadata
vllm_config=self.vllm_config,
num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings,
):
hidden_states = model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
assert self.hidden_states is not None
self.hidden_states[:num_tokens] = hidden_states

@torch.inference_mode()
def capture(
self,
Expand All @@ -146,22 +238,62 @@ def capture(
block_tables: BlockTables,
attn_metadata_builders: list[AttentionMetadataBuilder],
kv_cache_config: KVCacheConfig,
has_lora: bool = False,
) -> None:
capture_graphs(
self.cudagraph_sizes,
self.device,
self.capture_graph,
common_kwargs = dict(
device=self.device,
capture_fn=self.capture_graph,
model=model,
input_buffers=input_buffers,
mrope_positions=mrope_positions,
inputs_embeds=inputs_embeds,
block_tables=block_tables,
attn_metadata_builders=attn_metadata_builders,
kv_cache_config=kv_cache_config,
has_lora=has_lora,
)

def run(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs
# Phase 1: Capture for mixed prefill-decode batches if needed.
mixed_mode = self.cudagraph_mode.mixed_mode()
if mixed_mode != CUDAGraphMode.NONE:
capture_graphs(
cudagraph_sizes=self.cudagraph_sizes,
capture_cudagraph_mode=mixed_mode,
desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})",
uniform_decode=False,
**common_kwargs,
)

# Phase 2: Capture FULL graphs for uniform decode batches if needed.
# This is only needed if we use a separate routine for decode batches
# and the decode_mode is FULL.
if self.uniform_decode_cudagraph_sizes:
capture_graphs(
cudagraph_sizes=self.uniform_decode_cudagraph_sizes,
capture_cudagraph_mode=CUDAGraphMode.FULL,
desc="Capturing CUDA graphs (decode, FULL)",
uniform_decode=True,
**common_kwargs,
)

def get_cudagraph_runtime_mode(
self, num_reqs: int, num_tokens: int, max_query_len: int
) -> tuple[CUDAGraphMode, int | None]:
is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and (
num_tokens == max_query_len * num_reqs
)

cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode)
if cudagraph_size is None:
cudagraph_mode = CUDAGraphMode.NONE
elif is_uniform_decode:
cudagraph_mode = self.cudagraph_mode.decode_mode()
else:
cudagraph_mode = self.cudagraph_mode.mixed_mode()
return cudagraph_mode, cudagraph_size

def run_fullgraph(self, num_tokens: int) -> torch.Tensor:
assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens"
self.graphs[num_tokens].replay()
assert self.hidden_states is not None
return self.hidden_states[:num_tokens]
Expand All @@ -172,68 +304,53 @@ def get_cudagraph_sizes(
max_num_reqs: int,
max_num_tokens: int,
cudagraph_mode: CUDAGraphMode,
) -> dict[int, int]:
if not cudagraph_mode.has_full_cudagraphs():
return {}
uniform_decode_query_len: int = 1,
uniform_decode_cudagraph: bool = False,
) -> tuple[dict[int, int], dict[int, int]]:
# Support both FULL and PIECEWISE cudagraph modes
if cudagraph_mode == CUDAGraphMode.NONE:
return {}, {}
if not capture_sizes:
return {}
return {}, {}

capture_sizes = sorted(capture_sizes)
# Limit the capture sizes to the max number of requests or tokens.
upper_bound = (
max_num_reqs
if cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY
else max_num_tokens
)
capture_sizes = [x for x in capture_sizes if x <= upper_bound]
if not capture_sizes:
return {}
return {}, {}

cudagraph_sizes: dict[int, int] = {}
for i in range(1, capture_sizes[-1] + 1):
for x in capture_sizes:
if i <= x:
cudagraph_sizes[i] = x
break
return cudagraph_sizes


def get_cudagraph_size(
num_tokens_after_dp_padding: int,
num_tokens_per_request: Iterable[int],
cudagraph_sizes: dict[int, int],
cudagraph_mode: CUDAGraphMode,
) -> int | None:
if not cudagraph_mode.has_full_cudagraphs():
# No full CUDA graph is used.
return None

size = cudagraph_sizes.get(num_tokens_after_dp_padding)
if size is None:
# No CUDA graph for this size.
return None

is_mixed = any(x > 1 for x in num_tokens_per_request)
if is_mixed and cudagraph_mode.mixed_mode() != CUDAGraphMode.FULL:
# Prefill is included, and this mode doesn't use CUDA graph for it.
return None
return size
uniform_decode_cudagraph_sizes: dict[int, int] = {}
if uniform_decode_cudagraph:
max_num_tokens = max_num_reqs * uniform_decode_query_len
uniform_decode_cudagraph_sizes = {
k: v
for k, v in cudagraph_sizes.items()
if v <= max_num_tokens and v >= uniform_decode_query_len
}
return cudagraph_sizes, uniform_decode_cudagraph_sizes


def capture_graphs(
cudagraph_sizes: dict[int, int],
device: torch.device,
capture_fn: Callable,
capture_cudagraph_mode: CUDAGraphMode,
desc: str = "Capturing CUDA graphs",
**capture_kwargs,
) -> None:
# Capture larger graphs first.
sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True)
if is_global_first_rank():
sizes_to_capture = tqdm(sizes_to_capture, desc="Capturing CUDA graphs")
sizes_to_capture = tqdm(sizes_to_capture, desc=desc)

with graph_capture(device=device):
for size in sizes_to_capture:
capture_fn(size, **capture_kwargs)
capture_fn(size, capture_cudagraph_mode, **capture_kwargs)


def prepare_inputs_to_capture(
Expand All @@ -244,8 +361,12 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int,
kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
num_tokens_per_req = num_tokens // num_reqs
if uniform_decode_query_len > 0:
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs

query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
Expand Down
Loading