-
-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[Bugfix] Add is_blackwell_class() for SM121/GB10 DGX Spark support #34822
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,24 +102,34 @@ Priority is **1 = highest** (tried first). | |
|
|
||
| | Priority | Backend | | ||
| | -------- | ------- | | ||
| | 1 | `FLASHINFER` | | ||
| | 2 | `FLASH_ATTN` | | ||
| | 1 | `FLASH_ATTN` | | ||
| | 2 | `FLASHINFER` | | ||
| | 3 | `TRITON_ATTN` | | ||
| | 4 | `FLEX_ATTENTION` | | ||
|
|
||
| **Ampere/Hopper (SM 8.x-9.x):** | ||
|
|
||
| | Priority | Backend | | ||
| | -------- | ------- | | ||
| | 1 | `FLASH_ATTN` | | ||
| | 2 | `FLASHINFER` | | ||
| | 1 | `FLASHINFER` | | ||
| | 2 | `FLASH_ATTN` | | ||
| | 3 | `TRITON_ATTN` | | ||
| | 4 | `FLEX_ATTENTION` | | ||
|
|
||
| ### MLA Attention (DeepSeek-style) | ||
|
|
||
| **Blackwell (SM 10.x):** | ||
|
|
||
| | Priority | Backend | | ||
| | -------- | ------- | | ||
| | 1 | `FLASH_ATTN_MLA` | | ||
| | 2 | `FLASHMLA` | | ||
| | 3 | `FLASHINFER_MLA` | | ||
| | 4 | `TRITON_MLA` | | ||
| | 5 | `FLASHMLA_SPARSE` | | ||
|
|
||
| **Ampere/Hopper (SM 8.x-9.x):** | ||
|
|
||
| | Priority | Backend | | ||
| | -------- | ------- | | ||
| | 1 | `FLASHINFER_MLA` | | ||
|
|
@@ -130,16 +140,6 @@ Priority is **1 = highest** (tried first). | |
| | 6 | `FLASHMLA_SPARSE` | | ||
| | 7 | `FLASHINFER_MLA_SPARSE` | | ||
|
|
||
| **Ampere/Hopper (SM 8.x-9.x):** | ||
|
|
||
| | Priority | Backend | | ||
| | -------- | ------- | | ||
| | 1 | `FLASH_ATTN_MLA` | | ||
| | 2 | `FLASHMLA` | | ||
| | 3 | `FLASHINFER_MLA` | | ||
| | 4 | `TRITON_MLA` | | ||
| | 5 | `FLASHMLA_SPARSE` | | ||
|
|
||
| > **Note:** ROCm and CPU platforms have their own selection logic. See the platform-specific documentation for details. | ||
|
|
||
| ## Legend | ||
|
|
@@ -168,7 +168,7 @@ Priority is **1 = highest** (tried first). | |
| | `FLASHINFER` | TRTLLM† | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 64 | 64, 128, 256 | ✅ | ❌ | ✅ | Decoder | 10.x | | ||
| | `FLASH_ATTN` | FA2* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | ||
| | `FLASH_ATTN` | FA3* | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ❌ | ✅ | All | 9.x | | ||
| | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 | | ||
| | `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥8.0 | | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is wrong, we only want to use FA4 on blackwell |
||
| | `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any | | ||
| | `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any | | ||
| | `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A | | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,176 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Tests for is_blackwell_class() Blackwell-family GPU detection. | ||
|
|
||
| Verifies that the unified Blackwell-class check correctly identifies | ||
| SM10x, SM11x, and SM12x devices while excluding non-Blackwell GPUs. | ||
| """ | ||
|
|
||
| import importlib.util | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.platforms.interface import DeviceCapability, Platform | ||
|
|
||
|
|
||
| def _has_vllm_c() -> bool: | ||
| """Check if compiled vllm._C extension is available.""" | ||
| return importlib.util.find_spec("vllm._C") is not None | ||
|
|
||
|
|
||
| class _FakePlatform(Platform): | ||
| """Minimal Platform subclass for testing capability methods.""" | ||
|
|
||
| _capability: DeviceCapability | None = None | ||
|
|
||
| @classmethod | ||
| def get_device_capability(cls, device_id: int = 0) -> DeviceCapability | None: | ||
| return cls._capability | ||
|
|
||
|
|
||
| # ── is_blackwell_class parametrized tests ────────────────────────── | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("major", "minor", "expected"), | ||
| [ | ||
| # Pre-Blackwell architectures → False | ||
| (7, 0, False), # Volta (V100) | ||
| (7, 5, False), # Turing (RTX 2080) | ||
| (8, 0, False), # Ampere (A100) | ||
| (8, 6, False), # Ampere (RTX 3060) | ||
| (8, 9, False), # Ada Lovelace (RTX 4090) | ||
| (9, 0, False), # Hopper (H100) | ||
| # Blackwell-class architectures → True | ||
| (10, 0, True), # B100/B200 | ||
| (10, 1, True), # B200 variant | ||
| (10, 3, True), # B200 variant | ||
| (11, 0, True), # Future Blackwell | ||
| (12, 0, True), # GB10/DGX Spark (SM120) | ||
| (12, 1, True), # GB10/DGX Spark (SM121) | ||
| # Future / post-Blackwell → False | ||
| (13, 0, False), | ||
| (15, 0, False), | ||
| ], | ||
| ids=lambda v: str(v), | ||
| ) | ||
| def test_is_blackwell_class(major: int, minor: int, expected: bool): | ||
| _FakePlatform._capability = DeviceCapability(major=major, minor=minor) | ||
| assert _FakePlatform.is_blackwell_class() is expected | ||
|
|
||
|
|
||
| def test_is_blackwell_class_none_capability(): | ||
| """is_blackwell_class returns False when no capability is available.""" | ||
| _FakePlatform._capability = None | ||
| assert _FakePlatform.is_blackwell_class() is False | ||
|
|
||
|
|
||
| # ── is_blackwell_capability staticmethod tests ───────────────────── | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("major", "minor", "expected"), | ||
| [ | ||
| (9, 0, False), | ||
| (10, 0, True), | ||
| (11, 0, True), | ||
| (12, 1, True), | ||
| (13, 0, False), | ||
| ], | ||
| ids=lambda v: str(v), | ||
| ) | ||
| def test_is_blackwell_capability_static(major: int, minor: int, expected: bool): | ||
| """Staticmethod works directly on DeviceCapability without device query.""" | ||
| cap = DeviceCapability(major=major, minor=minor) | ||
| assert Platform.is_blackwell_capability(cap) is expected | ||
|
|
||
|
|
||
| def test_is_blackwell_capability_consistency(): | ||
| """Staticmethod and classmethod agree for all Blackwell variants.""" | ||
| for major in (10, 11, 12): | ||
| cap = DeviceCapability(major=major, minor=0) | ||
| _FakePlatform._capability = cap | ||
| assert ( | ||
| Platform.is_blackwell_capability(cap) is _FakePlatform.is_blackwell_class() | ||
| ) | ||
|
|
||
|
|
||
| # ── is_device_capability_family consistency check ────────────────── | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| ("major", "minor", "family"), | ||
| [ | ||
| (10, 0, 100), | ||
| (10, 3, 100), | ||
| (11, 0, 110), | ||
| (12, 0, 120), | ||
| (12, 1, 120), | ||
| ], | ||
| ) | ||
| def test_blackwell_class_covers_all_families(major: int, minor: int, family: int): | ||
| """Every Blackwell family (100, 110, 120) is also blackwell_class.""" | ||
| _FakePlatform._capability = DeviceCapability(major=major, minor=minor) | ||
| assert _FakePlatform.is_device_capability_family(family) is True | ||
| assert _FakePlatform.is_blackwell_class() is True | ||
|
|
||
|
|
||
| # ── Backend priority integration (mocked) ───────────────────────── | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not _has_vllm_c(), | ||
| reason="Requires compiled vllm._C extension", | ||
| ) | ||
| def test_backend_priorities_sm121(): | ||
| """SM121 should get Blackwell backend priorities (FlashInfer first).""" | ||
| from vllm.platforms.cuda import _get_backend_priorities | ||
|
|
||
| cap = DeviceCapability(major=12, minor=1) | ||
| # Non-MLA path: Blackwell should get FlashInfer first | ||
| priorities = _get_backend_priorities(cap, use_mla=False) | ||
| backend_names = [b.name for b in priorities] | ||
| assert "FLASHINFER" in backend_names | ||
| # FlashInfer should be before FlashAttn for Blackwell | ||
| fi_idx = backend_names.index("FLASHINFER") | ||
| if "FLASH_ATTN" in backend_names: | ||
| fa_idx = backend_names.index("FLASH_ATTN") | ||
| assert fi_idx < fa_idx, ( | ||
| f"FlashInfer ({fi_idx}) should come before FlashAttn ({fa_idx}) " | ||
| f"for SM121 Blackwell-class GPU" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not _has_vllm_c(), | ||
| reason="Requires compiled vllm._C extension", | ||
| ) | ||
| def test_backend_priorities_sm100_unchanged(): | ||
| """SM100 (B200) should still get Blackwell backend priorities.""" | ||
| from vllm.platforms.cuda import _get_backend_priorities | ||
|
|
||
| cap = DeviceCapability(major=10, minor=0) | ||
| priorities = _get_backend_priorities(cap, use_mla=False) | ||
| backend_names = [b.name for b in priorities] | ||
| assert "FLASHINFER" in backend_names | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| not _has_vllm_c(), | ||
| reason="Requires compiled vllm._C extension", | ||
| ) | ||
| def test_backend_priorities_hopper_not_blackwell(): | ||
| """SM90 (Hopper) should NOT get Blackwell backend priorities.""" | ||
| from vllm.platforms.cuda import _get_backend_priorities | ||
|
|
||
| cap = DeviceCapability(major=9, minor=0) | ||
| priorities = _get_backend_priorities(cap, use_mla=False) | ||
| backend_names = [b.name for b in priorities] | ||
| # Hopper gets FlashAttn first, not FlashInfer | ||
| if "FLASH_ATTN" in backend_names and "FLASHINFER" in backend_names: | ||
| fa_idx = backend_names.index("FLASH_ATTN") | ||
| fi_idx = backend_names.index("FLASHINFER") | ||
| assert fa_idx < fi_idx, ( | ||
| f"FlashAttn ({fa_idx}) should come before FlashInfer ({fi_idx}) " | ||
| f"for SM90 Hopper GPU" | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,8 +51,11 @@ def _get_backend_priorities( | |
| num_heads: int | None = None, | ||
| ) -> list[AttentionBackendEnum]: | ||
| """Get backend priorities with lazy import to avoid circular dependency.""" | ||
| from vllm.platforms.interface import Platform | ||
|
|
||
| is_blackwell = Platform.is_blackwell_capability(device_capability) | ||
| if use_mla: | ||
| if device_capability.major == 10: | ||
| if is_blackwell: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this actually the desired priority ranking for cc 12 GPUs? |
||
| # Prefer FlashInfer at low head counts (FlashMLA uses padding) | ||
| if num_heads is not None and num_heads <= 16: | ||
| sparse_backends = [ | ||
|
|
@@ -81,7 +84,7 @@ def _get_backend_priorities( | |
| AttentionBackendEnum.FLASHMLA_SPARSE, | ||
| ] | ||
| else: | ||
| if device_capability.major == 10: | ||
| if is_blackwell: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| return [ | ||
| AttentionBackendEnum.FLASHINFER, | ||
| AttentionBackendEnum.FLASH_ATTN, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,7 @@ def init_oracle_cache(cls) -> None: | |
|
|
||
| cls._oracle_cache = ( # type: ignore | ||
| cls.UE8M0 | ||
| if current_platform.is_device_capability_family(100) | ||
| if current_platform.is_blackwell_class() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DeepGemm does not report support for cc 12 GPUs: https://github.com/deepseek-ai/DeepGEMM#requirements Please either test this or revert this change |
||
| else cls.FLOAT32_CEIL_UE8M0 | ||
| ) | ||
|
|
||
|
|
@@ -72,7 +72,7 @@ def is_deep_gemm_supported() -> bool: | |
| """ | ||
| is_supported_arch = current_platform.is_cuda() and ( | ||
| current_platform.is_device_capability(90) | ||
| or current_platform.is_device_capability_family(100) | ||
| or current_platform.is_blackwell_class() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. DeepGemm does not report support for cc 12 GPUs: https://github.com/deepseek-ai/DeepGEMM#requirements Please either test this or revert this change |
||
| ) | ||
| return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,8 +75,10 @@ def get_flash_attn_version( | |
| if device_capability.major == 9 and is_fa_version_supported(3): | ||
| # Hopper (SM90): prefer FA3 | ||
| fa_version = 3 | ||
| elif device_capability.major == 10 and is_fa_version_supported(4): | ||
| # Blackwell (SM100+, restrict to SM100 for now): prefer FA4 | ||
| elif current_platform.is_blackwell_capability( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is FA4 faster than FA2 on cc 12 GPUs? This requires benchmarking |
||
| device_capability | ||
| ) and is_fa_version_supported(4): | ||
| # Blackwell (SM100-SM121): prefer FA4 | ||
| fa_version = 4 | ||
| else: | ||
| # Fallback to FA2 | ||
|
|
@@ -93,7 +95,10 @@ def get_flash_attn_version( | |
| fa_version = vllm_config.attention_config.flash_attn_version | ||
|
|
||
| # 3. fallback for unsupported combinations | ||
| if device_capability.major >= 10 and fa_version == 3: | ||
| if ( | ||
| current_platform.is_blackwell_capability(device_capability) | ||
| and fa_version == 3 | ||
| ): | ||
| logger.warning_once( | ||
| "Cannot use FA version 3 on Blackwell platform, " | ||
| "defaulting to FA version 4 if supported, otherwise FA2." | ||
|
|
@@ -127,7 +132,7 @@ def get_flash_attn_version( | |
| # See: https://github.com/Dao-AILab/flash-attention/issues/1959 | ||
| if ( | ||
| fa_version == 4 | ||
| and device_capability.major >= 10 | ||
| and current_platform.is_blackwell_capability(device_capability) | ||
|
Comment on lines
134
to
+135
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this restriction apply to cc 12? If you're unsure, then leave as-is or test. |
||
| and head_size is not None | ||
| and head_size > 128 | ||
| ): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -385,7 +385,9 @@ def supports_sink(cls) -> bool: | |
| @classmethod | ||
| def get_required_kv_cache_layout(cls) -> KVCacheLayoutType | None: | ||
| capability = current_platform.get_device_capability() | ||
| if capability is not None and capability.major == 10: | ||
| if capability is not None and current_platform.is_blackwell_capability( | ||
| capability | ||
| ): | ||
| return "HND" | ||
| return None | ||
|
|
||
|
|
@@ -630,7 +632,7 @@ def __init__( | |
| self.paged_kv_indices = self._make_buffer(max_num_pages) | ||
| self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) | ||
|
|
||
| if self.head_dim == 256 and current_platform.is_device_capability_family(100): | ||
| if self.head_dim == 256 and current_platform.is_blackwell_class(): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this restriction apply to cc 12? If you're unsure, then leave as-is or test. |
||
| # https://github.com/flashinfer-ai/flashinfer/issues/1993 reports that | ||
| # head size 256 and block size 16 is not supported on blackwell. | ||
| assert kv_cache_spec.block_size != 16, ( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR seems to have broken
generate_attention_backend_docs.py, please fix it