feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference#24692
feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference#24692AliceChenyy wants to merge 10 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces comprehensive support and fallbacks for NVIDIA SM120 (Blackwell Desktop) architectures, which lack certain hardware features like TMEM/tcgen05 required by DeepGEMM and FlashMLA. It implements both pure-PyTorch and optimized Triton fallback kernels for FlashMLA, MQA logits, and MXFP4 MoE operations. Additionally, it updates environment configurations, metadata handling, and server arguments to automatically detect SM120 and route to these compatible implementations. Review feedback highlighted misleading function aliasing in the MoE implementation, the replacement of descriptive assertion messages with "TODO" strings, and inconsistent backend selection logic in the new entrypoints.
| from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( | ||
| mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback, | ||
| ) |
There was a problem hiding this comment.
The alias mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback is highly misleading. It imports the optimized Triton implementation but names it as if it were the fallback. This is especially confusing given that a separate PyTorch fallback implementation exists in mxfp4_moe_fallback.py.
This can cause significant confusion during future maintenance. Please rename the alias to accurately reflect the imported function, for example, by simply importing mxfp4_moe_forward_triton and using that name directly.
| from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( | |
| mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback, | |
| ) | |
| from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import ( | |
| mxfp4_moe_forward_triton, | |
| ) |
There was a problem hiding this comment.
Please address this comment
There was a problem hiding this comment.
Done — now imports mxfp4_moe_forward_triton directly without aliasing.
| assert head_dim == 128, "TODO" | ||
| assert block_size == 64, "TODO" |
There was a problem hiding this comment.
The assert messages have been changed from descriptive explanations to simply "TODO". The previous implementation also had these asserts, but the original messages explained why these values were hardcoded (e.g., "torch reference impl hardcodes DSV4 indexer head_dim=128"). Please consider restoring the more descriptive messages to improve code maintainability.
| assert head_dim == 128, "TODO" | |
| assert block_size == 64, "TODO" | |
| assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128" | |
| assert block_size == 64, "Vectorized torch impl hardcodes block_size=64 cache layout" |
There was a problem hiding this comment.
Done — restored descriptive assert messages.
| _use_triton_flashmla = os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1" | ||
|
|
||
|
|
||
| def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs): |
There was a problem hiding this comment.
In flash_mla_with_kvcache_entrypoint, the backend selection logic for SM120 is confusing. The function takes a backend parameter (from SGLANG_HACK_FLASHMLA_BACKEND), but this parameter is ignored when _is_sm120 is true. Instead, the choice between the Triton kernel and the PyTorch fallback is controlled by a separate environment variable, SGLANG_SM120_TRITON_FLASHMLA.
To improve clarity, it would be better to unify this control flow. Consider using the backend parameter to also control the implementation path on SM120, for instance, by checking for values like "triton" or "torch".
There was a problem hiding this comment.
Done — the SGLANG_HACK_FLASHMLA_BACKEND env var has been removed. On SM120 the backend is now hardcoded to "kernel" (Triton path), with SGLANG_SM120_TRITON_FLASHMLA env var as the only escape hatch to fall back to PyTorch attention if needed for debugging. This simplifies the control flow as the gemini-bot suggested.
8779574 to
a4b24dd
Compare
There was a problem hiding this comment.
Does here need add _is_sm120() ?
There was a problem hiding this comment.
I think is_cuda covers SM120
There was a problem hiding this comment.
I mean and not _is_sm120
There was a problem hiding this comment.
Done — added and not _is_sm120 guards at all DeepGEMM/tilelang paths in nsa_backend.py.
| # ── Graph-safe routing: flatten topk assignments ── | ||
| # token_ids[slot] = which row of A (original token index) | ||
| # expert_ids[slot] = which expert's weights to use | ||
| flat_expert_ids = topk_ids.reshape(-1).contiguous() # [M*topk] |
There was a problem hiding this comment.
topk_ids can be -1 for padded/filtered tokens, but this path passes them directly as Triton expert ids. Could we use safe_expert_ids = topk_ids.clamp_min(0) for loads and set the invalid slots' output to zero?? The PyTorch fallback already skips eid_val < 0
There was a problem hiding this comment.
Handled — see lines 361-363: flat_expert_ids_raw.clamp(min=0) for safe indexing, plus line 441+: invalid slots are zeroed out after the kernel. The Triton kernel itself runs on clamped-to-0 expert IDs (safe for loads), and the output is masked to zero post-kernel.
| ke = ks + ke_offset | ||
| actual_seq_q = torch.cat(actual_seq_q_list, dim=0) | ||
| with self._with_real_sm_count(): | ||
| logits = deep_gemm.fp8_mqa_logits( |
There was a problem hiding this comment.
For SM120, should this CP ragged path explicitly raise NotImplementedError instead of falling through to deep_gemm.fp8_mqa_logits, since DeepGEMM is unsupported here?
There was a problem hiding this comment.
Good catch. Will add if _is_sm120: raise NotImplementedError("CP ragged indexer not supported on SM120")
| return False | ||
| return get_jit_cuda_arch().major >= 9 | ||
| arch = get_jit_cuda_arch() | ||
| # PDL requires SM100+ datacenter (tcgen05/TMEM); SM120 (desktop Blackwell) lacks these |
There was a problem hiding this comment.
Currently, there is a little bug on old CUTLASS version if the kernel is CUTLASS. But can you elaborate on PDL not working?
There was a problem hiding this comment.
PDL is now re-enabled on SM120.
Address all reviewer feedback from PR sgl-project#24692: - Use is_sm120_supported() helper instead of raw sm_version checks - Guard SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE with `not is_sm120_supported()` in deepseek_v4.py - Auto-select marlin MoE backend on SM120 in deepseek_v4_hook.py - Minor cleanups in indexer, metadata, nsa_backend, mxfp4_marlin_moe Fix FlashMLA Triton kernel garbled output on latest sglang:dev image: - Root cause: upstream changed KV cache dtype from float8_e4m3fn to uint8. The Triton kernel's as_strided() preserved the input dtype, so tl.load interpreted FP8 bit patterns as raw integers, corrupting attention scores. - Fix: explicitly view through uint8 → float8_e4m3fn before passing to Triton. Verified on sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130): - GSM8K 5-shot 200q: 99.0% - Decode BS=1: 11.40 tok/s, TPOT 87.7ms Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address all reviewer feedback from PR sgl-project#24692: - Use is_sm120_supported() helper instead of raw sm_version checks - Guard SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE with `not is_sm120_supported()` in deepseek_v4.py - Auto-select marlin MoE backend on SM120 in deepseek_v4_hook.py - Minor cleanups in indexer, metadata, nsa_backend, mxfp4_marlin_moe Fix FlashMLA Triton kernel garbled output on latest sglang:dev image: - Root cause: upstream changed KV cache dtype from float8_e4m3fn to uint8. The Triton kernel's as_strided() preserved the input dtype, so tl.load interpreted FP8 bit patterns as raw integers, corrupting attention scores. - Fix: explicitly view through uint8 → float8_e4m3fn before passing to Triton. Verified on sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130): - GSM8K 5-shot 200q: 99.0% - Decode BS=1: 11.40 tok/s, TPOT 87.7ms Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a0b78ed to
61a71a1
Compare
|
Could you please resolve the conflicts with main branch? Just merge main into this branch, no need to create new PR or force push, Thanks! ^ ^ Also, could you give a code lint fix with pre-commit, please? https://github.com/sgl-project/sglang/actions/runs/25803142782/job/75798441697?pr=24692 |
# Conflicts: # python/sglang/srt/layers/attention/dsv4/indexer.py # python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@sonny-vleisides Thanks for testing on TP=2! We reproduced and resolved the issues: a) Auto-detect not firing: This was fixed — the hook now auto-selects b) CUDA graph crash at TP=2 on 2x RTX PRO 6000 verified working with latest code:
The main constraint on TP=2 is memory — only ~0.9 GB available for KV cache after model loading. |
|
@samuellees Both done:
|
|
/tag-and-rerun-ci +++++++++ |
|
Seems the CI error is relative with this PR https://github.com/sgl-project/sglang/actions/runs/25839602738/job/75934182042?pr=24692#step:11:762. Could you pease take a look? @AliceChenyy |
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@samuellees Thanks for flagging! The CI failure was in Could you please re-trigger CI with |
|
Do you have any comment on this failing CI case, please? It has been failed many times. @AliceChenyy @b8zhong @Fridge003 |
Hi @samuellees, this CI failure is not related to our SM120 PR. The failing test is test_mxfp4_120b (gpt-oss-120b model on H100), which crashes in _matmul_ogs_NNT_bf16xbf16xmxfp4_128x256x128x1_swiglu with CUDBG_EXCEPTION_WARP_ILLEGAL_ADDRESS. Root cause: PR #24816 (merged May 13) introduced a triton_kernels API incompatibility — the FnSpecs constructor and GatherIndx/RoutingData/ScatterIndx import paths changed, causing the This was a known regression on main, confirmed by: Our branch was rebased on main from May 13 — after #24816 but before the fix in #25335. All our SM120 code paths are guarded by _is_sm120 and don't execute on H100. I'll rebase to latest main (which includes the fix). |
|
Merged latest main (includes #25335 fix) into branch — commit |
b8zhong
left a comment
There was a problem hiding this comment.
Before merging, can you please check:
- Run the same eval on SM90 or SM100 and compare the score to (maybe a small subset of GPQA, because the whole thing will take a really long time)
- 10-11 TPS is very slow, even considering SM120 I think. Can you please share a profile/ verify if there are any significantly slow kernels that can be easily improved (it's ok if it's allreduce or pre-existing kernels, etc)
Thanks! Will do both and update if any progress. |
Skip DeepGEMM transform_sf_into_required_layout (tcgen05 unsupported) and topk_v2 (128KB SMEM exceeds SM120 99KB limit) on SM120. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rted() directly Address b8zhong review: do not define SM120 detection in this file, call the existing util is_sm120_supported() at each usage site instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
is_sm120_supported() already checks is_cuda() internally and is lru_cached, so the redundant `_is_cuda and` prefix is unnecessary. - metadata.py: remove _is_cuda/_is_sm120 module vars, call util directly - flash_mla_sm120_fallback.py: remove _is_cuda, simplify _is_sm120 - nsa_backend.py: remove redundant is_cuda() prefix - nsa_indexer.py: remove redundant _is_cuda prefix Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Conflict in deepseek_v4.py hc_pre: upstream added AMD aiter mhc_pre path; preserved SM120 guard on deepgemm hc_prenorm_gemm. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Cross-architecture GPQA Diamond comparison completed (same 50 questions,
SM120 matches reference accuracy — no regression. The 2-question difference (4%) is within statistical noise (std ≈ 0.46 for n=50). |
|
hi @b8zhong The rough decode breakdown from the profiling runs is:
I do not see one single obviously-broken kernel that could be improved easily. The slowness is distributed across MoE weight streaming, PCIe all-reduce, and SM120 fallback kernels. The bottlenecks are known architectural fallbacks required on SM120 (no TMEM/tcgen05):
These are tracked upstream in flashinfer#3346 — most DSv4 kernels for SM120 are currently listed as "missing" or "not wired" in FlashInfer, confirming that SM120 fallback paths are the expected state. Performance optimization is already happening incrementally on top of this enablement PR — for example #25696 adds SM120 FP8 GEMM autotune configs and reaches 44-52 tok/s on TP=2 (4-5x over our baseline), building directly on this PR's SM120 support. This PR focuses on correctness — making DSv4-Flash work on SM120 with accuracy matching the reference (72% vs 71.2% on GPQA Diamond). |
Summary
Adds full SM120 (RTX PRO 6000 / RTX 5090 / DGX Spark, compute 12.0) support for DeepSeek-V4/V3 on SGLang. SM120 desktop Blackwell GPUs lack TMEM, tcgen05, and DeepGEMM support — this PR provides Triton-based fallback kernels for all critical paths and enables CUDA graph capture.
Key changes
New kernels (7 files):
mxfp4_moe_sm120_triton.py— Triton fused MXFP4 dequant + GEMM for MoE experts (4.1x vs PyTorch per-GEMM)flash_mla_sm120_triton.py— Triton FlashMLA sparse decode kernel (3.2–5.4x vs FlashInfer fallback)sm120_mqa_triton.py— FP8 paged MQA with wq-precompute + vectorized batch (CUDA graph compatible)flash_mla_sm120_fallback.py/sm120_mqa_fallback.py/mxfp4_moe_fallback.py— PyTorch fallback pathstest_sm120_mqa_fallback.py— Unit tests for MQA fallback correctnessSM120 guards (10 modified files):
deepseek_v4_backend.pymarlinbackend on SM120.unique()/.item()/.nonzero()→ vectorized)Bug fix (found during latest-image validation):
uint8dtype (upstream changed fromfloat8_e4m3fn)Results (8× RTX PRO 6000, TP=8, CUDA graph)
On
sglang:dev-cu13(sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130):On older nightly-dev-20260430 (sgl-kernel 0.4.1, PyTorch 2.9.1+cu129):
Motivation
Notes
deepseek_v4branch) ontomainis_sm120_supported()— zero impact on SM100/SM103 pathsSGLANG_SM120_TRITON_FLASHMLA=1(default on),SGLANG_SM120_MQA_FALLBACK=0(default off)Test plan
sglang:dev-cu13imagetest_sm120_mqa_fallback.pyCI States
Latest PR Test (Base): ⏳ Run #26081958271⚠️ Not enabled -- add
Latest PR Test (Extra):
run-ci-extralabel to opt in.