Skip to content

[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568

Open
blake-snc wants to merge 3 commits intovllm-project:mainfrom
blake-snc:fix/marlin-sm12x-capability-check
Open

[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568
blake-snc wants to merge 3 commits intovllm-project:mainfrom
blake-snc:fix/marlin-sm12x-capability-check

Conversation

@blake-snc
Copy link
Copy Markdown

@blake-snc blake-snc commented Feb 28, 2026

Summary

SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) — both support native mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. However, SM121 is excluded from all Marlin and CUTLASS FP8 codepaths by exact-match arch guards (== 120, in [89, 120], enable_sm120_only).

This fixes 8 locations across codegen, runtime, dispatch, and tests using bounded SM12x family checks (arch // 10 == 12, major_capability == 12, enable_sm120_family, is_device_capability_family(120)):

Codegen (FP8 kernel template generation):

  • csrc/quantization/marlin/generate_kernels.py: arch in [89, 120]arch == 89 or arch // 10 == 12
  • csrc/moe/marlin_moe_wna16/generate_kernels.py: same fix

Runtime (FP8 activation gate):

  • csrc/moe/marlin_moe_wna16/ops.cu: == 120major_capability == 12

CUTLASS FP8 dispatch (kernel wrapper):

  • csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh: enable_sm120_onlyenable_sm120_family
  • csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh: same fix

Tests (FP8 test case generation):

  • tests/kernels/moe/test_moe.py: get_device_capability() not in [89, 120] → proper is_device_capability(89) / is_device_capability_family(120) API calls
  • tests/kernels/quantization/test_marlin_gemm.py: same fix

Python-side FP8 input validation:

  • vllm/model_executor/layers/quantization/utils/marlin_utils.py: is_device_capability(120)is_device_capability_family(120)

All checks use bounded SM12x family matching (covers SM120/SM121 but won't accidentally match future SM13x).

The enable_sm120_onlyenable_sm120_family change in the CUTLASS dispatch headers also resolves the CUTLASS FP4 GEMM failure on SM121 reported in #30163 ("Failed to run cutlass FP4 gemm on sm120. Error: Error Internal"), since enable_sm120_only uses __CUDA_ARCH__ == 1200 which excludes SM121 (__CUDA_ARCH__ == 1210), while enable_sm120_family uses >= 1200 && < 1300.

Validation

Tested on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):

Marlin FP4 GEMM (all 5 configs including N=100544): PASS
CUTLASS FP4 dispatch: cutlass_scaled_mm_supports_fp4(121) = True
Capability check logic:

SM89 (Ada):   allowed via exact match ✓
SM90 (Hopper): blocked ✓
SM120 (RTX 5090): allowed ✓
SM121 (DGX Spark): allowed ✓
SM130 (future): not matched ✓

Subsumes #35803. Fixes #35432. Fixes #30163. Relates to #30135.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Test plan

  • Validated on SM121a hardware (DGX Spark)
  • Marlin FP4 GEMM passes all 5 test configs
  • enable_sm120_family verified in common.hpp with correct >= 1200 && < 1300 range guard
  • is_device_capability_family(120) verified: uses to_int() // 10 == 120 // 10
  • Pre-commit hooks pass

🤖 Generated with Claude Code

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added the bug Something isn't working label Feb 28, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the device capability check for Marlin W4A8-FP8 support to include newer GPU architectures. The check is changed from an exact match for compute capability 12.0 (is_device_capability(120)) to a check for 12.0 or higher (has_device_capability(120)). This is intended to enable support on devices such as Blackwell variants that report compute capabilities like 12.1. The error message is also updated to reflect this change, now indicating support for SM120+ devices.

…ariants)

`get_marlin_input_dtype()` uses `is_device_capability(120)` which is an
exact match — SM121 devices (DGX Spark GB10, RTX 5090) return capability
(12, 1) and fail the check, blocking Marlin W4A8-FP8 with a misleading
"only support SM89 or SM120" error.

Changed to `has_device_capability(120)` which uses >= comparison,
allowing SM120 and all Blackwell variants (SM121, SM121a, etc.) while
still correctly blocking SM90 (Hopper) where Marlin FP8 is slower than
W4A16.

The SM89 (Ada) check remains as `is_device_capability(89)` since there
are no Ada variants.

Validated on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):
- Before: `is_device_capability(120)` → False → ValueError raised
- After:  `has_device_capability(120)` → True  → FP8 dtype returned
- SM90 still correctly blocked (has_device_capability(120) → False)
- SM89 still correctly allowed (is_device_capability(89) → True)

