Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,34 @@ Priority is **1 = highest** (tried first).

| Priority | Backend |
| -------- | ------- |
| 1 | `FLASHINFER` |
| 2 | `FLASH_ATTN` |
| 1 | `FLASH_ATTN` |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR seems to have broken generate_attention_backend_docs.py, please fix it

| 2 | `FLASHINFER` |
| 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` |

**Ampere/Hopper (SM 8.x-9.x):**

| Priority | Backend |
| -------- | ------- |
| 1 | `FLASH_ATTN` |
| 2 | `FLASHINFER` |
| 1 | `FLASHINFER` |
| 2 | `FLASH_ATTN` |
| 3 | `TRITON_ATTN` |
| 4 | `FLEX_ATTENTION` |

### MLA Attention (DeepSeek-style)

**Blackwell (SM 10.x):**

| Priority | Backend |
| -------- | ------- |
| 1 | `FLASH_ATTN_MLA` |
| 2 | `FLASHMLA` |
| 3 | `FLASHINFER_MLA` |
| 4 | `TRITON_MLA` |
| 5 | `FLASHMLA_SPARSE` |

**Ampere/Hopper (SM 8.x-9.x):**

| Priority | Backend |
| -------- | ------- |
| 1 | `FLASHINFER_MLA` |
Expand All @@ -130,16 +140,6 @@ Priority is **1 = highest** (tried first).
| 6 | `FLASHMLA_SPARSE` |
| 7 | `FLASHINFER_MLA_SPARSE` |

**Ampere/Hopper (SM 8.x-9.x):**

| Priority | Backend |
| -------- | ------- |
| 1 | `FLASH_ATTN_MLA` |
| 2 | `FLASHMLA` |
| 3 | `FLASHINFER_MLA` |
| 4 | `TRITON_MLA` |
| 5 | `FLASHMLA_SPARSE` |

> **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details.

## Legend
Expand Down Expand Up @@ -168,7 +168,7 @@ Priority is **1 = highest** (tried first).
| `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x |
| `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
| `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is wrong, we only want to use FA4 on blackwell

| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
Expand Down
176 changes: 176 additions & 0 deletions tests/platforms/test_blackwell_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for is_blackwell_class() Blackwell-family GPU detection.

