Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b7d124c
[V1][CUDA Graph] Fix attention metadata tensor sizes for padded batches
ayushsatyam146 Sep 19, 2025
a42e3ce
review comments
LucasWilkinson Oct 5, 2025
f586e2b
Merge branch 'main' into cudagraph-fix
LucasWilkinson Oct 6, 2025
93071cf
format
LucasWilkinson Oct 6, 2025
f6b7d63
Merge branch 'main' into cudagraph-fix
LucasWilkinson Oct 6, 2025
9acd529
cleanup
LucasWilkinson Oct 6, 2025
d6a5a9a
cleanup
LucasWilkinson Oct 7, 2025
4a0e9df
more cleanup
LucasWilkinson Oct 7, 2025
0895405
Merge remote-tracking branch 'origin/main' into cudagraph-fix
LucasWilkinson Oct 7, 2025
3f495e2
cleanup
LucasWilkinson Oct 7, 2025
01db86e
clean up
LucasWilkinson Oct 7, 2025
5780e9a
no need for lora change
LucasWilkinson Oct 7, 2025
00c3280
review comments
LucasWilkinson Oct 8, 2025
5d838ed
Merge remote-tracking branch 'origin/main' into cudagraph-fix
LucasWilkinson Oct 8, 2025
370e700
more refactoring
LucasWilkinson Oct 8, 2025
a335b80
unifiy build attention metadata
LucasWilkinson Oct 8, 2025
89ca98e
clean-up
LucasWilkinson Oct 8, 2025
bbbc8fb
refactor
LucasWilkinson Oct 14, 2025
997f71e
wip
LucasWilkinson Oct 14, 2025
c907ef3
cleanup
LucasWilkinson Oct 14, 2025
d88842f
cleanup
LucasWilkinson Oct 14, 2025
39562aa
fix
LucasWilkinson Oct 14, 2025
0d997cf
cleanup
LucasWilkinson Oct 14, 2025
ccfb764
cleanup
LucasWilkinson Oct 14, 2025
021882c
fix
LucasWilkinson Oct 14, 2025
1d533bc
clean up
LucasWilkinson Oct 14, 2025
e48ab54
fix docs error
LucasWilkinson Oct 14, 2025
f3d09f5
Merge remote-tracking branch 'nm/lwilkinson/seperate-build-attn-metad…
LucasWilkinson Oct 15, 2025
d23b110
Fix merge conflicts: add missing imports and fix indentation
LucasWilkinson Oct 15, 2025
c3eba9b
Merge remote-tracking branch 'origin/main' into pr/ayushsatyam146/24002
LucasWilkinson Nov 11, 2025
d00b29c
wip
LucasWilkinson Nov 12, 2025
df7edde
wip
LucasWilkinson Nov 12, 2025
f95af46
Merge branch 'lwilkinson/pad-before-metadata' into pr/ayushsatyam146/…
LucasWilkinson Nov 12, 2025
6160f1c
update docs
LucasWilkinson Nov 12, 2025
cebcc54
cleanup
LucasWilkinson Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions docs/design/cuda_graphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,14 @@
```python
class BatchDescriptor(NamedTuple):
num_tokens: int
uniform_decode: bool = False
num_reqs: int
uniform: bool = False
has_lora: bool = False
```

where `num_tokens` can be the padded token length, and `uniform_decode` is determined by if `max_query_len` of a batch is equal to the desired `max_query_len` of a uniform_decode, and the num_scheduled_tokens is divisible by that desired `max_query_len`.
where `num_tokens` can be the padded token length, and `uniform` indicates if all the requests have the same query lengths. Many attention backends only support full cudagraphs when the batches are uniform; pure decode batches are uniform but may not be query length 1 (i.e. `num_tokens == num_reqs`), this occurs in the validation pass of spec-decode where "decode" batches will have a query length of `1+num_spec_tokens`.

Check failure on line 92 in docs/design/cuda_graphs.md

View workflow job for this annotation

GitHub Actions / pre-commit

Trailing spaces [Expected: 0 or 2; Actual: 1]

The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item. We are safe to exclude items like `uniform_query_len` because it is a constant at runtime for a certain setup currently. For example, it should be either `1` for a commonly pure decode or `1+num_spec_tokens` for a validation phase of speculative decode.
The goal of this structure is to uniquely identify a (padded) batch with minimal possible items corresponding to a CUDA Graphs item.

