Skip to content

Commit ab519de

Browse files
committed
Query builder for cudagraph support instead of per-layer metadata and hardcoded error
Signed-off-by: luka <[email protected]>
1 parent 2fd221d commit ab519de

File tree

5 files changed

+61
-42
lines changed

5 files changed

+61
-42
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Attention layer with FlashAttention."""
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Optional
5+
from typing import TYPE_CHECKING, Any, ClassVar, Optional
66

77
import numpy as np
88
import torch
@@ -127,10 +127,6 @@ class LocalAttentionMetadata:
127127

128128
local_attn_metadata: Optional[LocalAttentionMetadata] = None
129129

130-
# Supported for prefill and decode.
131-
# Backend (FA2 vs FA3 vs Triton) support checked separately.
132-
cuda_graph_supported: bool = True
133-
134130

135131
#
136132
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
@@ -311,6 +307,7 @@ def _get_sliding_window_configs(
311307

312308
class FlashAttentionMetadataBuilder(
313309
AttentionMetadataBuilder[FlashAttentionMetadata]):
310+
full_cudagraph_supported: ClassVar[bool] = get_flash_attn_version() == 3
314311

315312
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
316313
block_table: BlockTable):
@@ -501,6 +498,11 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
501498
)
502499
return attn_metadata
503500

501+
def can_run_in_cudagraph(
502+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
503+
# Full CUDA Graph always supported (FA2 support checked separately)
504+
return True
505+
504506
def use_cascade_attention(self, *args, **kwargs) -> bool:
505507
return use_cascade_attention(*args, **kwargs)
506508

vllm/v1/attention/backends/mla/common.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -326,13 +326,6 @@ def __post_init__(self):
326326
f"Only {supported_head_sizes} are supported for head_dim,",
327327
f"received {self.head_dim}.")
328328

329-
@property
330-
def cuda_graph_supported(self):
331-
"""
332-
Full CUDA Graphs (including attention) only supported for pure decode.
333-
"""
334-
return self.num_prefills == 0
335-
336329

337330
M = TypeVar("M", bound=MLACommonMetadata)
338331

@@ -598,6 +591,10 @@ def build(self, common_prefix_len: int,
598591
decode=decode_metadata,
599592
)
600593

594+
def can_run_in_cudagraph(
595+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
596+
return common_attn_metadata.max_query_len == 1
597+
601598

602599
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
603600
"""

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5-
from typing import Any, Optional
5+
from typing import Any, ClassVar, Optional
66

77
import torch
88

@@ -54,6 +54,7 @@ class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
5454

5555

5656
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
57+
full_cudagraph_supported: ClassVar[bool] = True # Decode-only
5758

5859
def __init__(self, runner, kv_cache_spec: AttentionSpec,
5960
block_table: BlockTable):

vllm/v1/attention/backends/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import abc
44
from abc import abstractmethod
55
from dataclasses import dataclass
6-
from typing import TYPE_CHECKING, Generic, TypeVar
6+
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
77

88
import numpy as np
99
import torch
@@ -38,6 +38,8 @@ class CommonAttentionMetadata:
3838

3939

4040
class AttentionMetadataBuilder(abc.ABC, Generic[M]):
41+
# Does this backend/builder support CUDA Graphs for attention.
42+
full_cudagraph_supported: ClassVar[bool] = False
4143

4244
@abstractmethod
4345
def build(self, common_prefix_len: int,
@@ -48,6 +50,13 @@ def build(self, common_prefix_len: int,
4850
"""
4951
raise NotImplementedError
5052

53+
def can_run_in_cudagraph(
54+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
55+
"""
56+
Can this batch (with given metadata) use CUDA Graphs for attention.
57+
"""
58+
return False
59+
5160
def build_for_cudagraph_capture(
5261
self, common_attn_metadata: CommonAttentionMetadata) -> M:
5362
"""

vllm/v1/worker/gpu_model_runner.py

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from vllm.attention import AttentionType, get_attn_backend
1919
from vllm.attention.backends.abstract import AttentionBackend
2020
from vllm.attention.layer import Attention
21-
from vllm.attention.utils.fa_utils import get_flash_attn_version
2221
from vllm.config import (CompilationLevel, VllmConfig,
2322
get_layers_from_vllm_config)
2423
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
@@ -555,7 +554,15 @@ def _get_cumsum_and_arange(
555554
def _prepare_inputs(
556555
self,
557556
scheduler_output: "SchedulerOutput",
558-
) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]:
557+
) -> tuple[dict[str, Any], bool, torch.Tensor,
558+
Optional[SpecDecodeMetadata]]:
559+
"""
560+
:return: tuple[
561+
attn_metadata: layer-to-attention_metadata mapping,
562+
attention_cuda_graphs: whether attention can run in captured cudagraph
563+
logits_indices, spec_decode_metadata
564+
]
565+
"""
559566
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
560567
assert total_num_scheduled_tokens > 0
561568
num_reqs = self.input_batch.num_reqs
@@ -677,27 +684,31 @@ def _prepare_inputs(
677684
)
678685

679686
attn_metadata: dict[str, Any] = {}
687+
attention_cuda_graphs = []
680688
# Prepare the attention metadata for each KV cache group and make layers
681689
# in the same group share the same metadata.
682690
for kv_cache_group_id, kv_cache_group_spec in enumerate(
683691
self.kv_cache_config.kv_cache_groups):
684692

685693
# Prepare for cascade attention if enabled & beneficial.
686694
common_prefix_len = 0
695+
builder = self.attn_metadata_builders[kv_cache_group_id]
687696
if self.cascade_attn_enabled:
688697
common_prefix_len = self._compute_cascade_attn_prefix_len(
689698
num_scheduled_tokens,
690699
scheduler_output.
691700
num_common_prefix_blocks[kv_cache_group_id],
692701
kv_cache_group_spec.kv_cache_spec,
693-
self.attn_metadata_builders[kv_cache_group_id],
702+
builder,
694703
)
695704

696-
attn_metadata_i = (
697-
self.attn_metadata_builders[kv_cache_group_id].build(
698-
common_prefix_len=common_prefix_len,
699-
common_attn_metadata=common_attn_metadata,
700-
))
705+
attn_metadata_i = (builder.build(
706+
common_prefix_len=common_prefix_len,
707+
common_attn_metadata=common_attn_metadata,
708+
))
709+
attention_cuda_graphs.append(
710+
builder.can_run_in_cudagraph(common_attn_metadata))
711+
701712
for layer_name in kv_cache_group_spec.layer_names:
702713
attn_metadata[layer_name] = attn_metadata_i
703714

@@ -729,7 +740,8 @@ def _prepare_inputs(
729740
if self.lora_config:
730741
self.set_active_loras(self.input_batch, num_scheduled_tokens)
731742

732-
return attn_metadata, logits_indices, spec_decode_metadata
743+
return attn_metadata, all(
744+
attention_cuda_graphs), logits_indices, spec_decode_metadata
733745

734746
def _compute_cascade_attn_prefix_len(
735747
self,
@@ -1189,8 +1201,8 @@ def execute_model(
11891201
return self.kv_connector_no_forward(scheduler_output)
11901202

11911203
# Prepare the decoder inputs.
1192-
attn_metadata, logits_indices, spec_decode_metadata = (
1193-
self._prepare_inputs(scheduler_output))
1204+
(attn_metadata, attention_cuda_graphs, logits_indices,
1205+
spec_decode_metadata) = (self._prepare_inputs(scheduler_output))
11941206
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
11951207
if (self.use_cuda_graph
11961208
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
@@ -1255,11 +1267,9 @@ def execute_model(
12551267
intermediate_tensors = self.sync_and_slice_intermediate_tensors(
12561268
num_input_tokens, intermediate_tensors, True)
12571269

1258-
# Some attention backends only support CUDA graphs in pure decode.
1259-
# Assume cuda_graph_supported is false if it does not exist.
1260-
attention_cuda_graphs = all(
1261-
getattr(m, "cuda_graph_supported", False)
1262-
for _, m in attn_metadata.items())
1270+
# Some attention backends only support CUDA Graphs in pure decode.
1271+
# If attention doesn't support CUDA Graphs for this batch, but we
1272+
# compiled with full CUDA graphs, we have to skip them entirely.
12631273
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
12641274

12651275
# Run the decoder.
@@ -2100,20 +2110,20 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
21002110
"Non-Attention backend is not supported by V1 "
21012111
"GPUModelRunner.")
21022112

2103-
if self.compilation_config.full_cuda_graph:
2104-
attn_backend_name = attn_backend_i.__name__
2105-
flash_attn_version = get_flash_attn_version()
2106-
if ((attn_backend_name != "FlashAttentionBackend"
2107-
or flash_attn_version != 3)
2108-
and attn_backend_name != "FlashMLABackend"):
2109-
raise ValueError(
2110-
f"Full CUDAGraph is only supported with FA3 or FlashMLA"
2111-
f". Current attention backend is {attn_backend_name}, "
2112-
f"FlashAttention version is {flash_attn_version}.")
2113-
21142113
block_table_i = self.input_batch.block_table[i]
21152114
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
2116-
weakref.proxy(self), kv_cache_spec, block_table_i)
2115+
weakref.proxy(self),
2116+
kv_cache_spec,
2117+
block_table_i,
2118+
)
2119+
2120+
if (self.full_cuda_graph
2121+
and not attn_metadata_builder_i.full_cudagraph_supported):
2122+
raise ValueError(
2123+
f"Full CUDAGraph not supported for "
2124+
f"{attn_backend_i.__name__}. Turn off CompilationConfig."
2125+
f"full_cuda_graph or use a different attention backend.")
2126+
21172127
self.attn_backends.append(attn_backend_i)
21182128
self.attn_metadata_builders.append(attn_metadata_builder_i)
21192129

0 commit comments

Comments
 (0)