Verifies that the unified Blackwell-class check correctly identifies
SM10x, SM11x, and SM12x devices while excluding non-Blackwell GPUs.
"""

import importlib.util

import pytest

from vllm.platforms.interface import DeviceCapability, Platform


def _has_vllm_c() -> bool:
"""Check if compiled vllm._C extension is available."""
return importlib.util.find_spec("vllm._C") is not None


class _FakePlatform(Platform):
"""Minimal Platform subclass for testing capability methods."""

_capability: DeviceCapability | None = None

@classmethod
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None:
return cls._capability


# ── is_blackwell_class parametrized tests ──────────────────────────


@pytest.mark.parametrize(
("major", "minor", "expected"),
[
# Pre-Blackwell architectures → False
(7, 0, False), # Volta (V100)
(7, 5, False), # Turing (RTX 2080)
(8, 0, False), # Ampere (A100)
(8, 6, False), # Ampere (RTX 3060)
(8, 9, False), # Ada Lovelace (RTX 4090)
(9, 0, False), # Hopper (H100)
# Blackwell-class architectures → True
(10, 0, True), # B100/B200
(10, 1, True), # B200 variant
(10, 3, True), # B200 variant
(11, 0, True), # Future Blackwell
(12, 0, True), # GB10/DGX Spark (SM120)
(12, 1, True), # GB10/DGX Spark (SM121)
# Future / post-Blackwell → False
(13, 0, False),
(15, 0, False),
],
ids=lambda v: str(v),
)
def test_is_blackwell_class(major: int, minor: int, expected: bool):
_FakePlatform._capability = DeviceCapability(major=major, minor=minor)
assert _FakePlatform.is_blackwell_class() is expected


def test_is_blackwell_class_none_capability():
"""is_blackwell_class returns False when no capability is available."""
_FakePlatform._capability = None
assert _FakePlatform.is_blackwell_class() is False


# ── is_blackwell_capability staticmethod tests ─────────────────────


@pytest.mark.parametrize(
("major", "minor", "expected"),
[
(9, 0, False),
(10, 0, True),
(11, 0, True),
(12, 1, True),
(13, 0, False),
],
ids=lambda v: str(v),
)
def test_is_blackwell_capability_static(major: int, minor: int, expected: bool):
"""Staticmethod works directly on DeviceCapability without device query."""
cap = DeviceCapability(major=major, minor=minor)
assert Platform.is_blackwell_capability(cap) is expected


def test_is_blackwell_capability_consistency():
"""Staticmethod and classmethod agree for all Blackwell variants."""
for major in (10, 11, 12):
cap = DeviceCapability(major=major, minor=0)
_FakePlatform._capability = cap
assert (
Platform.is_blackwell_capability(cap) is _FakePlatform.is_blackwell_class()
)


# ── is_device_capability_family consistency check ──────────────────


@pytest.mark.parametrize(
("major", "minor", "family"),
[
(10, 0, 100),
(10, 3, 100),
(11, 0, 110),
(12, 0, 120),
(12, 1, 120),
],
)
def test_blackwell_class_covers_all_families(major: int, minor: int, family: int):
"""Every Blackwell family (100, 110, 120) is also blackwell_class."""
_FakePlatform._capability = DeviceCapability(major=major, minor=minor)
assert _FakePlatform.is_device_capability_family(family) is True
assert _FakePlatform.is_blackwell_class() is True


# ── Backend priority integration (mocked) ─────────────────────────


@pytest.mark.skipif(
not _has_vllm_c(),
reason="Requires compiled vllm._C extension",
)
def test_backend_priorities_sm121():
"""SM121 should get Blackwell backend priorities (FlashInfer first)."""
from vllm.platforms.cuda import _get_backend_priorities

cap = DeviceCapability(major=12, minor=1)
# Non-MLA path: Blackwell should get FlashInfer first
priorities = _get_backend_priorities(cap, use_mla=False)
backend_names = [b.name for b in priorities]
assert "FLASHINFER" in backend_names
# FlashInfer should be before FlashAttn for Blackwell
fi_idx = backend_names.index("FLASHINFER")
if "FLASH_ATTN" in backend_names:
fa_idx = backend_names.index("FLASH_ATTN")
assert fi_idx < fa_idx, (
f"FlashInfer ({fi_idx}) should come before FlashAttn ({fa_idx}) "
f"for SM121 Blackwell-class GPU"
)


@pytest.mark.skipif(
not _has_vllm_c(),
reason="Requires compiled vllm._C extension",
)
def test_backend_priorities_sm100_unchanged():
"""SM100 (B200) should still get Blackwell backend priorities."""
from vllm.platforms.cuda import _get_backend_priorities

cap = DeviceCapability(major=10, minor=0)
priorities = _get_backend_priorities(cap, use_mla=False)
backend_names = [b.name for b in priorities]
assert "FLASHINFER" in backend_names


@pytest.mark.skipif(
not _has_vllm_c(),
reason="Requires compiled vllm._C extension",
)
def test_backend_priorities_hopper_not_blackwell():
"""SM90 (Hopper) should NOT get Blackwell backend priorities."""
from vllm.platforms.cuda import _get_backend_priorities

cap = DeviceCapability(major=9, minor=0)
priorities = _get_backend_priorities(cap, use_mla=False)
backend_names = [b.name for b in priorities]
# Hopper gets FlashAttn first, not FlashInfer
if "FLASH_ATTN" in backend_names and "FLASHINFER" in backend_names:
fa_idx = backend_names.index("FLASH_ATTN")
fi_idx = backend_names.index("FLASHINFER")
assert fa_idx < fi_idx, (
f"FlashAttn ({fa_idx}) should come before FlashInfer ({fi_idx}) "
f"for SM90 Hopper GPU"
)
7 changes: 5 additions & 2 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def _get_backend_priorities(
num_heads: int | None = None,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.platforms.interface import Platform

is_blackwell = Platform.is_blackwell_capability(device_capability)
if use_mla:
if device_capability.major == 10:
if is_blackwell:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually the desired priority ranking for cc 12 GPUs?

# Prefer FlashInfer at low head counts (FlashMLA uses padding)
if num_heads is not None and num_heads <= 16:
sparse_backends = [
Expand Down Expand Up @@ -81,7 +84,7 @@ def _get_backend_priorities(
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
if device_capability.major == 10:
if is_blackwell:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

return [
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN,
Expand Down
13 changes: 13 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,19 @@ def is_device_capability_family(
return False
return (current_capability.to_int() // 10) == (capability // 10)

@staticmethod
def is_blackwell_capability(capability: "DeviceCapability") -> bool:
"""Check if a DeviceCapability represents a Blackwell-class GPU."""
return capability.major in (10, 11, 12)

@classmethod
def is_blackwell_class(cls, device_id: int = 0) -> bool:
"""Check if device is a Blackwell-class GPU (SM10x, SM11x, SM12x)."""
capability = cls.get_device_capability(device_id=device_id)
if capability is None:
return False
return cls.is_blackwell_capability(capability)

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
"""Get the name of a device."""
Expand Down
4 changes: 2 additions & 2 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init_oracle_cache(cls) -> None:

cls._oracle_cache = ( # type: ignore
cls.UE8M0
if current_platform.is_device_capability_family(100)
if current_platform.is_blackwell_class()
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepGemm does not report support for cc 12 GPUs: https://github.com/deepseek-ai/DeepGEMM#requirements

Please either test this or revert this change

else cls.FLOAT32_CEIL_UE8M0
)

Expand All @@ -72,7 +72,7 @@ def is_deep_gemm_supported() -> bool:
"""
is_supported_arch = current_platform.is_cuda() and (
current_platform.is_device_capability(90)
or current_platform.is_device_capability_family(100)
or current_platform.is_blackwell_class()
Copy link
Copy Markdown
Collaborator

@MatthewBonanni MatthewBonanni Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepGemm does not report support for cc 12 GPUs: https://github.com/deepseek-ai/DeepGEMM#requirements

Please either test this or revert this change

)
return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch

