Skip to content

feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference#24692

Open
AliceChenyy wants to merge 10 commits into
sgl-project:mainfrom
AliceChenyy:sm120-dsv4-rebase
Open

feat: SM120 (Blackwell Desktop) support for DeepSeek-V4 inference#24692
AliceChenyy wants to merge 10 commits into
sgl-project:mainfrom
AliceChenyy:sm120-dsv4-rebase

Conversation

@AliceChenyy
Copy link
Copy Markdown

@AliceChenyy AliceChenyy commented May 8, 2026

Summary

Adds full SM120 (RTX PRO 6000 / RTX 5090 / DGX Spark, compute 12.0) support for DeepSeek-V4/V3 on SGLang. SM120 desktop Blackwell GPUs lack TMEM, tcgen05, and DeepGEMM support — this PR provides Triton-based fallback kernels for all critical paths and enables CUDA graph capture.

Key changes

New kernels (7 files):

  • mxfp4_moe_sm120_triton.py — Triton fused MXFP4 dequant + GEMM for MoE experts (4.1x vs PyTorch per-GEMM)
  • flash_mla_sm120_triton.py — Triton FlashMLA sparse decode kernel (3.2–5.4x vs FlashInfer fallback)
  • sm120_mqa_triton.py — FP8 paged MQA with wq-precompute + vectorized batch (CUDA graph compatible)
  • flash_mla_sm120_fallback.py / sm120_mqa_fallback.py / mxfp4_moe_fallback.py — PyTorch fallback paths
  • test_sm120_mqa_fallback.py — Unit tests for MQA fallback correctness

SM120 guards (10 modified files):

  • DeepGEMM / tilelang MHC disabled on SM120 (no TMEM/tcgen05)
  • NSA backend: tilelang default, skip DeepGEMM metadata allocation
  • FlashMLA: SM120 adapter in deepseek_v4_backend.py
  • MoE: auto-select marlin backend on SM120
  • 3 CUDA-graph-breaking paths fixed (.unique()/.item()/.nonzero() → vectorized)

Bug fix (found during latest-image validation):

  • FlashMLA Triton kernel: handle KV cache uint8 dtype (upstream changed from float8_e4m3fn)

Results (8× RTX PRO 6000, TP=8, CUDA graph)

On sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130):

Metric Value
GSM8K 5-shot (200q) 99.0% accuracy
Decode (BS=1) 11.40 tok/s (TPOT = 87.7ms)
CUDA graph capture ✅ all batch sizes captured

On older nightly-dev-20260430 (sgl-kernel 0.4.1, PyTorch 2.9.1+cu129):

Metric Value
GSM8K 5-shot (200q) 98.0% accuracy
Decode (BS=1) 10.26 tok/s (TPOT = 97.5ms)
CUDA graph speedup 2.4× vs without (4.36 → 10.26 tok/s)

Motivation

  • SM120 is desktop Blackwell (RTX 5090, RTX PRO 6000) — no server-class features (TMEM, tcgen05, NVSwitch)
  • Prior to this PR, SGLang cannot run DSv4 on SM120 at all (DeepGEMM JIT crash, no MXFP4 MoE support)
  • Enables developer/researcher access to DSv4 on workstation GPUs

Notes

  • This is a rebase of feat: SM120 support for DeepSeek-V4 inference #24047 (which targeted deepseek_v4 branch) onto main
  • All SM120 kernel code is identical between the two PRs (verified by diff)
  • SM120 kernels are guarded by is_sm120_supported() — zero impact on SM100/SM103 paths
  • Environment variables: SGLANG_SM120_TRITON_FLASHMLA=1 (default on), SGLANG_SM120_MQA_FALLBACK=0 (default off)

Test plan

  • GSM8K 5-shot 200q: 99.0% on latest sglang:dev-cu13 image
  • GSM8K 5-shot 200q: 98.0% on older nightly-dev-20260430 image
  • Single request sanity check (15+27=42, capital of France, Chinese translation)
  • TPOT benchmark: 87.7ms BS=1 on dev-cu13 (11% faster than old image)
  • CUDA graph capture: all batch sizes captured successfully
  • Unit tests: test_sm120_mqa_fallback.py
  • PDL enabled on SM120: verified correct (GSM8K 10/10)
  • CI (no SM120 runner in CI — tested on local 8× RTX PRO 6000)

CI States

Latest PR Test (Base): ⏳ Run #26081958271
Latest PR Test (Extra): ⚠️ Not enabled -- add run-ci-extra label to opt in.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces comprehensive support and fallbacks for NVIDIA SM120 (Blackwell Desktop) architectures, which lack certain hardware features like TMEM/tcgen05 required by DeepGEMM and FlashMLA. It implements both pure-PyTorch and optimized Triton fallback kernels for FlashMLA, MQA logits, and MXFP4 MoE operations. Additionally, it updates environment configurations, metadata handling, and server arguments to automatically detect SM120 and route to these compatible implementations. Review feedback highlighted misleading function aliasing in the MoE implementation, the replacement of descriptive assertion messages with "TODO" strings, and inconsistent backend selection logic in the new entrypoints.

Comment on lines +135 to +137
from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import (
mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The alias mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback is highly misleading. It imports the optimized Triton implementation but names it as if it were the fallback. This is especially confusing given that a separate PyTorch fallback implementation exists in mxfp4_moe_fallback.py.

This can cause significant confusion during future maintenance. Please rename the alias to accurately reflect the imported function, for example, by simply importing mxfp4_moe_forward_triton and using that name directly.

Suggested change
from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import (
mxfp4_moe_forward_triton as mxfp4_moe_forward_fallback,
)
from sglang.srt.layers.moe.fused_moe_triton.mxfp4_moe_sm120_triton import (
mxfp4_moe_forward_triton,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address this comment

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — now imports mxfp4_moe_forward_triton directly without aliasing.

Comment on lines +67 to +68
assert head_dim == 128, "TODO"
assert block_size == 64, "TODO"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assert messages have been changed from descriptive explanations to simply "TODO". The previous implementation also had these asserts, but the original messages explained why these values were hardcoded (e.g., "torch reference impl hardcodes DSV4 indexer head_dim=128"). Please consider restoring the more descriptive messages to improve code maintainability.

Suggested change
assert head_dim == 128, "TODO"
assert block_size == 64, "TODO"
assert head_dim == 128, "Vectorized torch impl hardcodes DSV4 indexer head_dim=128"
assert block_size == 64, "Vectorized torch impl hardcodes block_size=64 cache layout"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dito

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — restored descriptive assert messages.

_use_triton_flashmla = os.environ.get("SGLANG_SM120_TRITON_FLASHMLA", "1") == "1"


def flash_mla_with_kvcache_entrypoint(backend: str, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In flash_mla_with_kvcache_entrypoint, the backend selection logic for SM120 is confusing. The function takes a backend parameter (from SGLANG_HACK_FLASHMLA_BACKEND), but this parameter is ignored when _is_sm120 is true. Instead, the choice between the Triton kernel and the PyTorch fallback is controlled by a separate environment variable, SGLANG_SM120_TRITON_FLASHMLA.

To improve clarity, it would be better to unify this control flow. Consider using the backend parameter to also control the implementation path on SM120, for instance, by checking for values like "triton" or "torch".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dito

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — the SGLANG_HACK_FLASHMLA_BACKEND env var has been removed. On SM120 the backend is now hardcoded to "kernel" (Triton path), with SGLANG_SM120_TRITON_FLASHMLA env var as the only escape hatch to fall back to PyTorch attention if needed for debugging. This simplifies the control flow as the gemini-bot suggested.

@AliceChenyy AliceChenyy force-pushed the sm120-dsv4-rebase branch 3 times, most recently from 8779574 to a4b24dd Compare May 9, 2026 04:20
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does here need add _is_sm120() ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think is_cuda covers SM120

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean and not _is_sm120

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — added and not _is_sm120 guards at all DeepGEMM/tilelang paths in nsa_backend.py.

# ── Graph-safe routing: flatten topk assignments ──
# token_ids[slot] = which row of A (original token index)
# expert_ids[slot] = which expert's weights to use
flat_expert_ids = topk_ids.reshape(-1).contiguous() # [M*topk]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topk_ids can be -1 for padded/filtered tokens, but this path passes them directly as Triton expert ids. Could we use safe_expert_ids = topk_ids.clamp_min(0) for loads and set the invalid slots' output to zero?? The PyTorch fallback already skips eid_val < 0

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handled — see lines 361-363: flat_expert_ids_raw.clamp(min=0) for safe indexing, plus line 441+: invalid slots are zeroed out after the kernel. The Triton kernel itself runs on clamped-to-0 expert IDs (safe for loads), and the output is masked to zero post-kernel.

ke = ks + ke_offset
actual_seq_q = torch.cat(actual_seq_q_list, dim=0)
with self._with_real_sm_count():
logits = deep_gemm.fp8_mqa_logits(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SM120, should this CP ragged path explicitly raise NotImplementedError instead of falling through to deep_gemm.fp8_mqa_logits, since DeepGEMM is unsupported here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Will add if _is_sm120: raise NotImplementedError("CP ragged indexer not supported on SM120")

Comment thread python/sglang/jit_kernel/utils.py Outdated
return False
return get_jit_cuda_arch().major >= 9
arch = get_jit_cuda_arch()
# PDL requires SM100+ datacenter (tcgen05/TMEM); SM120 (desktop Blackwell) lacks these
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, there is a little bug on old CUTLASS version if the kernel is CUTLASS. But can you elaborate on PDL not working?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PDL is now re-enabled on SM120.

AliceChenyy added a commit to AliceChenyy/sglang that referenced this pull request May 13, 2026
Address all reviewer feedback from PR sgl-project#24692:
- Use is_sm120_supported() helper instead of raw sm_version checks
- Guard SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE
  with `not is_sm120_supported()` in deepseek_v4.py
- Auto-select marlin MoE backend on SM120 in deepseek_v4_hook.py
- Minor cleanups in indexer, metadata, nsa_backend, mxfp4_marlin_moe

Fix FlashMLA Triton kernel garbled output on latest sglang:dev image:
- Root cause: upstream changed KV cache dtype from float8_e4m3fn to uint8.
  The Triton kernel's as_strided() preserved the input dtype, so tl.load
  interpreted FP8 bit patterns as raw integers, corrupting attention scores.
- Fix: explicitly view through uint8 → float8_e4m3fn before passing to Triton.

Verified on sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130):
- GSM8K 5-shot 200q: 99.0%
- Decode BS=1: 11.40 tok/s, TPOT 87.7ms

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address all reviewer feedback from PR sgl-project#24692:
- Use is_sm120_supported() helper instead of raw sm_version checks
- Guard SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE
  with `not is_sm120_supported()` in deepseek_v4.py
- Auto-select marlin MoE backend on SM120 in deepseek_v4_hook.py
- Minor cleanups in indexer, metadata, nsa_backend, mxfp4_marlin_moe

Fix FlashMLA Triton kernel garbled output on latest sglang:dev image:
- Root cause: upstream changed KV cache dtype from float8_e4m3fn to uint8.
  The Triton kernel's as_strided() preserved the input dtype, so tl.load
  interpreted FP8 bit patterns as raw integers, corrupting attention scores.
- Fix: explicitly view through uint8 → float8_e4m3fn before passing to Triton.

Verified on sglang:dev-cu13 (sgl-kernel 0.4.2.post1, PyTorch 2.11+cu130):
- GSM8K 5-shot 200q: 99.0%
- Decode BS=1: 11.40 tok/s, TPOT 87.7ms

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@samuellees
Copy link
Copy Markdown
Contributor

samuellees commented May 13, 2026

Could you please resolve the conflicts with main branch? Just merge main into this branch, no need to create new PR or force push, Thanks! ^ ^

Also, could you give a code lint fix with pre-commit, please? https://github.com/sgl-project/sglang/actions/runs/25803142782/job/75798441697?pr=24692

AliceChenyy and others added 2 commits May 13, 2026 20:15
# Conflicts:
#	python/sglang/srt/layers/attention/dsv4/indexer.py
#	python/sglang/srt/layers/quantization/mxfp4_marlin_moe.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AliceChenyy
Copy link
Copy Markdown
Author

@sonny-vleisides Thanks for testing on TP=2! We reproduced and resolved the issues:

a) Auto-detect not firing: This was fixed — the hook now auto-selects marlin MoE backend on SM120 regardless of TP size. No manual env vars needed.

b) CUDA graph crash at deep_gemm.tf32_hc_prenorm_gemm: Fixed — added not is_sm120_supported() guards on SGLANG_OPT_DEEPGEMM_HC_PRENORM and SGLANG_OPT_USE_TILELANG_MHC_PRE so SM120 falls back to PyTorch.

TP=2 on 2x RTX PRO 6000 verified working with latest code:

  • Key param: --mem-fraction-static 0.93 (model uses 86.57 GB/card, needs high fraction)
  • --cuda-graph-max-bs 1
  • GSM8K 5-shot 10q: 100% (10/10)
  • TPOT BS=1: ~443ms (vs ~88ms on TP=8)

The main constraint on TP=2 is memory — only ~0.9 GB available for KV cache after model loading.

@AliceChenyy
Copy link
Copy Markdown
Author

@samuellees Both done:

  1. Merged main into branch (no force push) ✅
  2. Pre-commit lint fixed — lint CI now passes

@samuellees
Copy link
Copy Markdown
Contributor

samuellees commented May 14, 2026

/tag-and-rerun-ci +++++++++
0519+

Copy link
Copy Markdown
Contributor

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@samuellees
Copy link
Copy Markdown
Contributor

Seems the CI error is relative with this PR https://github.com/sgl-project/sglang/actions/runs/25839602738/job/75934182042?pr=24692#step:11:762. Could you pease take a look? @AliceChenyy

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AliceChenyy
Copy link
Copy Markdown
Author

@samuellees Thanks for flagging! The CI failure was in test_no_bare_pytest_main — our test_sm120_mqa_fallback.py had a bare pytest.main(...) without sys.exit() wrapper. Fixed in 385a8d4.

Could you please re-trigger CI with /tag-and-rerun-ci? Thanks!

@samuellees
Copy link
Copy Markdown
Contributor

samuellees commented May 18, 2026

Do you have any comment on this failing CI case, please? It has been failed many times. @AliceChenyy @b8zhong @Fridge003
https://github.com/sgl-project/sglang/actions/runs/25845935332/job/76407592232?pr=24692#step:11:5175

@AliceChenyy
Copy link
Copy Markdown
Author

AliceChenyy commented May 18, 2026

Do you have any comment on this failing CI case, please? It has been failed many times. @AliceChenyy @b8zhong @Fridge003 https://github.com/sgl-project/sglang/actions/runs/25845935332/job/76407592232?pr=24692#step:11:5175

Hi @samuellees, this CI failure is not related to our SM120 PR.

The failing test is test_mxfp4_120b (gpt-oss-120b model on H100), which crashes in _matmul_ogs_NNT_bf16xbf16xmxfp4_128x256x128x1_swiglu with CUDBG_EXCEPTION_WARP_ILLEGAL_ADDRESS.

Root cause: PR #24816 (merged May 13) introduced a triton_kernels API incompatibility — the FnSpecs constructor and GatherIndx/RoutingData/ScatterIndx import paths changed, causing the
matmul_ogs kernel to receive incorrect swiglu parameters → illegal memory access.

This was a known regression on main, confirmed by:
- PR #25329 (merged May 15): explicitly skipped #24816's broken tests
- PR #25335 (merged May 15): fixed the API calls + bumped sgl-kernel to 0.4.2.post2
- [rerun-test] test_gpt_oss_4gpu.py on fix_flashinfer branch: ✅ passed (runs https://github.com/sgl-project/sglang/actions/runs/25905118937)

Our branch was rebased on main from May 13 — after #24816 but before the fix in #25335. All our SM120 code paths are guarded by _is_sm120 and don't execute on H100.

I'll rebase to latest main (which includes the fix).

@AliceChenyy
Copy link
Copy Markdown
Author

Merged latest main (includes #25335 fix) into branch — commit 07c5e4896. Could you please re-trigger CI with /tag-and-rerun-ci? Thanks!

Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before merging, can you please check:

  1. Run the same eval on SM90 or SM100 and compare the score to (maybe a small subset of GPQA, because the whole thing will take a really long time)
  2. 10-11 TPS is very slow, even considering SM120 I think. Can you please share a profile/ verify if there are any significantly slow kernels that can be easily improved (it's ok if it's allreduce or pre-existing kernels, etc)

@AliceChenyy
Copy link
Copy Markdown
Author

Before merging, can you please check:

  1. Run the same eval on SM90 or SM100 and compare the score to (maybe a small subset of GPQA, because the whole thing will take a really long time)
  2. 10-11 TPS is very slow, even considering SM120 I think. Can you please share a profile/ verify if there are any significantly slow kernels that can be easily improved (it's ok if it's allreduce or pre-existing kernels, etc)

Thanks!

Will do both and update if any progress.

AliceChenyy and others added 4 commits May 18, 2026 23:53
Skip DeepGEMM transform_sf_into_required_layout (tcgen05 unsupported)
and topk_v2 (128KB SMEM exceeds SM120 99KB limit) on SM120.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…rted() directly

Address b8zhong review: do not define SM120 detection in this file,
call the existing util is_sm120_supported() at each usage site instead.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
is_sm120_supported() already checks is_cuda() internally and is
lru_cached, so the redundant `_is_cuda and` prefix is unnecessary.

- metadata.py: remove _is_cuda/_is_sm120 module vars, call util directly
- flash_mla_sm120_fallback.py: remove _is_cuda, simplify _is_sm120
- nsa_backend.py: remove redundant is_cuda() prefix
- nsa_indexer.py: remove redundant _is_cuda prefix

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Conflict in deepseek_v4.py hc_pre: upstream added AMD aiter mhc_pre
path; preserved SM120 guard on deepgemm hc_prenorm_gemm.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@AliceChenyy
Copy link
Copy Markdown
Author

Cross-architecture GPQA Diamond comparison completed (same 50 questions, random.Random(0), temperature=0, max_tokens=1024):

GPU TP Score
RTX 6000D (SM120) 8 72.0% (36/50)
H100 NVL (SM90) 4 68.0% (34/50)
DeepSeek reference 71.2%

SM120 matches reference accuracy — no regression. The 2-question difference (4%) is within statistical noise (std ≈ 0.46 for n=50).

@AliceChenyy
Copy link
Copy Markdown
Author

AliceChenyy commented May 19, 2026

hi @b8zhong

The rough decode breakdown from the profiling runs is:

  • MXFP4 MoE path: ~35-40 ms/token, around 40% of TPOT
  • NCCL/all-reduce on PCIe: ~18-22 ms/token, around 20-25%
  • Dense linear / existing FP8 GEMM path: ~15-20 ms/token
  • FlashMLA sparse decode fallback/Triton path: ~10-12 ms/token
  • Other ops, e.g. RMSNorm/RoPE/metadata/HC-related pieces: ~5-8 ms/token

I do not see one single obviously-broken kernel that could be improved easily. The slowness is distributed across MoE weight streaming, PCIe all-reduce, and SM120 fallback kernels.

The bottlenecks are known architectural fallbacks required on SM120 (no TMEM/tcgen05):

  • hc_pre: PyTorch fallback (34 kernel launches vs 1 fused DeepGEMM op on SM90)
  • MQA logits: Triton fallback instead of FlashMLA native
  • MoE: Triton MxFP4 fallback instead of CUTLASS

These are tracked upstream in flashinfer#3346 — most DSv4 kernels for SM120 are currently listed as "missing" or "not wired" in FlashInfer, confirming that SM120 fallback paths are the expected state.

Performance optimization is already happening incrementally on top of this enablement PR — for example #25696 adds SM120 FP8 GEMM autotune configs and reaches 44-52 tok/s on TP=2 (4-5x over our baseline), building directly on this PR's SM120 support.

This PR focuses on correctness — making DSv4-Flash work on SM120 with accuracy matching the reference (72% vs 71.2% on GPQA Diamond).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants