[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35803
[Bugfix] Fix SM121 (DGX Spark) exclusion from Marlin/CUTLASS FP8 paths#35803blake-snc wants to merge 2 commits intovllm-project:mainfrom
Conversation
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) but is excluded by exact-match arch guards throughout the Marlin and CUTLASS FP8 codepaths. This fixes 8 locations: - generate_kernels.py (Marlin + MoE): `arch in [89, 120]` → `arch == 89 or arch >= 120` so SM121 FP8 kernel templates are generated - ops.cu (MoE Marlin): `== 120` → `>= 120` in runtime FP8 activation gate - scaled_mm_sm120_fp8_dispatch.cuh + scaled_mm.cuh: `enable_sm120_only` → `enable_sm120_family` so CUTLASS FP8 GEMM kernels run on SM121 - test_moe.py + test_marlin_gemm.py: fix FP8 test skip using proper `is_device_capability(89)` / `is_device_capability_family(120)` APIs instead of broken `get_device_capability() not in [89, 120]` (NamedTuple vs int comparison) - marlin_utils.py: `is_device_capability(120)` → `is_device_capability_family(120)` for Python-side FP8 input check Companion to vllm-project#35568 which fixes the runtime Marlin FP8 gate in marlin.cu. Contributed by Second Nature Computing (https://joinsecondnature.com) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request correctly addresses the exclusion of the SM121 (DGX Spark) architecture from Marlin/CUTLASS FP8 codepaths by replacing exact architecture checks with more general ones. The changes across Python, CUDA C++, and test files are mostly consistent and well-aligned with the goal. However, I've identified an inconsistency in the implementation of these checks. Some parts of the code use unbounded checks (e.g., >= 120), which assume forward compatibility with all future architectures, while other parts use safer, bounded checks for the SM12x family. I've added comments to align all checks to be bounded for consistency and to prevent potential issues on future hardware.
| # SM90 and SM100 can use this PTX, but it’s simulated | ||
| # with FP16 MMA, so it cannot achieve any acceleration. | ||
| if arch in [89, 120]: | ||
| if arch == 89 or arch >= 120: |
There was a problem hiding this comment.
The check arch >= 120 is unbounded, which assumes all future architectures (SM13x, SM14x, etc.) will support this specific MMA instruction. This might not be true and could lead to issues on future hardware. Other parts of this PR use a bounded check for the SM12x family (e.g., is_device_capability_family(120) in Python, which is equivalent to arch // 10 == 12, or enable_sm120_family in CUDA C++, which checks __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300). For consistency and safety, it would be better to use a bounded check here as well.
| if arch == 89 or arch >= 120: | |
| if arch == 89 or (arch // 10 == 12): |
csrc/moe/marlin_moe_wna16/ops.cu
Outdated
| major_capability * 10 + minor_capability == 89 || | ||
| major_capability * 10 + minor_capability == 120, | ||
| "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " | ||
| major_capability * 10 + minor_capability >= 120, |
There was a problem hiding this comment.
Similar to the kernel generation scripts, this check major_capability * 10 + minor_capability >= 120 is unbounded and assumes forward compatibility for future architectures beyond the SM12x family. This is inconsistent with other changes in this PR that use bounded checks (e.g., enable_sm120_family). To prevent potential issues on future hardware, a bounded check would be safer. Checking if major_capability == 12 is a clean way to target the entire SM12.x family.
major_capability == 12,
| # SM90 and SM100 can use this PTX, but it’s simulated | ||
| # with FP16 MMA, so it cannot achieve any acceleration. | ||
| if arch in [89, 120]: | ||
| if arch == 89 or arch >= 120: |
There was a problem hiding this comment.
The check arch >= 120 is unbounded, which assumes all future architectures (SM13x, SM14x, etc.) will support this specific MMA instruction. This might not be true and could lead to issues on future hardware. Other parts of this PR use a bounded check for the SM12x family (e.g., is_device_capability_family(120) in Python, which is equivalent to arch // 10 == 12, or enable_sm120_family in CUDA C++, which checks __CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300). For consistency and safety, it would be better to use a bounded check here as well.
| if arch == 89 or arch >= 120: | |
| if arch == 89 or (arch // 10 == 12): |
Address review feedback: arch >= 120 would incorrectly match future arch families (SM130+). Use arch // 10 == 12 for codegen and major_capability == 12 for runtime to scope checks to the SM12x family. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Consolidated into #35568 which now contains all changes from this PR (8 files, bounded SM12x family checks). Closing this one. |
Summary
SM121 (DGX Spark GB10) shares the same FP8 MMA capabilities as SM120 (RTX 5090) — both support native
mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. However, SM121 is excluded from all Marlin and CUTLASS FP8 codepaths by exact-match arch guards (== 120,in [89, 120],enable_sm120_only).This fixes 8 locations across codegen, runtime, dispatch, and tests:
Codegen (FP8 kernel template generation):
csrc/quantization/marlin/generate_kernels.py:arch in [89, 120]→arch == 89 or arch >= 120csrc/moe/marlin_moe_wna16/generate_kernels.py: same fixRuntime (FP8 activation gate):
csrc/moe/marlin_moe_wna16/ops.cu:== 120→>= 120CUTLASS FP8 dispatch (kernel wrapper):
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh:enable_sm120_only→enable_sm120_familycsrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh: same fixTests (FP8 test case generation):
tests/kernels/moe/test_moe.py:get_device_capability() not in [89, 120]→ properis_device_capability(89)/is_device_capability_family(120)API calls (the old pattern compared aDeviceCapabilityNamedTuple against ints, which never matched)tests/kernels/quantization/test_marlin_gemm.py: same fixPython-side FP8 input validation:
vllm/model_executor/layers/quantization/utils/marlin_utils.py:is_device_capability(120)→is_device_capability_family(120)Companion to #35568 which fixes the runtime Marlin FP8 gate in
marlin.cu.Test plan
enable_sm120_familyalready exists incommon.hppwith correct>= 1200 && < 1300range guardlaunch_bounds_utils.halready correctly handles SM121 (__CUDA_ARCH__ == 1210)is_device_capability_family(120)usesto_int() // 10 == 120 // 10which covers SM120/SM121Contributed by Second Nature Computing (https://joinsecondnature.com)
🤖 Generated with Claude Code