From b0092f53c237043702f2e2bd10f915ab020f1f29 Mon Sep 17 00:00:00 2001 From: Andrew Mello Date: Wed, 18 Feb 2026 08:17:46 -0800 Subject: [PATCH 1/3] [Bugfix] Add is_blackwell_class() for SM121/GB10 DGX Spark support SM121 (GB10, DGX Spark) has capability major=12, which was not recognized by the existing is_device_capability_family(100) checks (major=10 only). This caused SM121 to fall into non-Blackwell code paths, selecting wrong attention backends and KV cache layouts. Add is_blackwell_class() to Platform that returns True for major in {10, 11, 12} (the full Blackwell architecture family). Update key code paths: - Backend priorities: SM121 gets Blackwell priority list (FlashInfer) - FA3 fallback: SM121 correctly falls back to FA2 - FlashInfer KV cache: SM121 gets HND layout - FlashInfer head_dim=256 guard: applies to all Blackwell-class - DeepGemm: SM121 recognized as Blackwell for oracle and support check This is a minimal pure-Python fix; no C++/CUDA recompilation needed. CMakeLists.txt changes for native SM121 kernel compilation are left for a follow-up PR. Related: #31740, #33313 Signed-off-by: Andrew Mello --- docs/design/attention_backends.md | 72 ++++++++++++------------ vllm/platforms/cuda.py | 5 +- vllm/platforms/interface.py | 8 +++ vllm/utils/deep_gemm.py | 4 +- vllm/v1/attention/backends/fa_utils.py | 14 +++-- vllm/v1/attention/backends/flashinfer.py | 4 +- 6 files changed, 60 insertions(+), 47 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 40108e490740..97f31902ded0 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -101,18 +101,18 @@ Priority is **1 = highest** (tried first). **Blackwell (SM 10.x):** | Priority | Backend | -| -------- | ------- | -| 1 | `FLASHINFER` | -| 2 | `FLASH_ATTN` | +|----------|---------| +| 1 | `FLASH_ATTN` | +| 2 | `FLASHINFER` | | 3 | `TRITON_ATTN` | | 4 | `FLEX_ATTENTION` | **Ampere/Hopper (SM 8.x-9.x):** | Priority | Backend | -| -------- | ------- | -| 1 | `FLASH_ATTN` | -| 2 | `FLASHINFER` | +|----------|---------| +| 1 | `FLASHINFER` | +| 2 | `FLASH_ATTN` | | 3 | `TRITON_ATTN` | | 4 | `FLEX_ATTENTION` | @@ -121,7 +121,17 @@ Priority is **1 = highest** (tried first). **Blackwell (SM 10.x):** | Priority | Backend | -| -------- | ------- | +|----------|---------| +| 1 | `FLASH_ATTN_MLA` | +| 2 | `FLASHMLA` | +| 3 | `FLASHINFER_MLA` | +| 4 | `TRITON_MLA` | +| 5 | `FLASHMLA_SPARSE` | + +**Ampere/Hopper (SM 8.x-9.x):** + +| Priority | Backend | +|----------|---------| | 1 | `FLASHINFER_MLA` | | 2 | `CUTLASS_MLA` | | 3 | `FLASH_ATTN_MLA` | @@ -130,22 +140,12 @@ Priority is **1 = highest** (tried first). | 6 | `FLASHMLA_SPARSE` | | 7 | `FLASHINFER_MLA_SPARSE` | -**Ampere/Hopper (SM 8.x-9.x):** - -| Priority | Backend | -| -------- | ------- | -| 1 | `FLASH_ATTN_MLA` | -| 2 | `FLASHMLA` | -| 3 | `FLASHINFER_MLA` | -| 4 | `TRITON_MLA` | -| 5 | `FLASHMLA_SPARSE` | - > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. ## Legend | Column | Description | -| ------ | ----------- | +|--------|-------------| | **Dtypes** | Supported model data types (fp16, bf16, fp32) | | **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) | | **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) | @@ -162,24 +162,23 @@ Priority is **1 = highest** (tried first). ## Standard Attention (MHA, MQA, GQA) Backends | Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. | -| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | -| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | +|---------|---------|--------|-----------|-------------|------------|------|-----------|-----|-----------------|--------------| +| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | -| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | -| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | -| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | -| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | -| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | -| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | -| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | +| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | +| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | +| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A | +| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | +| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > -> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise. +> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise. ## MLA (Multi-head Latent Attention) Backends @@ -191,10 +190,10 @@ The prefill backend is selected at runtime based on hardware and configuration. | Backend | Description | Compute Cap. | Enable | Disable | Notes | -| ------- | ----------- | ------------ | ------ | ------- | ----- | +|---------|-------------|--------------|--------|---------|-------| | TRT-LLM Ragged‡ | TensorRT-LLM ragged attention | 10.x | Default on SM100 | `-ac.use_trtllm_ragged_deepseek_prefill=0` | DeepSeek R1 dims only | | FlashInfer | FlashInfer CUTLASS backend | 10.x | `-ac.disable_flashinfer_prefill=0` | `-ac.disable_flashinfer_prefill=1` | DeepSeek R1 dims only | -| cuDNN | cuDNN-based attention | 10.x | `-ac.use_cudnn_prefill=1` | `-ac.use_cudnn_prefill=0` | | +| cuDNN | cuDNN-based attention | 10.x | `-ac.use_cudnn_prefill=1` | `-ac.use_cudnn_prefill=0` | | | FlashAttention | FlashAttention varlen (FA2/FA3) | Any | Default fallback | Use other backends | FA3 on SM90, FA2 otherwise | > **‡** TRT-LLM Ragged is the default on Blackwell (SM100). @@ -203,15 +202,14 @@ configuration. ### Decode Backends | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | -| ------- | ------ | --------- | ----------- | ---------- | ---- | ------ | --------- | --- | --------------- | ------------ | +|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | +| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | -| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | -| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2025c41ab8d9..2efc9d2105d2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -51,8 +51,9 @@ def _get_backend_priorities( num_heads: int | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" + is_blackwell = device_capability.major in (10, 11, 12) if use_mla: - if device_capability.major == 10: + if is_blackwell: # Prefer FlashInfer at low head counts (FlashMLA uses padding) if num_heads is not None and num_heads <= 16: sparse_backends = [ @@ -81,7 +82,7 @@ def _get_backend_priorities( AttentionBackendEnum.FLASHMLA_SPARSE, ] else: - if device_capability.major == 10: + if is_blackwell: return [ AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASH_ATTN, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b538524995a5..f793f3817e34 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -345,6 +345,14 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) + @classmethod + def is_blackwell_class(cls, device_id: int = 0) -> bool: + """Check if device is a Blackwell-class GPU (SM10x, SM11x, SM12x).""" + capability = cls.get_device_capability(device_id=device_id) + if capability is None: + return False + return capability.major in (10, 11, 12) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ee104a6cc75c..e3259377bd48 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -53,7 +53,7 @@ def init_oracle_cache(cls) -> None: cls._oracle_cache = ( # type: ignore cls.UE8M0 - if current_platform.is_device_capability_family(100) + if current_platform.is_blackwell_class() else cls.FLOAT32_CEIL_UE8M0 ) @@ -72,7 +72,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) + or current_platform.is_blackwell_class() ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 20502cbf0feb..e9826eda2af9 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -75,8 +75,11 @@ def get_flash_attn_version( if device_capability.major == 9 and is_fa_version_supported(3): # Hopper (SM90): prefer FA3 fa_version = 3 - elif device_capability.major == 10 and is_fa_version_supported(4): - # Blackwell (SM100+, restrict to SM100 for now): prefer FA4 + elif ( + current_platform.is_blackwell_capability(device_capability) + and is_fa_version_supported(4) + ): + # Blackwell (SM100-SM121): prefer FA4 fa_version = 4 else: # Fallback to FA2 @@ -93,7 +96,10 @@ def get_flash_attn_version( fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations - if device_capability.major >= 10 and fa_version == 3: + if ( + current_platform.is_blackwell_capability(device_capability) + and fa_version == 3 + ): logger.warning_once( "Cannot use FA version 3 on Blackwell platform, " "defaulting to FA version 4 if supported, otherwise FA2." @@ -127,7 +133,7 @@ def get_flash_attn_version( # See: https://github.com/Dao-AILab/flash-attention/issues/1959 if ( fa_version == 4 - and device_capability.major >= 10 + and current_platform.is_blackwell_capability(device_capability) and head_size is not None and head_size > 128 ): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 844e8597e5b1..4452848cb0c5 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -385,7 +385,7 @@ def supports_sink(cls) -> bool: @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: capability = current_platform.get_device_capability() - if capability is not None and capability.major == 10: + if capability is not None and capability.major in (10, 11, 12): return "HND" return None @@ -630,7 +630,7 @@ def __init__( self.paged_kv_indices = self._make_buffer(max_num_pages) self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) - if self.head_dim == 256 and current_platform.is_device_capability_family(100): + if self.head_dim == 256 and current_platform.is_blackwell_class(): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, ( From ff9c38ac7e032cf2cfe47d037ef12a56bc150aa7 Mon Sep 17 00:00:00 2001 From: Andrew Mello Date: Wed, 18 Feb 2026 09:12:35 -0800 Subject: [PATCH 2/3] [Bugfix] Add tests for is_blackwell_class() platform detection Unit tests for Blackwell-family GPU detection covering: - Parametrized capability matrix (Volta through post-Blackwell) - None capability handling - Consistency with is_device_capability_family for all Blackwell families - Backend priority integration tests (skipped without compiled _C extension) Signed-off-by: Andrew Mello --- tests/platforms/test_blackwell_class.py | 146 ++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 tests/platforms/test_blackwell_class.py diff --git a/tests/platforms/test_blackwell_class.py b/tests/platforms/test_blackwell_class.py new file mode 100644 index 000000000000..c37466d107f8 --- /dev/null +++ b/tests/platforms/test_blackwell_class.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for is_blackwell_class() Blackwell-family GPU detection. + +Verifies that the unified Blackwell-class check correctly identifies +SM10x, SM11x, and SM12x devices while excluding non-Blackwell GPUs. +""" + +import importlib.util + +import pytest + +from vllm.platforms.interface import DeviceCapability, Platform + + +def _has_vllm_c() -> bool: + """Check if compiled vllm._C extension is available.""" + return importlib.util.find_spec("vllm._C") is not None + + +class _FakePlatform(Platform): + """Minimal Platform subclass for testing capability methods.""" + + _capability: DeviceCapability | None = None + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + return cls._capability + + +# ── is_blackwell_class parametrized tests ────────────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "expected"), + [ + # Pre-Blackwell architectures → False + (7, 0, False), # Volta (V100) + (7, 5, False), # Turing (RTX 2080) + (8, 0, False), # Ampere (A100) + (8, 6, False), # Ampere (RTX 3060) + (8, 9, False), # Ada Lovelace (RTX 4090) + (9, 0, False), # Hopper (H100) + # Blackwell-class architectures → True + (10, 0, True), # B100/B200 + (10, 1, True), # B200 variant + (10, 3, True), # B200 variant + (11, 0, True), # Future Blackwell + (12, 0, True), # GB10/DGX Spark (SM120) + (12, 1, True), # GB10/DGX Spark (SM121) + # Future / post-Blackwell → False + (13, 0, False), + (15, 0, False), + ], + ids=lambda v: str(v), +) +def test_is_blackwell_class(major: int, minor: int, expected: bool): + _FakePlatform._capability = DeviceCapability(major=major, minor=minor) + assert _FakePlatform.is_blackwell_class() is expected + + +def test_is_blackwell_class_none_capability(): + """is_blackwell_class returns False when no capability is available.""" + _FakePlatform._capability = None + assert _FakePlatform.is_blackwell_class() is False + + +# ── is_device_capability_family consistency check ────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "family"), + [ + (10, 0, 100), + (10, 3, 100), + (11, 0, 110), + (12, 0, 120), + (12, 1, 120), + ], +) +def test_blackwell_class_covers_all_families(major: int, minor: int, family: int): + """Every Blackwell family (100, 110, 120) is also blackwell_class.""" + _FakePlatform._capability = DeviceCapability(major=major, minor=minor) + assert _FakePlatform.is_device_capability_family(family) is True + assert _FakePlatform.is_blackwell_class() is True + + +# ── Backend priority integration (mocked) ───────────────────────── + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_sm121(): + """SM121 should get Blackwell backend priorities (FlashInfer first).""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=12, minor=1) + # Non-MLA path: Blackwell should get FlashInfer first + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + assert "FLASHINFER" in backend_names + # FlashInfer should be before FlashAttn for Blackwell + fi_idx = backend_names.index("FLASHINFER") + if "FLASH_ATTN" in backend_names: + fa_idx = backend_names.index("FLASH_ATTN") + assert fi_idx < fa_idx, ( + f"FlashInfer ({fi_idx}) should come before FlashAttn ({fa_idx}) " + f"for SM121 Blackwell-class GPU" + ) + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_sm100_unchanged(): + """SM100 (B200) should still get Blackwell backend priorities.""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=10, minor=0) + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + assert "FLASHINFER" in backend_names + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_hopper_not_blackwell(): + """SM90 (Hopper) should NOT get Blackwell backend priorities.""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=9, minor=0) + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + # Hopper gets FlashAttn first, not FlashInfer + if "FLASH_ATTN" in backend_names and "FLASHINFER" in backend_names: + fa_idx = backend_names.index("FLASH_ATTN") + fi_idx = backend_names.index("FLASHINFER") + assert fa_idx < fi_idx, ( + f"FlashAttn ({fa_idx}) should come before FlashInfer ({fi_idx}) " + f"for SM90 Hopper GPU" + ) From de01ee1a5104ef0bbd3296cfff5b035a22f1e9ab Mon Sep 17 00:00:00 2001 From: Andrew Mello Date: Wed, 18 Feb 2026 09:17:22 -0800 Subject: [PATCH 3/3] [Bugfix] Add is_blackwell_capability() staticmethod per review feedback Address review suggestions from @gemini-code-assist and @amadhan882: - Add Platform.is_blackwell_capability(cap) @staticmethod that takes a DeviceCapability directly, avoiding redundant device queries - Refactor is_blackwell_class() to delegate to the new staticmethod - Update cuda.py, fa_utils.py, flashinfer.py to use the staticmethod where a DeviceCapability object is already available - Add tests for staticmethod and consistency with classmethod Signed-off-by: Andrew Mello --- docs/design/attention_backends.md | 46 ++++++++++++------------ tests/platforms/test_blackwell_class.py | 30 ++++++++++++++++ vllm/platforms/cuda.py | 4 ++- vllm/platforms/interface.py | 7 +++- vllm/v1/attention/backends/fa_utils.py | 7 ++-- vllm/v1/attention/backends/flashinfer.py | 4 ++- 6 files changed, 69 insertions(+), 29 deletions(-) diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 97f31902ded0..65449e6a808c 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -101,7 +101,7 @@ Priority is **1 = highest** (tried first). **Blackwell (SM 10.x):** | Priority | Backend | -|----------|---------| +| -------- | ------- | | 1 | `FLASH_ATTN` | | 2 | `FLASHINFER` | | 3 | `TRITON_ATTN` | @@ -110,7 +110,7 @@ Priority is **1 = highest** (tried first). **Ampere/Hopper (SM 8.x-9.x):** | Priority | Backend | -|----------|---------| +| -------- | ------- | | 1 | `FLASHINFER` | | 2 | `FLASH_ATTN` | | 3 | `TRITON_ATTN` | @@ -121,7 +121,7 @@ Priority is **1 = highest** (tried first). **Blackwell (SM 10.x):** | Priority | Backend | -|----------|---------| +| -------- | ------- | | 1 | `FLASH_ATTN_MLA` | | 2 | `FLASHMLA` | | 3 | `FLASHINFER_MLA` | @@ -131,7 +131,7 @@ Priority is **1 = highest** (tried first). **Ampere/Hopper (SM 8.x-9.x):** | Priority | Backend | -|----------|---------| +| -------- | ------- | | 1 | `FLASHINFER_MLA` | | 2 | `CUTLASS_MLA` | | 3 | `FLASH_ATTN_MLA` | @@ -145,7 +145,7 @@ Priority is **1 = highest** (tried first). ## Legend | Column | Description | -|--------|-------------| +| ------ | ----------- | | **Dtypes** | Supported model data types (fp16, bf16, fp32) | | **KV Dtypes** | Supported KV cache data types (`auto`, `fp8`, `fp8_e4m3`, etc.) | | **Block Sizes** | Supported KV cache block sizes (%N means multiples of N) | @@ -162,23 +162,24 @@ Priority is **1 = highest** (tried first). ## Standard Attention (MHA, MQA, GQA) Backends | Backend | Version | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | MM Prefix | DCP | Attention Types | Compute Cap. | -|---------|---------|--------|-----------|-------------|------------|------|-----------|-----|-----------------|--------------| -| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | +| ------- | ------- | ------ | --------- | ----------- | ---------- | ---- | --------- | --- | --------------- | ------------ | +| `CPU_ATTN` | | fp16, bf16, fp32 | `auto` | Any | 32, 64, 80, 96, 112, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | All | N/A | | `FLASHINFER` | Native† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ❌ | ❌ | ✅ | Decoder | 7.x-9.x | | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | -| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | -| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | -| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | N/A | -| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | -| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | +| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | +| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | +| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | +| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | +| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A | +| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A | +| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any | +| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any | > **†** FlashInfer uses TRTLLM attention on Blackwell (SM100), which supports sinks. Disable via `--attention-config.use_trtllm_attention=0`. > -> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2` or `3`. Default is FA3 on SM90, FA2 otherwise. +> **\*** Specify the FlashAttention version via `--attention-config.flash_attn_version=2`, `3`, or `4`. Default is FA4 on SM100+ (Blackwell), FA3 on SM90 (Hopper), FA2 otherwise. ## MLA (Multi-head Latent Attention) Backends @@ -190,10 +191,10 @@ The prefill backend is selected at runtime based on hardware and configuration. | Backend | Description | Compute Cap. | Enable | Disable | Notes | -|---------|-------------|--------------|--------|---------|-------| +| ------- | ----------- | ------------ | ------ | ------- | ----- | | TRT-LLM Ragged‡ | TensorRT-LLM ragged attention | 10.x | Default on SM100 | `-ac.use_trtllm_ragged_deepseek_prefill=0` | DeepSeek R1 dims only | | FlashInfer | FlashInfer CUTLASS backend | 10.x | `-ac.disable_flashinfer_prefill=0` | `-ac.disable_flashinfer_prefill=1` | DeepSeek R1 dims only | -| cuDNN | cuDNN-based attention | 10.x | `-ac.use_cudnn_prefill=1` | `-ac.use_cudnn_prefill=0` | | +| cuDNN | cuDNN-based attention | 10.x | `-ac.use_cudnn_prefill=1` | `-ac.use_cudnn_prefill=0` | | | FlashAttention | FlashAttention varlen (FA2/FA3) | Any | Default fallback | Use other backends | FA3 on SM90, FA2 otherwise | > **‡** TRT-LLM Ragged is the default on Blackwell (SM100). @@ -202,14 +203,15 @@ configuration. ### Decode Backends | Backend | Dtypes | KV Dtypes | Block Sizes | Head Sizes | Sink | Sparse | MM Prefix | DCP | Attention Types | Compute Cap. | -|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------| +| ------- | ------ | --------- | ----------- | ---------- | ---- | ------ | --------- | --- | --------------- | ------------ | | `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x | | `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x | -| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | +| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x | | `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x | | `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x | | `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x | -| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto` | Any | 576 | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | +| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A | | `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A | -| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any | +| `XPU_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | Any | diff --git a/tests/platforms/test_blackwell_class.py b/tests/platforms/test_blackwell_class.py index c37466d107f8..fcbd99d5a2ea 100644 --- a/tests/platforms/test_blackwell_class.py +++ b/tests/platforms/test_blackwell_class.py @@ -65,6 +65,36 @@ def test_is_blackwell_class_none_capability(): assert _FakePlatform.is_blackwell_class() is False +# ── is_blackwell_capability staticmethod tests ───────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "expected"), + [ + (9, 0, False), + (10, 0, True), + (11, 0, True), + (12, 1, True), + (13, 0, False), + ], + ids=lambda v: str(v), +) +def test_is_blackwell_capability_static(major: int, minor: int, expected: bool): + """Staticmethod works directly on DeviceCapability without device query.""" + cap = DeviceCapability(major=major, minor=minor) + assert Platform.is_blackwell_capability(cap) is expected + + +def test_is_blackwell_capability_consistency(): + """Staticmethod and classmethod agree for all Blackwell variants.""" + for major in (10, 11, 12): + cap = DeviceCapability(major=major, minor=0) + _FakePlatform._capability = cap + assert ( + Platform.is_blackwell_capability(cap) is _FakePlatform.is_blackwell_class() + ) + + # ── is_device_capability_family consistency check ────────────────── diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2efc9d2105d2..c1a7dc3d153e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -51,7 +51,9 @@ def _get_backend_priorities( num_heads: int | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" - is_blackwell = device_capability.major in (10, 11, 12) + from vllm.platforms.interface import Platform + + is_blackwell = Platform.is_blackwell_capability(device_capability) if use_mla: if is_blackwell: # Prefer FlashInfer at low head counts (FlashMLA uses padding) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f793f3817e34..eccbf2e7638c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -345,13 +345,18 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) + @staticmethod + def is_blackwell_capability(capability: "DeviceCapability") -> bool: + """Check if a DeviceCapability represents a Blackwell-class GPU.""" + return capability.major in (10, 11, 12) + @classmethod def is_blackwell_class(cls, device_id: int = 0) -> bool: """Check if device is a Blackwell-class GPU (SM10x, SM11x, SM12x).""" capability = cls.get_device_capability(device_id=device_id) if capability is None: return False - return capability.major in (10, 11, 12) + return cls.is_blackwell_capability(capability) @classmethod def get_device_name(cls, device_id: int = 0) -> str: diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index e9826eda2af9..28a6545e8811 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -75,10 +75,9 @@ def get_flash_attn_version( if device_capability.major == 9 and is_fa_version_supported(3): # Hopper (SM90): prefer FA3 fa_version = 3 - elif ( - current_platform.is_blackwell_capability(device_capability) - and is_fa_version_supported(4) - ): + elif current_platform.is_blackwell_capability( + device_capability + ) and is_fa_version_supported(4): # Blackwell (SM100-SM121): prefer FA4 fa_version = 4 else: diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4452848cb0c5..41d59bd070e4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -385,7 +385,9 @@ def supports_sink(cls) -> bool: @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: capability = current_platform.get_device_capability() - if capability is not None and capability.major in (10, 11, 12): + if capability is not None and current_platform.is_blackwell_capability( + capability + ): return "HND" return None