Fixes vllm-project#35432
Relates to vllm-project#30135

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
@blake-snc blake-snc force-pushed the fix/marlin-sm12x-capability-check branch from 30d8763 to f4b19a7 Compare February 28, 2026 02:33
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 2, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
blake-snc added a commit to blake-snc/vllm that referenced this pull request Mar 2, 2026
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
blake-snc added a commit to blake-snc/vllm that referenced this pull request Mar 3, 2026
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@blake-snc blake-snc requested a review from WoosukKwon as a code owner March 3, 2026 05:58
@mergify mergify bot added the nvidia label Mar 3, 2026
@blake-snc blake-snc changed the title [Bugfix] Fix Marlin W4A8-FP8 check for SM121+ Blackwell variants [Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths Mar 3, 2026
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
- Remove VLLM_TEST_FORCE_FP8_MARLIN=1 (CUTLASS FP8 now works on SM121
  via enable_sm120_family from PR vllm-project#35568)
- Make VLLM_USE_FLASHINFER_MOE_FP4 overridable (default still 0) so
  users can test FlashInfer TRTLLM MoE on SM121 after JIT patch
- Add auto-kill of existing vLLM server before launch (prevents GPU OOM
  on GB10 unified memory)
- Skip VLLM_TEST_FORCE_FP8_MARLIN in NVFP4 MoE oracle (not SM121-ready
  for that path)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 4, 2026
- Remove VLLM_TEST_FORCE_FP8_MARLIN=1 (CUTLASS FP8 now works on SM121
  via enable_sm120_family from PR vllm-project#35568)
- Make VLLM_USE_FLASHINFER_MOE_FP4 overridable (default still 0) so
  users can test FlashInfer TRTLLM MoE on SM121 after JIT patch
- Add auto-kill of existing vLLM server before launch (prevents GPU OOM
  on GB10 unified memory)
- Skip VLLM_TEST_FORCE_FP8_MARLIN in NVFP4 MoE oracle (not SM121-ready
  for that path)
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 5, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
blake-snc added a commit to blake-snc/vllm that referenced this pull request Mar 5, 2026
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
blake-snc and others added 2 commits March 12, 2026 14:35
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120
(RTX 5090) but is excluded by exact-match arch guards throughout the
Marlin and CUTLASS FP8 codepaths. This fixes 8 locations:

- generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89
  or arch >= 120` so SM121 FP8 kernel templates are generated
- ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation
  gate
- scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only`
  → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121
- test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper
  `is_device_capability(89)` / `is_device_capability_family(120)` APIs
  instead of broken `get_device_capability() not in [89, 120]`
  (NamedTuple vs int comparison)
- marlin_utils.py: `is_device_capability(120)` →
  `is_device_capability_family(120)` for Python-side FP8 input check

Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.

Contributed by Second Nature Computing (https://joinsecondnature.com)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Address review feedback: arch >= 120 would incorrectly match future
arch families (SM130+). Use arch // 10 == 12 for codegen and
major_capability == 12 for runtime to scope checks to the SM12x family.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
@blake-snc blake-snc force-pushed the fix/marlin-sm12x-capability-check branch from 2cb48d7 to 8092825 Compare March 12, 2026 21:36
@blake-snc
Copy link
Copy Markdown
Author

Updated — DCO sign-off has been added to all commits. Ready for review.

@blake-snc
Copy link
Copy Markdown
Author

@scottgl9 I see you have cherry-picked a good bit of this PR - is there anything left in this PR worth keeping it open for from your end?

scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 18, 2026
Change is_device_capability(120) to has_device_capability(120) so
SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support.
is_device_capability checks for exact match only.

Ref: vllm-project#35568
scottgl9 added a commit to scottgl9/vllm that referenced this pull request Mar 18, 2026
Cherry-pick upstream fixes for GB10 Spark (SM121):

- PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8
  kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py)
- PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4
  by using ReplicatedLinear with quant_config=None
- PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds
  on-the-fly FP8 dequantization in Triton kernels
- PR vllm-project#35936: tool_choice="required" falls back to tool_parser for
  non-JSON (XML) tool calls from Qwen3 models

Local patches:
- Patch FlashInfer TRTLLM JIT to compile for SM12x
  (supported_major_versions=[10] → [10, 12])
- Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
@blake-snc
Copy link
Copy Markdown
Author

blake-snc commented Mar 24, 2026

Verified that these changes are not yet in main — marlin/generate_kernels.py still has if arch in [89, 120] and scaled_mm.cuh still uses enable_sm120_only. The SM121 exclusion is still live. This PR should be good to merge as-is. @scottgl9 happy to rebase if there are conflicts.

@johnnynunez
Copy link
Copy Markdown
Contributor

cc @mgoin

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for separating this

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 25, 2026
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 25, 2026
@mgoin mgoin self-assigned this Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working nvidia ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Ready

3 participants