Expand Down
13 changes: 9 additions & 4 deletions vllm/v1/attention/backends/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def get_flash_attn_version(
if device_capability.major == 9 and is_fa_version_supported(3):
# Hopper (SM90): prefer FA3
fa_version = 3
elif device_capability.major == 10 and is_fa_version_supported(4):
# Blackwell (SM100+, restrict to SM100 for now): prefer FA4
elif current_platform.is_blackwell_capability(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is FA4 faster than FA2 on cc 12 GPUs? This requires benchmarking

device_capability
) and is_fa_version_supported(4):
# Blackwell (SM100-SM121): prefer FA4
fa_version = 4
else:
# Fallback to FA2
Expand All @@ -93,7 +95,10 @@ def get_flash_attn_version(
fa_version = vllm_config.attention_config.flash_attn_version

# 3. fallback for unsupported combinations
if device_capability.major >= 10 and fa_version == 3:
if (
current_platform.is_blackwell_capability(device_capability)
and fa_version == 3
):
logger.warning_once(
"Cannot use FA version 3 on Blackwell platform, "
"defaulting to FA version 4 if supported, otherwise FA2."
Expand Down Expand Up @@ -127,7 +132,7 @@ def get_flash_attn_version(
# See: https://github.com/Dao-AILab/flash-attention/issues/1959
if (
fa_version == 4
and device_capability.major >= 10
and current_platform.is_blackwell_capability(device_capability)
Comment on lines 134 to +135
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this restriction apply to cc 12? If you're unsure, then leave as-is or test.

and head_size is not None
and head_size > 128
):
Expand Down
6 changes: 4 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ def supports_sink(cls) -> bool:
@classmethod
def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None:
capability = current_platform.get_device_capability()
if capability is not None and capability.major == 10:
if capability is not None and current_platform.is_blackwell_capability(
capability
):
return "HND"
return None

Expand Down Expand Up @@ -630,7 +632,7 @@ def __init__(
self.paged_kv_indices = self._make_buffer(max_num_pages)
self.paged_kv_last_page_len = self._make_buffer(max_num_reqs)

if self.head_dim == 256 and current_platform.is_device_capability_family(100):
if self.head_dim == 256 and current_platform.is_blackwell_class():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this restriction apply to cc 12? If you're unsure, then leave as-is or test.

# https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that
# head size 256 and block size 16 is not supported on blackwell.
assert kv_cache_spec.block_size != 16, (
Expand Down
Loading