[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568
[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35568blake-snc wants to merge 3 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Code Review
This pull request updates the device capability check for Marlin W4A8-FP8 support to include newer GPU architectures. The check is changed from an exact match for compute capability 12.0 (is_device_capability(120)) to a check for 12.0 or higher (has_device_capability(120)). This is intended to enable support on devices such as Blackwell variants that report compute capabilities like 12.1. The error message is also updated to reflect this change, now indicating support for SM120+ devices.
…ariants) `get_marlin_input_dtype()` uses `is_device_capability(120)` which is an exact match — SM121 devices (DGX Spark GB10, RTX 5090) return capability (12, 1) and fail the check, blocking Marlin W4A8-FP8 with a misleading "only support SM89 or SM120" error. Changed to `has_device_capability(120)` which uses >= comparison, allowing SM120 and all Blackwell variants (SM121, SM121a, etc.) while still correctly blocking SM90 (Hopper) where Marlin FP8 is slower than W4A16. The SM89 (Ada) check remains as `is_device_capability(89)` since there are no Ada variants. Validated on DGX Spark (NVIDIA GB10, SM121a / capability 12.1): - Before: `is_device_capability(120)` → False → ValueError raised - After: `has_device_capability(120)` → True → FP8 dtype returned - SM90 still correctly blocked (has_device_capability(120) → False) - SM89 still correctly allowed (is_device_capability(89) → True) Fixes vllm-project#35432 Relates to vllm-project#30135 Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
30d8763 to
f4b19a7
Compare
Change is_device_capability(120) to has_device_capability(120) so SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support. is_device_capability checks for exact match only. Ref: vllm-project#35568
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) but is excluded by exact-match arch guards throughout the Marlin and CUTLASS FP8 codepaths. This fixes 8 locations: - generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89 or arch >= 120` so SM121 FP8 kernel templates are generated - ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation gate - scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only` → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121 - test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper `is_device_capability(89)` / `is_device_capability_family(120)` APIs instead of broken `get_device_capability() not in [89, 120]` (NamedTuple vs int comparison) - marlin_utils.py: `is_device_capability(120)` → `is_device_capability_family(120)` for Python-side FP8 input check Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in marlin.cu. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) but is excluded by exact-match arch guards throughout the Marlin and CUTLASS FP8 codepaths. This fixes 8 locations: - generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89 or arch >= 120` so SM121 FP8 kernel templates are generated - ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation gate - scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only` → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121 - test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper `is_device_capability(89)` / `is_device_capability_family(120)` APIs instead of broken `get_device_capability() not in [89, 120]` (NamedTuple vs int comparison) - marlin_utils.py: `is_device_capability(120)` → `is_device_capability_family(120)` for Python-side FP8 input check Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in marlin.cu. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
- Remove VLLM_TEST_FORCE_FP8_MARLIN=1 (CUTLASS FP8 now works on SM121 via enable_sm120_family from PR vllm-project#35568) - Make VLLM_USE_FLASHINFER_MOE_FP4 overridable (default still 0) so users can test FlashInfer TRTLLM MoE on SM121 after JIT patch - Add auto-kill of existing vLLM server before launch (prevents GPU OOM on GB10 unified memory) - Skip VLLM_TEST_FORCE_FP8_MARLIN in NVFP4 MoE oracle (not SM121-ready for that path)
Change is_device_capability(120) to has_device_capability(120) so SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support. is_device_capability checks for exact match only. Ref: vllm-project#35568
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
- Remove VLLM_TEST_FORCE_FP8_MARLIN=1 (CUTLASS FP8 now works on SM121 via enable_sm120_family from PR vllm-project#35568) - Make VLLM_USE_FLASHINFER_MOE_FP4 overridable (default still 0) so users can test FlashInfer TRTLLM MoE on SM121 after JIT patch - Add auto-kill of existing vLLM server before launch (prevents GPU OOM on GB10 unified memory) - Skip VLLM_TEST_FORCE_FP8_MARLIN in NVFP4 MoE oracle (not SM121-ready for that path)
Change is_device_capability(120) to has_device_capability(120) so SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support. is_device_capability checks for exact match only. Ref: vllm-project#35568
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) but is excluded by exact-match arch guards throughout the Marlin and CUTLASS FP8 codepaths. This fixes 8 locations: - generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89 or arch >= 120` so SM121 FP8 kernel templates are generated - ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation gate - scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only` → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121 - test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper `is_device_capability(89)` / `is_device_capability_family(120)` APIs instead of broken `get_device_capability() not in [89, 120]` (NamedTuple vs int comparison) - marlin_utils.py: `is_device_capability(120)` → `is_device_capability_family(120)` for Python-side FP8 input check Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in marlin.cu. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) but is excluded by exact-match arch guards throughout the Marlin and CUTLASS FP8 codepaths. This fixes 8 locations: - generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89 or arch >= 120` so SM121 FP8 kernel templates are generated - ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation gate - scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only` → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121 - test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper `is_device_capability(89)` / `is_device_capability_family(120)` APIs instead of broken `get_device_capability() not in [89, 120]` (NamedTuple vs int comparison) - marlin_utils.py: `is_device_capability(120)` → `is_device_capability_family(120)` for Python-side FP8 input check Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in marlin.cu. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
Address review feedback: arch >= 120 would incorrectly match future arch families (SM130+). Use arch // 10 == 12 for codegen and major_capability == 12 for runtime to scope checks to the SM12x family. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Blake Ledden <blake@secondnaturecomputing.com>
2cb48d7 to
8092825
Compare
|
Updated — DCO sign-off has been added to all commits. Ready for review. |
|
@scottgl9 I see you have cherry-picked a good bit of this PR - is there anything left in this PR worth keeping it open for from your end? |
Change is_device_capability(120) to has_device_capability(120) so SM121 (GB10) passes the >= comparison for Marlin W4A8-FP8 support. is_device_capability checks for exact match only. Ref: vllm-project#35568
Cherry-pick upstream fixes for GB10 Spark (SM121): - PR vllm-project#35568: Recognize SM121 as SM120 family for Marlin/CUTLASS FP8 kernels (generate_kernels.py, ops.cu, scaled_mm*.cuh, marlin_utils.py) - PR vllm-project#35675: Fix Qwen3.5 MTP fc layer weight shape mismatch with NVFP4 by using ReplicatedLinear with quant_config=None - PR vllm-project#35833: FP8 KV cache for Triton MLA decode on Blackwell — adds on-the-fly FP8 dequantization in Triton kernels - PR vllm-project#35936: tool_choice="required" falls back to tool_parser for non-JSON (XML) tool calls from Qwen3 models Local patches: - Patch FlashInfer TRTLLM JIT to compile for SM12x (supported_major_versions=[10] → [10, 12]) - Skip VLLM_TEST_FORCE_FP8_MARLIN for NVFP4 MoE (not SM121-ready)
|
Verified that these changes are not yet in main — |
|
cc @mgoin |
mgoin
left a comment
There was a problem hiding this comment.
LGTM, thank you for separating this
Summary
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) — both support native
mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. However, SM121 is excluded from all Marlin and CUTLASS FP8 codepaths by exact-match arch guards (== 120,in [89, 120],enable_sm120_only).This fixes 8 locations across codegen, runtime, dispatch, and tests using bounded SM12x family checks (
arch // 10 == 12,major_capability == 12,enable_sm120_family,is_device_capability_family(120)):Codegen (FP8 kernel template generation):
csrc/quantization/marlin/generate_kernels.py:arch in [89, 120]→arch == 89 or arch // 10 == 12csrc/moe/marlin_moe_wna16/generate_kernels.py: same fixRuntime (FP8 activation gate):
csrc/moe/marlin_moe_wna16/ops.cu:== 120→major_capability == 12CUTLASS FP8 dispatch (kernel wrapper):
csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh:enable_sm120_only→enable_sm120_familycsrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh: same fixTests (FP8 test case generation):
tests/kernels/moe/test_moe.py:get_device_capability() not in [89, 120]→ properis_device_capability(89)/is_device_capability_family(120)API callstests/kernels/quantization/test_marlin_gemm.py: same fixPython-side FP8 input validation:
vllm/model_executor/layers/quantization/utils/marlin_utils.py:is_device_capability(120)→is_device_capability_family(120)All checks use bounded SM12x family matching (covers SM120/SM121 but won't accidentally match future SM13x).
The
enable_sm120_only→enable_sm120_familychange in the CUTLASS dispatch headers also resolves the CUTLASS FP4 GEMM failure on SM121 reported in #30163 ("Failed to run cutlass FP4 gemm on sm120. Error: Error Internal"), sinceenable_sm120_onlyuses__CUDA_ARCH__ == 1200which excludes SM121 (__CUDA_ARCH__ == 1210), whileenable_sm120_familyuses>= 1200 && < 1300.Validation
Tested on DGX Spark (NVIDIA GB10, SM121a / capability 12.1):
Marlin FP4 GEMM (all 5 configs including N=100544): PASS
CUTLASS FP4 dispatch:
cutlass_scaled_mm_supports_fp4(121) = TrueCapability check logic:
Subsumes #35803. Fixes #35432. Fixes #30163. Relates to #30135.
Contributed by Second Nature Computing (https://joinsecondnature.com)
Test plan
enable_sm120_familyverified incommon.hppwith correct>= 1200 && < 1300range guardis_device_capability_family(120)verified: usesto_int() // 10 == 120 // 10🤖 Generated with Claude Code