Skip to content

Commit b30dfa0

Browse files
[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <[email protected]> Signed-off-by: Matthew Bonanni <[email protected]> Co-authored-by: Luka Govedič <[email protected]>
1 parent 2e78150 commit b30dfa0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1333
-997
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,11 +890,16 @@ steps:
890890
- vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py
891891
- vllm/model_executor/layers/quantization/utils/flashinfer_utils.py
892892
- vllm/v1/attention/backends/flashinfer.py
893+
- vllm/v1/attention/backends/mla/cutlass_mla.py
894+
- vllm/v1/attention/backends/mla/flashinfer_mla.py
895+
- vllm/platforms/cuda.py
896+
- vllm/attention/selector.py
893897
commands:
894898
- nvidia-smi
895899
- python3 examples/offline_inference/basic/chat.py
896900
# Attention
897901
# num_heads2 broken by https://github.com/flashinfer-ai/flashinfer/issues/1353
902+
- pytest -v -s tests/kernels/attention/test_attention_selector.py
898903
- pytest -v -s tests/kernels/attention/test_flashinfer.py -k 'not num_heads2'
899904
- pytest -v -s tests/kernels/attention/test_flashinfer_trtllm_attention.py
900905
- pytest -v -s tests/kernels/attention/test_cutlass_mla_decode.py

tests/compile/test_fusion_attn.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
1111
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
1212
from vllm.attention import Attention, AttentionMetadata
13-
from vllm.attention.backends.registry import _Backend
13+
from vllm.attention.backends.registry import AttentionBackendEnum
1414
from vllm.attention.selector import global_force_attn_backend_context_manager
1515
from vllm.compilation.fusion_attn import ATTN_OP, AttnFusionPass
1616
from vllm.compilation.fx_utils import find_op_nodes
@@ -104,7 +104,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
104104

105105
# TODO(luka) use get_kv_cache_stride_order
106106
# Create dummy KV cache for the selected backend
107-
if backend == _Backend.ROCM_ATTN:
107+
if backend == AttentionBackendEnum.ROCM_ATTN:
108108
# k/v as 1st dimention
109109
# HND: [num_blocks, num_kv_heads, block_size, head_size]
110110
kv_cache = torch.zeros(
@@ -116,7 +116,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
116116
dtype=self.kv_cache_dtype,
117117
device=self.device,
118118
)
119-
elif backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
119+
elif backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
120120
# k/v as 1st dimention
121121
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
122122
kv_cache = torch.zeros(
@@ -128,7 +128,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
128128
dtype=self.kv_cache_dtype,
129129
device=self.device,
130130
)
131-
elif backend == _Backend.TRITON_ATTN:
131+
elif backend == AttentionBackendEnum.TRITON_ATTN:
132132
# k/v as 2nd dimention
133133
# NHD: [num_blocks, block_size, num_kv_heads, head_size]
134134
kv_cache = torch.zeros(
@@ -140,7 +140,7 @@ def build_attn_metadata(self, batch_size: int) -> AttentionMetadata:
140140
dtype=self.kv_cache_dtype,
141141
device=self.device,
142142
)
143-
elif backend == _Backend.FLASHINFER:
143+
elif backend == AttentionBackendEnum.FLASHINFER:
144144
kv_cache = torch.zeros(
145145
num_blocks,
146146
2,
@@ -244,8 +244,8 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
244244
MODELS_FP4: list[tuple[str, type]] = []
245245
HEADS: list[tuple[int, int]] = []
246246
SPLIT_ATTENTION: list[bool] = []
247-
BACKENDS_FP8: list[_Backend] = []
248-
BACKENDS_FP4: list[_Backend] = []
247+
BACKENDS_FP8: list[AttentionBackendEnum] = []
248+
BACKENDS_FP4: list[AttentionBackendEnum] = []
249249

250250
if current_platform.is_cuda():
251251
HEADS = [(64, 8), (40, 8)]
@@ -261,18 +261,18 @@ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
261261
TestAttentionNvfp4QuantPatternModel,
262262
)
263263
]
264-
BACKENDS_FP8 = [_Backend.TRITON_ATTN, _Backend.FLASHINFER]
265-
BACKENDS_FP4 = [_Backend.FLASHINFER]
264+
BACKENDS_FP8 = [AttentionBackendEnum.TRITON_ATTN, AttentionBackendEnum.FLASHINFER]
265+
BACKENDS_FP4 = [AttentionBackendEnum.FLASHINFER]
266266

