Feature/sm120 deepseek v4 highspeed inference support#24303
Feature/sm120 deepseek v4 highspeed inference support#24303AdamPlatin123 wants to merge 10 commits into
Conversation
SM120 (consumer Blackwell, e.g. RTX PRO 6000) has only ~99KB shared memory per block (101376 bytes), compared to SM100's 227KB. Multiple kernels exceed this limit and fail with: "Failed to set the allowed dynamic shared memory size to N" Changes: - topk.cu: Reduce dynamic shared memory from 128KB to 80KB (10240 entries per buffer, still 5x the TopK=2048 requirement) - topk_v2.cuh: Skip cluster path (~144KB) on SM120, use Small/Medium paths (~84KB) instead - tilelang_kernel.py: Add SM120 detection for sparse attention, use smaller block_I=32 to reduce shared memory - mhc.py: Use hidden_block=128 on SM120 instead of 256 to keep splitk kernel well within the 99KB limit
Enable DeepSeek-V4-Flash inference on SM120 GPUs (RTX PRO 6000 Blackwell, CC 12.0, TP=2) with CUDA Graph, achieving ~43 tok/s decode throughput (8.3x from 5.2 tok/s baseline). Key changes: General improvements (all platforms benefit): - indexer.py: Eliminate GPU->CPU sync points (.item(), .any()) in fp8_paged_mqa_logits_torch() for CUDA graph compatibility - cuda_graph_runner.py: Add warmup sync and optional debug mode for more reliable graph capture - fused_kv_gather_triton.py: New Triton kernel for fused KV gather+dequant in MLA decode (replaces ~16 PyTorch ops) - fp4_gemv_triton.py: New Triton kernel for FP4 dequant+GEMV with batched multi-expert support for MoE decode SM120-specific (runtime detection via is_sm120_supported()): - utils.py: Fix PDL detection - PDL is SM90-only, not >=SM90 - mhc.py: PyTorch fallback for hc_split_sinkhorn, reduce hidden_block to 128 on SM120 (SMEM constraint) - debug_flash_mla_adapter.py: Full PyTorch MLA decode fallback for SM120 (flash_mla only has sm_90a/sm_100f cubins) - deepseek_v4.py: Decode-optimized hc_pre path (4 kernels vs 34 tilelang launches), type-consistency fixes in hc_post - fp8.py: FP4 MoE PyTorch/Triton fallback dispatch when DeepGEMM unavailable (generates sm_120a cubins on SM120) - MoE Triton autotune configs for RTX PRO 6000 Blackwell Tested on: 2x RTX PRO 6000 Blackwell (SM120, 96GB each, TP=2) Benchmark: ~21 tok/s single-shot, ~43 tok/s at concurrency=8 Container config: --cuda-graph-bs 1 2 4 8 --fp8-gemm-backend triton --disable-custom-all-reduce --disable-flashinfer-autotune --disable-overlap-schedule --mem-fraction-static 0.95 --tp 2
|
Tested this on 2x RTX PRO 6000 Blackwells (SM 120, 96 GB each) running DeepSeek-V4-Flash. The PR fixes the FP8 expert OOM, FP4 hidden size mismatch & CPU offload weight loader shape errors I was hitting without it. Server gets to Image is First one... is in Pops during CUDA graph capture at bs=8. The E2M1 LUT gets built inline with Second one... is fires in Can run a target SHA if you ship one. |
1. Move FP4 E2M1 LUT to module-level cache to avoid host→device allocation during CUDA graph capture (fixes cudaErrorStreamCaptureUnsupported) 2. Wrap tf32_hc_prenorm_gemm call in try/except with PyTorch fallback for architectures where DeepGEMM sgl-kernel binary is unsupported (SM120)
…llback 1. configurer.py: Use is_sm100_supported() instead of is_blackwell_supported() for DEEPGEMM_BLACKWELL. SM120 (consumer Blackwell) lacks tcgen05/TMEM instructions required by DeepGEMM SM100 kernels. 2. indexer.py: Wrap DeepGEMM fp8_paged_mqa_logits import in try/except with PyTorch fallback for architectures where DeepGEMM is unavailable.
SM120 consumer Blackwell lacks WGMMA (SM90) and tcgen05 (SM100) instructions that DeepGEMM kernels require. Disable ENABLE_JIT_DEEPGEMM on SM120 to prevent tcgen05.fence JIT compilation errors.
SM120 (consumer Blackwell, CC 12.0) lacks tcgen05/WGMMA instructions required by DeepGEMM kernels. While configurer.py already disables ENABLE_JIT_DEEPGEMM for SM120, several files imported deep_gemm directly without checking the flag, causing tcgen05.fence NVCC compilation errors at runtime. Changes: - nsa_indexer.py: skip deep_gemm import on SM120, guard get_num_sms() - paged_mqa_logits.py: gate import behind ENABLE_JIT_DEEPGEMM - metadata.py: skip deep_gemm when ENABLE_JIT_DEEPGEMM is False - nsa_backend.py: catch RuntimeError in addition to ImportError
Replace deep_gemm with an empty module stub on SM120 to prevent _C.init() from triggering NVCC JIT compilation of tcgen05 kernels. Also catch AttributeError in hc_pre fallback for the stub case.
|
Tested 429deb0 on the same 2x setup. Doesn't get to bugs 1 & 2... stalls much earlier. NCCL init hangs for 6+ minutes, both TP ranks pegged at 99% CPU, GPUs sitting at 914 MB (just CUDA context, weights never start loading). Last thing in the log is Same image ( ff2e14c got the server all the way to fwiw, the hang is well before any of the SM120 fallback paths even fire, so this isn't bug 1 or 2 popping back up in a different shape. Different bug. happy to re-test if you ship a candidate fix. |
Port Triton kernels for SM120 (RTX PRO 6000 Blackwell, CC 12.0) which lacks TMEM/tcgen05/WGMMA instructions required by DeepGEMM CUDA kernels. New kernels: - MXFP4 MoE Triton: per-slot fused FP4 dequant + GEMV, eliminating the Python for-loop fallback. Prefill 5.5→800+ tok/s, decode +16%. (SGLANG_SM120_TRITON_MOE=1) - FlashMLA Triton sparse decode: tiled kernel with 3 typed views of paged buffer, online softmax. No regression vs PyTorch fallback. (SGLANG_SM120_TRITON_FLASHMLA=1) - SM120 MQA fallback: pure PyTorch replacement for DeepGEMM MQA logits in NSA indexer, with wq precompute optimization. Modified: - fp8.py: Add Path B1 (Triton MXFP4 MoE) before Path B2 (Python loop) - nsa_indexer.py: Route all DeepGEMM MQA calls to SM120 fallback - debug_flash_mla_adapter.py: Integrate Triton FlashMLA with env toggle - deepseek_v4.py: Fix hash_topk input_ids dtype (int32→int64) Also includes FP8 GEMM autotune configs for RTX PRO 6000 Blackwell. Tested on 2× RTX PRO 6000 (TP=2, PCIe): - Decode: 5.2→28.8 tok/s (with CUDA Graph bs=1) - Prefill 7K tokens: ~29K tok/s (TTFT ~249ms, ~11ms compute)
b8zhong
left a comment
There was a problem hiding this comment.
Hi, since the diff is very large, I will reply to your PR description first, and then maybe once it's merged into main, we can review the code
- SM120 supports PDL. It is the same instructions
- SM120 supports TMA
Also, can you run an accuracy test, to see if the implementation is correct, as described in https://docs.sglang.io/cookbook/autoregressive/DeepSeek/DeepSeek-V4
Complete PyTorch MLA decode fallback
with fused KV gather + dequant
You can consider the Flashinfer FA2 MLA as a ground truth MLA implementation.
- Auto-enable SGLANG_SM120_TRITON_MOE and SGLANG_SM120_TRITON_FLASHMLA on SM120 via is_sm120_supported() detection (users no longer need to set env vars; both registered in environ.py with default True) - Fix is_arch_support_pdl(): change ==9 to >=9 && !=12 to preserve PDL on SM100 (B200/B100) while correctly disabling on SM120 - Unify SM120 detection: replace ad-hoc >=12 and ==120 checks with is_sm120_supported() in flash_mla_adapter and nsa_indexer - Remove dead code: _fused_rope_copy_kernel (unused, had BF16 bug) and abandoned FP4 arithmetic dequant in _dequant_fp4_to_bf16 - Clarify topk.cu 80KB SMEM: safe for all arches (was 128KB which overflowed SM80/SM89/SM120 99KB limit)
PDL (Programmatic Dependent Launch) is supported on all SM90+ GPUs including consumer Blackwell (SM120). Previous commit incorrectly disabled PDL on SM120 by conflating it with TMEM/tcgen05 (which are SM100-only datacenter features). Thanks to @b8zhong for the correction.
|
Hi @AdamPlatin123 , could you tell the difference between this PR and #24047, please? Seems 24047 has provided perf/accuracy results. |
|
Thanks @b8zhong for the corrections! PDL and TMA on SM120: You're right — SM120 does support both PDL and TMA. I was SM120 lacks:
This is why DeepGEMM kernels (which use tcgen05/TMEM) fail on SM120, but PDL-based Accuracy test results (commit
Performance:
On tilelang We'll also run the formal DeepSeek-V4 cookbook accuracy test with FlashInfer FA2 MLA |
|
@sonny-vleisides Thanks for the detailed testing! NCCL hang at
The latest commit ( Bug 1 (E2M1 LUT during CUDA graph capture): Fixed in Bug 2 (hyperconnection.hpp assertion): Fixed in Would be great if you could re-test at the latest commit. The server should start |
|
Hi @samuellees, thanks for pointing this out. Here’s a side-by-side comparison between the two approaches: Relationship: Key differences:
Our recommendation: |
Sound good! Thanks for reply @AdamPlatin123 . What's your slack name please? My slack name is Sam Li. Let's discuss more details on slack~ |
|
do you have any perf command on sm120 ? |
|
Hi @bbbearxyz, here's our benchmark setup on SM120 (2× RTX PRO 6000 Blackwell, 96GB each, TP=2): Server launch: Results (PyTorch 2.9.1+cu129):
|
|
@AdamPlatin123 do you have any env or just to run it? and i will try my best to optimize it, and you can contact me in the slack or email? |
|
Hi @bbbearxyz, at commit On earlier commits, these env vars are required: export SGLANG_DSV4_MODE=2604 |
|
Hi @AdamPlatin123 , we've rebased PR #24047 (AliceChenyy) into #24692. We'll move PR24692 forward and merge it as soon as possible. So you can cherry pick or rebase this PR based on PR24047/24692 (We can also help if you want). Please let me know if this doesn't make sense to you^ ^ |
|
Hi @samuellees, thanks for the update on #24692. Here's our plan: Rebase coordination: We're happy to rebase #24303 on top of #24692 once it merges. Our PR adds complementary optimizations beyond the Triton kernels from #24047:
Action items:
|
Motivation
Consumer Blackwell GPUs (RTX PRO 6000, CC 12.0) lack TMA, TCGEN5, and ACQBULK
instructions present in datacenter Blackwell (B100/B200, CC 12.0a). This makes
DeepSeek-V4-Flash inference fail out-of-the-box due to DeepGEMM/tilelang kernel
incompatibilities.
This PR adds full SM120 compatibility so DeepSeek-V4-Flash runs on consumer
Blackwell hardware.
Modifications
Kernel Fallbacks & Fixes
SM90-only, not available on SM120
computation using index/scan approach
with fused KV gather + dequant
gather + FP8 dequant, avoiding redundant copies
incompatible)
(~212 lines)
cases
Blackwell
Autotune Configs
Bug Fixes
Accuracy Tests
All fallback paths produce numerically correct results verified on SM120 hardware:
Speed Tests and Profiling
Tested on 2× RTX PRO 6000 Blackwell (96GB each), TP=2, model: DeepSeek-V4-Flash
(FP8, 158B params).
Benchmark:
python3 -m sglang.bench_serving --dataset-name random --random-input 128 --random-output 256Key profiling findings:
Checklist
All 8 modified Python files pass black --check and ruff check --select=F401,F821.
Hardware-specific (SM120) — no CI coverage for this GPU class yet. Smoke tests can be added if maintainers prefer.
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci