diff --git a/docs/design/attention_backends.md b/docs/design/attention_backends.md index 40108e490740..65449e6a808c 100644 --- a/docs/design/attention_backends.md +++ b/docs/design/attention_backends.md @@ -102,8 +102,8 @@ Priority is **1 = highest** (tried first). | Priority | Backend | | -------- | ------- | -| 1 | `FLASHINFER` | -| 2 | `FLASH_ATTN` | +| 1 | `FLASH_ATTN` | +| 2 | `FLASHINFER` | | 3 | `TRITON_ATTN` | | 4 | `FLEX_ATTENTION` | @@ -111,8 +111,8 @@ Priority is **1 = highest** (tried first). | Priority | Backend | | -------- | ------- | -| 1 | `FLASH_ATTN` | -| 2 | `FLASHINFER` | +| 1 | `FLASHINFER` | +| 2 | `FLASH_ATTN` | | 3 | `TRITON_ATTN` | | 4 | `FLEX_ATTENTION` | @@ -120,6 +120,16 @@ Priority is **1 = highest** (tried first). **Blackwell (SM 10.x):** +| Priority | Backend | +| -------- | ------- | +| 1 | `FLASH_ATTN_MLA` | +| 2 | `FLASHMLA` | +| 3 | `FLASHINFER_MLA` | +| 4 | `TRITON_MLA` | +| 5 | `FLASHMLA_SPARSE` | + +**Ampere/Hopper (SM 8.x-9.x):** + | Priority | Backend | | -------- | ------- | | 1 | `FLASHINFER_MLA` | @@ -130,16 +140,6 @@ Priority is **1 = highest** (tried first). | 6 | `FLASHMLA_SPARSE` | | 7 | `FLASHINFER_MLA_SPARSE` | -**Ampere/Hopper (SM 8.x-9.x):** - -| Priority | Backend | -| -------- | ------- | -| 1 | `FLASH_ATTN_MLA` | -| 2 | `FLASHMLA` | -| 3 | `FLASHINFER_MLA` | -| 4 | `TRITON_MLA` | -| 5 | `FLASHMLA_SPARSE` | - > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. ## Legend @@ -168,7 +168,7 @@ Priority is **1 = highest** (tried first). | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | -| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | +| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | diff --git a/tests/platforms/test_blackwell_class.py b/tests/platforms/test_blackwell_class.py new file mode 100644 index 000000000000..fcbd99d5a2ea --- /dev/null +++ b/tests/platforms/test_blackwell_class.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tests for is_blackwell_class() Blackwell-family GPU detection. + +Verifies that the unified Blackwell-class check correctly identifies +SM10x, SM11x, and SM12x devices while excluding non-Blackwell GPUs. +""" + +import importlib.util + +import pytest + +from vllm.platforms.interface import DeviceCapability, Platform + + +def _has_vllm_c() -> bool: + """Check if compiled vllm._C extension is available.""" + return importlib.util.find_spec("vllm._C") is not None + + +class _FakePlatform(Platform): + """Minimal Platform subclass for testing capability methods.""" + + _capability: DeviceCapability | None = None + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: + return cls._capability + + +# ── is_blackwell_class parametrized tests ────────────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "expected"), + [ + # Pre-Blackwell architectures → False + (7, 0, False), # Volta (V100) + (7, 5, False), # Turing (RTX 2080) + (8, 0, False), # Ampere (A100) + (8, 6, False), # Ampere (RTX 3060) + (8, 9, False), # Ada Lovelace (RTX 4090) + (9, 0, False), # Hopper (H100) + # Blackwell-class architectures → True + (10, 0, True), # B100/B200 + (10, 1, True), # B200 variant + (10, 3, True), # B200 variant + (11, 0, True), # Future Blackwell + (12, 0, True), # GB10/DGX Spark (SM120) + (12, 1, True), # GB10/DGX Spark (SM121) + # Future / post-Blackwell → False + (13, 0, False), + (15, 0, False), + ], + ids=lambda v: str(v), +) +def test_is_blackwell_class(major: int, minor: int, expected: bool): + _FakePlatform._capability = DeviceCapability(major=major, minor=minor) + assert _FakePlatform.is_blackwell_class() is expected + + +def test_is_blackwell_class_none_capability(): + """is_blackwell_class returns False when no capability is available.""" + _FakePlatform._capability = None + assert _FakePlatform.is_blackwell_class() is False + + +# ── is_blackwell_capability staticmethod tests ───────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "expected"), + [ + (9, 0, False), + (10, 0, True), + (11, 0, True), + (12, 1, True), + (13, 0, False), + ], + ids=lambda v: str(v), +) +def test_is_blackwell_capability_static(major: int, minor: int, expected: bool): + """Staticmethod works directly on DeviceCapability without device query.""" + cap = DeviceCapability(major=major, minor=minor) + assert Platform.is_blackwell_capability(cap) is expected + + +def test_is_blackwell_capability_consistency(): + """Staticmethod and classmethod agree for all Blackwell variants.""" + for major in (10, 11, 12): + cap = DeviceCapability(major=major, minor=0) + _FakePlatform._capability = cap + assert ( + Platform.is_blackwell_capability(cap) is _FakePlatform.is_blackwell_class() + ) + + +# ── is_device_capability_family consistency check ────────────────── + + +@pytest.mark.parametrize( + ("major", "minor", "family"), + [ + (10, 0, 100), + (10, 3, 100), + (11, 0, 110), + (12, 0, 120), + (12, 1, 120), + ], +) +def test_blackwell_class_covers_all_families(major: int, minor: int, family: int): + """Every Blackwell family (100, 110, 120) is also blackwell_class.""" + _FakePlatform._capability = DeviceCapability(major=major, minor=minor) + assert _FakePlatform.is_device_capability_family(family) is True + assert _FakePlatform.is_blackwell_class() is True + + +# ── Backend priority integration (mocked) ───────────────────────── + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_sm121(): + """SM121 should get Blackwell backend priorities (FlashInfer first).""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=12, minor=1) + # Non-MLA path: Blackwell should get FlashInfer first + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + assert "FLASHINFER" in backend_names + # FlashInfer should be before FlashAttn for Blackwell + fi_idx = backend_names.index("FLASHINFER") + if "FLASH_ATTN" in backend_names: + fa_idx = backend_names.index("FLASH_ATTN") + assert fi_idx < fa_idx, ( + f"FlashInfer ({fi_idx}) should come before FlashAttn ({fa_idx}) " + f"for SM121 Blackwell-class GPU" + ) + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_sm100_unchanged(): + """SM100 (B200) should still get Blackwell backend priorities.""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=10, minor=0) + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + assert "FLASHINFER" in backend_names + + +@pytest.mark.skipif( + not _has_vllm_c(), + reason="Requires compiled vllm._C extension", +) +def test_backend_priorities_hopper_not_blackwell(): + """SM90 (Hopper) should NOT get Blackwell backend priorities.""" + from vllm.platforms.cuda import _get_backend_priorities + + cap = DeviceCapability(major=9, minor=0) + priorities = _get_backend_priorities(cap, use_mla=False) + backend_names = [b.name for b in priorities] + # Hopper gets FlashAttn first, not FlashInfer + if "FLASH_ATTN" in backend_names and "FLASHINFER" in backend_names: + fa_idx = backend_names.index("FLASH_ATTN") + fi_idx = backend_names.index("FLASHINFER") + assert fa_idx < fi_idx, ( + f"FlashAttn ({fa_idx}) should come before FlashInfer ({fi_idx}) " + f"for SM90 Hopper GPU" + ) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 2025c41ab8d9..c1a7dc3d153e 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -51,8 +51,11 @@ def _get_backend_priorities( num_heads: int | None = None, ) -> list[AttentionBackendEnum]: """Get backend priorities with lazy import to avoid circular dependency.""" + from vllm.platforms.interface import Platform + + is_blackwell = Platform.is_blackwell_capability(device_capability) if use_mla: - if device_capability.major == 10: + if is_blackwell: # Prefer FlashInfer at low head counts (FlashMLA uses padding) if num_heads is not None and num_heads <= 16: sparse_backends = [ @@ -81,7 +84,7 @@ def _get_backend_priorities( AttentionBackendEnum.FLASHMLA_SPARSE, ] else: - if device_capability.major == 10: + if is_blackwell: return [ AttentionBackendEnum.FLASHINFER, AttentionBackendEnum.FLASH_ATTN, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index b538524995a5..eccbf2e7638c 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -345,6 +345,19 @@ def is_device_capability_family( return False return (current_capability.to_int() // 10) == (capability // 10) + @staticmethod + def is_blackwell_capability(capability: "DeviceCapability") -> bool: + """Check if a DeviceCapability represents a Blackwell-class GPU.""" + return capability.major in (10, 11, 12) + + @classmethod + def is_blackwell_class(cls, device_id: int = 0) -> bool: + """Check if device is a Blackwell-class GPU (SM10x, SM11x, SM12x).""" + capability = cls.get_device_capability(device_id=device_id) + if capability is None: + return False + return cls.is_blackwell_capability(capability) + @classmethod def get_device_name(cls, device_id: int = 0) -> str: """Get the name of a device.""" diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index ee104a6cc75c..e3259377bd48 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -53,7 +53,7 @@ def init_oracle_cache(cls) -> None: cls._oracle_cache = ( # type: ignore cls.UE8M0 - if current_platform.is_device_capability_family(100) + if current_platform.is_blackwell_class() else cls.FLOAT32_CEIL_UE8M0 ) @@ -72,7 +72,7 @@ def is_deep_gemm_supported() -> bool: """ is_supported_arch = current_platform.is_cuda() and ( current_platform.is_device_capability(90) - or current_platform.is_device_capability_family(100) + or current_platform.is_blackwell_class() ) return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 20502cbf0feb..28a6545e8811 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -75,8 +75,10 @@ def get_flash_attn_version( if device_capability.major == 9 and is_fa_version_supported(3): # Hopper (SM90): prefer FA3 fa_version = 3 - elif device_capability.major == 10 and is_fa_version_supported(4): - # Blackwell (SM100+, restrict to SM100 for now): prefer FA4 + elif current_platform.is_blackwell_capability( + device_capability + ) and is_fa_version_supported(4): + # Blackwell (SM100-SM121): prefer FA4 fa_version = 4 else: # Fallback to FA2 @@ -93,7 +95,10 @@ def get_flash_attn_version( fa_version = vllm_config.attention_config.flash_attn_version # 3. fallback for unsupported combinations - if device_capability.major >= 10 and fa_version == 3: + if ( + current_platform.is_blackwell_capability(device_capability) + and fa_version == 3 + ): logger.warning_once( "Cannot use FA version 3 on Blackwell platform, " "defaulting to FA version 4 if supported, otherwise FA2." @@ -127,7 +132,7 @@ def get_flash_attn_version( # See: https://github.com/Dao-AILab/flash-attention/issues/1959 if ( fa_version == 4 - and device_capability.major >= 10 + and current_platform.is_blackwell_capability(device_capability) and head_size is not None and head_size > 128 ): diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 844e8597e5b1..41d59bd070e4 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -385,7 +385,9 @@ def supports_sink(cls) -> bool: @classmethod def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: capability = current_platform.get_device_capability() - if capability is not None and capability.major == 10: + if capability is not None and current_platform.is_blackwell_capability( + capability + ): return "HND" return None @@ -630,7 +632,7 @@ def __init__( self.paged_kv_indices = self._make_buffer(max_num_pages) self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) - if self.head_dim == 256 and current_platform.is_device_capability_family(100): + if self.head_dim == 256 and current_platform.is_blackwell_class(): # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that # head size 256 and block size 16 is not supported on blackwell. assert kv_cache_spec.block_size != 16, (