267267
elif current_platform.is_rocm():
268268
HEADS = [(32, 8), (40, 8)]
269269
MODELS_FP8 = [
270270
("amd/Llama-3.1-8B-Instruct-FP8-KV", TestAttentionFp8StaticQuantPatternModel)
271271
]
272272
BACKENDS = [
273-
_Backend.ROCM_AITER_UNIFIED_ATTN,
274-
_Backend.ROCM_ATTN,
275-
_Backend.TRITON_ATTN,
273+
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
274+
AttentionBackendEnum.ROCM_ATTN,
275+
AttentionBackendEnum.TRITON_ATTN,
276276
]
277277

278278

@@ -302,18 +302,19 @@ def test_attention_quant_pattern(
302302
custom_ops: str,
303303
model_name: str,
304304
model_class: type[AttentionQuantPatternModel],
305-
backend: _Backend,
305+
backend: AttentionBackendEnum,
306306
dist_init,
307307
):
308308
"""Test AttentionStaticQuantPattern fusion pass"""
309-
if backend == _Backend.FLASHINFER and (
309+
if backend == AttentionBackendEnum.FLASHINFER and (
310310
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
311311
):
312312
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
313313

314314
custom_ops_list = custom_ops.split(",") if custom_ops else []
315315

316316
device = torch.device("cuda:0")
317+
torch.set_default_dtype(dtype)
317318
torch.manual_seed(42)
318319

319320
vllm_config = VllmConfig(
@@ -402,7 +403,7 @@ def test_attention_quant_pattern(
402403

403404
result_fused_1 = model_compiled(q, k, v)
404405

405-
if backend == _Backend.FLASHINFER:
406+
if backend == AttentionBackendEnum.FLASHINFER:
406407
# With the Flashinfer backend after the 1st round of the forward
407408
# pass, output quant scale should be loaded into the attn layer's
408409
# _o_scale_float, the 2nd round should reuse the loaded

tests/compile/test_fusions_e2e.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212
import regex as re
1313

14-
from tests.v1.attention.utils import _Backend
14+
from tests.v1.attention.utils import AttentionBackendEnum
1515
from vllm import LLM, SamplingParams
1616
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
1717
from vllm.platforms import current_platform
@@ -24,7 +24,7 @@
2424
class ModelBackendTestCase(NamedTuple):
2525
model_name: str
2626
model_kwargs: dict[str, Any]
27-
backend: _Backend
27+
backend: AttentionBackendEnum
2828
attention_fusions: int
2929
allreduce_fusions: int | None = None
3030

@@ -39,14 +39,14 @@ class ModelBackendTestCase(NamedTuple):
3939
# Use smaller model for L40s in CI
4040
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
4141
model_kwargs=dict(max_model_len=1024),
42-
backend=_Backend.TRITON_ATTN,
42+
backend=AttentionBackendEnum.TRITON_ATTN,
4343
attention_fusions=32,
4444
allreduce_fusions=65,
4545
),
4646
ModelBackendTestCase(
4747
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
4848
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
49-
backend=_Backend.FLASHINFER,
49+
backend=AttentionBackendEnum.FLASHINFER,
5050
attention_fusions=48,
5151
allreduce_fusions=96,
5252
),
@@ -56,7 +56,7 @@ class ModelBackendTestCase(NamedTuple):
5656
ModelBackendTestCase(
5757
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
5858
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
59-
backend=_Backend.FLASHINFER,
59+
backend=AttentionBackendEnum.FLASHINFER,
6060
attention_fusions=32,
6161
allreduce_fusions=65,
6262
),
@@ -67,7 +67,7 @@ class ModelBackendTestCase(NamedTuple):
6767
ModelBackendTestCase(
6868
model_name="meta-llama/Llama-3.1-8B-Instruct",
6969
model_kwargs=dict(max_model_len=1024),
70-
backend=_Backend.TRITON_ATTN,
70+
backend=AttentionBackendEnum.TRITON_ATTN,
7171
attention_fusions=0,
7272
allreduce_fusions=65,
7373
),
@@ -85,19 +85,19 @@ class ModelBackendTestCase(NamedTuple):
8585
ModelBackendTestCase(
8686
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
8787
model_kwargs=dict(max_model_len=1024),
88-
backend=_Backend.TRITON_ATTN,
88+
backend=AttentionBackendEnum.TRITON_ATTN,
8989
attention_fusions=32,
9090
),
9191
ModelBackendTestCase(
9292
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
9393
model_kwargs=dict(max_model_len=1024),
94-
backend=_Backend.ROCM_ATTN,
94+
backend=AttentionBackendEnum.ROCM_ATTN,
9595
attention_fusions=32,
9696
),
9797
ModelBackendTestCase(
9898
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
9999
model_kwargs=dict(max_model_len=1024),
100-
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
100+
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
101101
attention_fusions=32,
102102
),
103103
]
@@ -117,15 +117,15 @@ class ModelBackendTestCase(NamedTuple):
117117
def test_attn_quant(
118118
model_name: str,
119119
model_kwargs: dict[str, Any],
120-
backend: _Backend,
120+
backend: AttentionBackendEnum,
121121
attention_fusions: int,
122122
allreduce_fusions: int,
123123
custom_ops: str,
124124
inductor_graph_partition: bool,
125125
caplog_mp_spawn,
126126
monkeypatch,
127127
):
128-
if backend == _Backend.FLASHINFER and (
128+
if backend == AttentionBackendEnum.FLASHINFER and (
129129
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
130130
):
131131
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
208208
def test_tp2_attn_quant_allreduce_rmsnorm(
209209
model_name: str,
210210
model_kwargs: dict,
211-
backend: _Backend,
211+
backend: AttentionBackendEnum,
212212
attention_fusions: int,
213213
allreduce_fusions: int,
214214
custom_ops: str,

tests/config/test_multimodal_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
import pytest
55

6-
from vllm.attention.backends.registry import _Backend
6+
from vllm.attention.backends.registry import AttentionBackendEnum
77
from vllm.config.multimodal import MultiModalConfig
88

99

1010
def test_mm_encoder_attn_backend_str_conversion():
1111
config = MultiModalConfig(mm_encoder_attn_backend="FLASH_ATTN")
12-
assert config.mm_encoder_attn_backend == _Backend.FLASH_ATTN
12+
assert config.mm_encoder_attn_backend == AttentionBackendEnum.FLASH_ATTN
1313

1414

1515
def test_mm_encoder_attn_backend_invalid():
@@ -20,6 +20,6 @@ def test_mm_encoder_attn_backend_invalid():
2020
def test_mm_encoder_attn_backend_hash_updates():
2121
base_hash = MultiModalConfig().compute_hash()
2222
overridden_hash = MultiModalConfig(
23-
mm_encoder_attn_backend=_Backend.FLASH_ATTN
23+
mm_encoder_attn_backend=AttentionBackendEnum.FLASH_ATTN
2424
).compute_hash()
2525
assert base_hash != overridden_hash

tests/kernels/attention/test_attention_selector.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,13 @@ def test_env(
120120

121121
elif device == "cuda":
122122
with patch("vllm.platforms.current_platform", CudaPlatform()):
123+
capability = torch.cuda.get_device_capability()
123124
if use_mla:
124125
# CUDA MLA backend logic:
125126
# - CUTLASS_MLA: only supported with block_size == 128
126-
# and Blackwell GPUs (SM 10.0), V1 only
127+
# and Blackwell GPUs (SM 10.x), V1 only
127128
# - FLASHINFER_MLA: only supported on Blackwell GPUs
128-
# (SM 10.0+), V1 only
129+
# (SM 10.x), V1 only
129130
# - FLASHMLA: only supported with block_size == 64
130131
# - FLASH_ATTN_MLA: V1 only
131132
# - TRITON_MLA: fallback for other cases
@@ -134,58 +135,72 @@ def test_env(
134135
if block_size != 128:
135136
# CUTLASS_MLA only supports block_size == 128
136137
pytest.skip("CUTLASS_MLA only supports block_size 128")
137-
else:
138-
backend = get_attn_backend(
139-
16, torch.float16, None, block_size, use_mla=use_mla
140-
)
141-
expected = "CUTLASS_MLA"
142-
assert backend.get_name() == expected
138+
if capability[0] != 10:
139+
pytest.skip("CUTLASS MLA is not supported on this platform")
140+
backend = get_attn_backend(
141+
576, torch.float16, None, block_size, use_mla=use_mla
142+
)
143+
expected = "CUTLASS_MLA"
144+
assert backend.get_name() == expected
143145
elif name == "FLASHINFER_MLA":
146+
if capability[0] != 10:
147+
pytest.skip(
148+
"FlashInfer MLA is not supported on this platform"
149+
)
144150
if block_size not in [32, 64]:
145151
# FlashInfer MLA only supports block_size 32 or 64
146152
pytest.skip(
147153
"FlashInfer MLA only supports block_size 32 or 64"
148154
)
149-
else:
150-
backend = get_attn_backend(
151-
16, torch.float16, None, block_size, use_mla=use_mla
152-
)
153-
expected = "FLASHINFER_MLA"
154-
assert backend.get_name() == expected
155+
backend = get_attn_backend(
156+
576, torch.float16, None, block_size, use_mla=use_mla
157+
)
158+
expected = "FLASHINFER_MLA"
159+
assert backend.get_name() == expected
155160
elif name == "FLASHMLA":
156161
if block_size != 64:
157162
# FlashMLA only supports block_size == 64
158163
pytest.skip("FlashMLA only supports block_size 64")
159-
else:
160-
from vllm.v1.attention.backends.mla.flashmla import (
161-
is_flashmla_dense_supported,
162-
)
164+
from vllm.v1.attention.backends.mla.flashmla import (
165+
is_flashmla_dense_supported,
166+
)
163167

164-
is_supported, _ = is_flashmla_dense_supported()
165-
if not is_supported:
166-
pytest.skip("FlashMLA not supported on this platform")
167-
else:
168-
backend = get_attn_backend(
169-
16, torch.float16, None, block_size, use_mla=use_mla
170-
)
171-
expected = name
172-
assert backend.get_name() == expected
168+
is_supported, _ = is_flashmla_dense_supported()
169+
if not is_supported:
170+
pytest.skip("FlashMLA not supported on this platform")
171+
backend = get_attn_backend(
172+
576,
173+
torch.float16,
174+
None,
175+
block_size,
176+
use_mla=use_mla,
177+
)
178+
expected = name
179+
assert backend.get_name() == expected
173180
elif name == "FLASH_ATTN_MLA":
181+
from vllm.attention.utils.fa_utils import (
182+
flash_attn_supports_mla,
183+
)
184+
185+
if not flash_attn_supports_mla():
186+
pytest.skip(
187+
"FlashAttention MLA not supported on this platform"
188+
)
174189
backend = get_attn_backend(
175-
16, torch.float16, None, block_size, use_mla=use_mla
190+
576, torch.float16, None, block_size, use_mla=use_mla
176191
)
177192
expected = "FLASH_ATTN_MLA"
178193
assert backend.get_name() == expected
179194
else:
180195
# TRITON_MLA or other fallback
181196
backend = get_attn_backend(
182-
16, torch.float16, None, block_size, use_mla=use_mla
197+
576, torch.float16, None, block_size, use_mla=use_mla
183198
)
184199
expected = "TRITON_MLA"
185200
assert backend.get_name() == expected
186201
elif name == "FLASHINFER":
187202
backend = get_attn_backend(
188-
16, torch.float16, None, block_size, use_mla=use_mla
203+
64, torch.float16, None, block_size, use_mla=use_mla
189204
)
190205
expected = "FLASHINFER"
191206
assert backend.get_name() == expected

0 commit comments

Comments
 (0)