!!! note
The prototype of `BatchDescriptor` may be extended for more general situations in the future, e.g., include more items, like `uniform_query_len` to support multiple different uniform decode lengths settings (<https://github.com/vllm-project/vllm/pull/23679>), or other modifications needed to support CUDA Graphs for models whose inputs are not necessarily token length aware (for example, some multi-modal inputs).
Expand Down
18 changes: 11 additions & 7 deletions vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,27 @@ class BatchDescriptor(NamedTuple):
"""

num_tokens: int
uniform_decode: bool = False
num_reqs: int | None = None
"""
False can also be used for an uniform decode batch to dispatch to the
cudagraph supporting non-uniform batches.
Number of requests in the batch. Can be None for PIECEWISE cudagraphs where
were the cudagraphs can handle any number of requests.
"""
uniform: bool = False
"""
True if all the requests in the batch have the same number of tokens.
"""
has_lora: bool = False
"""
Whether this batch has active LoRA adapters.
"""

@property
def non_uniform(self) -> "BatchDescriptor":
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a non-uniform version of current batch descriptor.
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, uniform_decode=False, has_lora=self.has_lora
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
)


Expand Down
21 changes: 1 addition & 20 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,31 +708,12 @@ def build(

if num_decodes > 0:
pure_decode = num_prefills == 0
# possible required padding for cudagraph replay
use_cudagraph = (
self.enable_cuda_graph
and pure_decode
and num_decode_tokens <= self._decode_cudagraph_max_bs
)
if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph(
num_decode_tokens
)
# Carefully fulfill the padding region with reasonable value
# on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing
self.paged_kv_indptr_cpu[
1 + num_decodes : 1 + num_input_tokens
].fill_(paged_kv_indptr_cpu[-1])
# Fill the remaining paged_kv_last_page_len_cpu with 1.
# This is because flashinfer treats 0 as a full page
# instead of empty.
self.paged_kv_last_page_len_cpu[num_decodes:num_input_tokens].fill_(
1
)

else:
num_input_tokens = num_decode_tokens
num_input_tokens = num_decode_tokens

attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,9 @@ def split_decodes_and_prefills(
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold
# 0-query len indicates a padded request; leave this at the back
# of the batch with the prefills
is_prefill = query_lens > decode_threshold | query_lens == 0

if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0
Expand Down
85 changes: 68 additions & 17 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger

logger = init_logger(__name__)


class CudagraphDispatcher:
Expand All @@ -29,6 +32,11 @@ def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.cudagraph_mode = self.compilation_config.cudagraph_mode
self.uniform_decode_query_len = (
1
if not self.vllm_config.speculative_config
else 1 + self.vllm_config.speculative_config.num_speculative_tokens
)

# Dict to store valid cudagraph dispatching keys.
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
Expand All @@ -55,6 +63,32 @@ def __init__(self, vllm_config: VllmConfig):

self.keys_initialized = False

def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)

if uniform_decode:
num_reqs = num_tokens // uniform_decode_query_len
assert num_tokens % uniform_decode_query_len == 0
assert num_reqs <= max_num_seqs
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
)
num_reqs = min(num_tokens_padded, max_num_seqs)

return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
)

def add_cudagraph_key(
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
):
Expand Down Expand Up @@ -86,9 +120,7 @@ def initialize_cudagraph_keys(
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
BatchDescriptor(
num_tokens=bs, uniform_decode=False, has_lora=has_lora
),
self._create_padded_batch_descriptor(bs, False, has_lora),
)

# if decode cudagraph mode is FULL, and we don't already have mixed
Expand All @@ -109,40 +141,59 @@ def initialize_cudagraph_keys(
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
self.add_cudagraph_key(
CUDAGraphMode.FULL,
BatchDescriptor(
num_tokens=bs, uniform_decode=True, has_lora=has_lora
),
self._create_padded_batch_descriptor(bs, True, has_lora),
)

self.keys_initialized = True

def _is_compatible(
self, batch_descriptor: BatchDescriptor, candidate: BatchDescriptor
) -> bool:
"""Check if candidate cudagraph can handle the batch request."""
if candidate.num_reqs is None:
return True
assert batch_descriptor.num_reqs is not None
return candidate.num_reqs >= batch_descriptor.num_reqs

def dispatch(
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
self,
num_tokens: int,
num_reqs: int,
uniform_decode: bool,
has_lora: bool,
use_cascade_attn: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).

`num_reqs` reserved for future use; making sure callsites have access to this
information.
"""
# if not initialized, just skip dispatching.
if not self.keys_initialized:
return CUDAGraphMode.NONE, None

non_uniform_key = batch_descriptor.non_uniform
# if a batch use cascade attention, bypass checking full cudagraphs
batch_descriptor = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora
)
relaxed_batch_descriptor = batch_descriptor.relax_for_mixed_batch_cudagraphs()

if not use_cascade_attn:
# check if key exists for full cudagraph
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_descriptor

# otherwise, check if non-uniform key exists
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, non_uniform_key
# otherwise, check if the relaxed key exists
if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_descriptor

# also check if non-uniform key exists for more "general"
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, non_uniform_key
if relaxed_batch_descriptor in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_descriptor

# finally, just return no cudagraphs
return CUDAGraphMode.NONE, None
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
Loading
Loading