feat(w4a16): SM90 W4A16 MoE path for DeepSeek-V4-Flash + RSF pre-multiply fix#24492
Closed
samuellees wants to merge 8 commits into
Closed
feat(w4a16): SM90 W4A16 MoE path for DeepSeek-V4-Flash + RSF pre-multiply fix#24492samuellees wants to merge 8 commits into
samuellees wants to merge 8 commits into
Conversation
Contributor
There was a problem hiding this comment.
Code Review
This pull request implements support for DeepSeek-V4 W4A16 quantization on Hopper architectures, adding specialized MoE runners and quantization logic. Key changes include refactoring the sgl-kernel build system, optimizing DeepGEMM warmup, and replacing custom JIT kernels with native PyTorch operations. Additionally, the PR includes a suite of benchmarking scripts and evaluation utilities. Review feedback identifies security risks in the benchmarking scripts, specifically the use of --privileged and elevated capabilities in Docker commands, which should be restricted to the minimum necessary permissions.
Comment on lines
+64
to
+65
| --cap-add SYS_PTRACE \ | ||
| --security-opt seccomp=unconfined \ |
Contributor
|
|
||
| docker run --gpus all --rm -d \ | ||
| --name flash_nsys_$BACKEND --ipc=host --network=host \ | ||
| --privileged \ |
Contributor
…00) (sgl-project#27) * feat(w4a16-deepseek): add SM90 W4A16 MoE path for DSv4 FP4 checkpoint Add DeepSeekW4A16MoEMethod, the H200/SM90 counterpart to DeepSeekMxfp4MoEMethod. Both classes consume the same DSv4 FP4 checkpoint (SGLANG_DSV4_MODE=2604 SGLANG_DSV4_FP4_EXPERTS=1); mxfp4_deepseek targets B200's trtllm_fp4_block_scale_routed_moe (MXFP8xMXFP4), and this new path targets flashinfer's SM90 mixed-input cutlass_fused_moe(..., use_w4_group_scaling=True) (BF16xMXFP4) introduced in flashinfer-ai/flashinfer#3084. Key differences from mxfp4_deepseek: - Pre-interleaves FP4 weights and MXFP4 block scales at load time via flashinfer's interleave_moe_weights_for_hopper_mixed_gemm / interleave_moe_scales_for_hopper_mixed_gemm helpers. Without this the SM90 LDSM-based FP4->BF16 pipeline reads LUT bytes from wrong positions and the output decorrelates for K > 128 (DSv4 has K=4096). - Kernel takes raw (token_selected_experts, token_final_scales) rather than the packed int32 (id<<16 | weight_bf16) that the TRT-LLM routed kernel expects; no PackTopkIds step. - Local-expert filtering is done via ep_size/ep_rank parameters on cutlass_fused_moe, so topk_ids are handed over in the GLOBAL id space (same as the mxfp4_deepseek dispatcher-mapping-undo logic). - SwiGLU clamp is plumbed via swiglu_limit (no separate gemm1_clamp_limit). w13 row order is unchanged: checkpoint stores [w1(gate), w3(up)] and we reorder to [w3(up), w1(gate)] to match the SM90 kernel's reference (test_moe_bf16_mxfp4 splits as `w3, w1 = torch.chunk(w31, 2, dim=0)`). Enabled by --moe-runner-backend flashinfer_w4a16. * fix(w4a16-deepseek): bump sunrise_moe_code_path_checker.observed on 260415 deepseek_v4.py forward asserts observed == 1 exactly once per MoE layer under SGLANG_DSV4_2604_SUBMODE=260415 and then resets. mxfp4_deepseek bumps the counter; the W4A16 path was forgetting, so the server crashed at the first real forward on the DSv4 0415_v5 checkpoint. Mirror the mxfp4_deepseek bump. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * feat(w4a16-deepseek): add SGLANG_HACK_DEBUG_W4A16_REMOVE_SWIGLU_LIMIT Force-override swiglu_limit to None in cutlass_fused_moe call when flag is set, to A/B-test whether the swiglu_limit=10.0 up-branch clamp path is the root cause of the AIME25 accuracy regression observed on the DSv4 260415 FP4 checkpoint. Keeps _swiglu_limit_tensor non-None so create_moe_runner's sanity assert and the sunrise_moe_code_path_checker bump remain unaffected. Flag defaults to False: production behavior unchanged. See journal 2026-04-21-024 for the repro plan. * fix(w4a16-deepseek): rename interleave_moe_*_for_hopper → _sm90 Upstream flashinfer PR sgl-project#3084 branch (samuellees/flashinfer@feat/w4a16-moe-kernel commit cb90611) renamed interleave_moe_{weights,scales}_for_{Hopper,hopper}_mixed_gemm to the _sm90_ variants. Sync the sglang W4A16 MoE wrapper to the new names. * journal(2026-04-21-022): DSv4 W4A16 H200 with cuda graph enabled First cuda-graph run of the W4A16 path on DSv4 FP4 ckpt. Records a clean 8m52s cold start (41s capture for 36 batch sizes), ~6000 tok/s aggregate peak at mc=256, and 17k+ decode batches all dispatching under graph with no kernel-level issues. AIME25 x16 ran ~5h to 12/16 seeds before a gloo TCP connection reset peeled the scheduler; partial pass@1[avg-of-12] = 75.56% ± 6.25%. GPQA queued but did not run. Next: relaunch + GPQA. * revert: remove 2026-04-21-022 journal (belongs on rcli-config branch) Journal committed to wrong branch; this branch (w4a16 PR) contains code changes only. Journal will be added under rcli-config per convention. * feat(w4a16-deepseek): add aime25_q6 single-question bench dataset Generated via sunrise/filter_nemo_skills_questions.py from the canonical nemo_skills aime25 source jsonl (see generate.sh for the reproducible one-liner). Purpose: one-question dense A/B subset for the W4A16 accuracy regression investigation. Journal 0421-024 observed pred=271 clustering on 11/31 wrong seeds for aime25-6 across all three arms; aime25_q6:64 lets us run a 64-repeat concentration experiment in minutes instead of 9 hours. * debug(w4a16-deepseek): assert fp32→UE8M0 scale conversion is lossless UE8M0 stores only the biased exponent, so a float32 block scale is only preserved when it's an exact power of 2. If DSv4's ckpt stores scales that aren't pure powers of 2, the .to(float8_e8m0fnu) round-trip silently rounds and feeds the kernel wrong scales — a plausible culprit for the AIME25 accuracy drop. This helper crashes loudly on the first mismatch with a sample of bad values instead of silently degrading. * feat(w4a16-deepseek): add SGLANG_HACK_DEBUG_W4A16_USE_BF16_API for dequant-ref When the flag is set, process_weights_after_loading dequants FP4+UE8M0 expert weights into plain bf16 (post reorder_w1w3_to_w3w1) and drops the scale parameters, and apply() calls cutlass_fused_moe with bf16 weights, quant_scales=None, use_w4_group_scaling=False. This routes the MoE through flashinfer's CutlassMoeFCRunner<bf16, bf16> specialization — a numerically independent reference path that does not share the SM90 mixed-input dequant/interleave code of PR sgl-project#3084, so any acc gap it closes isolates the regression to the W4A16 kernel / interleave side. Flag defaults to False; W4A16 behavior unchanged when off. * docs(w4a16-deepseek): cite flashinfer source for _dequant_mxfp4 copy Body and LUT copied verbatim from flashinfer-sunrise PR sgl-project#3084 (commit 77746b81) at tests/moe/test_trtllm_cutlass_fused_moe.py lines 2419-2452 (_MXFP4_LUT + _dequant_mxfp4_on_device). Bitwise equivalence verified on 5 random uint8 shapes (CPU torch.equal on bf16 output; NaN-position agreement on UE8M0=255 case separately). * test(sunrise): add verify_dequant_mxfp4.py Bitwise equivalence check between flashinfer-sunrise PR sgl-project#3084's _dequant_mxfp4_on_device (tests/moe/test_trtllm_cutlass_fused_moe.py @ 77746b81) and the sglang local copy in w4a16_deepseek.py. Paths are resolved relative to the script, with FLASHINFER_SUNRISE_TEST_FILE env override. CPU-only; no CUDA needed. * fix(w4a16-deepseek): extend StandardDispatcher skip_local_expert_mapping to flashinfer_w4a16 Previously the skip gate was `enable_flashinfer_mxfp4_moe and SGLANG_OPT_MXFP4_SKIP_DISPATCHER_MAPPING`, so for --moe-runner-backend=flashinfer_w4a16 the dispatcher always applied the global->local+sentinel mapping, while w4a16_deepseek.apply() (copy of mxfp4 logic) skipped the inverse undo when the env default (True) was active. Net effect under --ep>1: cutlass_fused_moe received local-index+(-1)-sentinel topk_ids interpreted as globals, causing ep_rank>0 experts to be filtered out and producing garbage output (degenerate token loops). TP-only arm masked it because local-id == global-id and no sentinels fire when num_local_experts == num_experts. Fix: include flashinfer_w4a16 in the skip gate alongside flashinfer_mxfp4_moe. Repro + diagnosis: sunrise/bench_records/journals/2026-04-22-003-w4a16-ep-garbage-bug-repro.md * feat(w4a16-deepseek): add SGLANG_HACK_DEBUG_W4A16_USE_TORCH_REF env flag Gate a new pure-torch MoE reference path as an acc-investigation arm that sits one level deeper than SGLANG_HACK_DEBUG_W4A16_USE_BF16_API: both dequant FP4 to bf16 once at load time, but BF16_API still calls the flashinfer bf16 grouped GEMM while TORCH_REF bypasses it entirely. * feat(w4a16-deepseek): add pure-torch MoE ref in debug_utils torch_ref_cutlass_fused_moe mirrors the flashinfer cutlass_fused_moe signature so the w4a16_deepseek apply() site can swap one for the other via a single local import, matching the pattern used by mxfp4_deepseek/naive_torch_trtllm_fp4_block_scale_routed_moe. Body adapted from flashinfer-sunrise tests/moe/test_trtllm_cutlass_fused_moe.py _compute_with_active_experts (commit 77746b81). * feat(w4a16-deepseek): wire torch-ref path through single MoE call site Gate the bf16-weight dequant branch in process_weights_after_loading on either BF16_API or TORCH_REF (mutually exclusive), and swap the MoE function in apply() via local import rather than duplicating the call with different args. * test(sunrise): add verify_torch_ref_w4a16_moe.py Element-wise smoke comparing torch_ref_cutlass_fused_moe against the flashinfer cutlass_fused_moe(use_w4_group_scaling=True, swiglu_limit=...) kernel on tiny random MXFP4 weights (shapes borrowed from flashinfer's own W4A16_CORRECTNESS_CONFIGS). Before committing bench-scale wall-clock to the torch-ref path, we want this to show that the two agree within ~1% at small shape. --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
… DSv4-Flash 2604B Key fix (RSF-PREMUL): - cutlass_fused_moe expects token_final_scales as the *final* output weight per expert, which must include routed_scaling_factor (1.5 for DSv4). - Original code passed raw topk_weights then did output.mul_(rsf) post-hoc, which is semantically wrong and caused -7.5pp accuracy regression. - Fix: pre-multiply topk_weights by RSF before passing to the kernel; remove the post-hoc multiply entirely. - Result: GSM8k 89.5% → 96.5% (+7pp), AIME25 31.7% → 46.7% (+15pp) Compatibility patches for public DeepSeek-V4-Flash (2604B submode): - w4a16_deepseek.py: update submode check to accept both '2604B' and '260415'; import code-path checker from deepseek_v4_debug_utils (renamed in main branch); always bump counter unconditionally; pass swiglu_alpha/beta as ones/zeros - mxfp4_deepseek.py: extend swiglu_limit assertion to include '260415' - fused_moe.py, deep_gemm.py: same submode extension - deepseek_v4.py: allow '260415' in load_weights submode assertion - mhc.py: add pure-PyTorch fallback for hc_split_sinkhorn (for envs without new tilelang supporting wg_wait) - sunrise_debug_utils.py: stub file for compatibility Accuracy validation on DeepSeek-V4-Flash (H200 TP=4): GSM8k 8-shot 1319q: W4A16=95.2% marlin=95.7% FP8=95.5% ✓ aligned AIME25 30q×4: W4A16=46.7% marlin=45.8% FP8=48.3% ✓ aligned GPQA 198q×8: W4A16=75.8% marlin=75.3% FP8=75.7% ✓ aligned Co-Authored-By: Claude Sonnet 4.6 (1M context) <noreply@anthropic.com>
- Rename flashinfer_w4a16 → flashinfer_cutlass_wmxfp4a16 (backend name, enum value, method names, standard dispatcher flag) - Remove debug_utils/sunrise_debug_utils.py (not needed in PR) - Remove fused_moe_triton/fused_moe.py (large unrelated file) - Remove mhc.py _hc_split_sinkhorn_torch (unrelated to W4A16 kernel) - Remove sunrise/ test artifacts (aime25_q6, verify_*.py) - Remove internal submode asserts (SGLANG_DSV4_2604_SUBMODE, "260415", "2604B") from deepseek_v4.py, w4a16_deepseek.py, mxfp4_flashinfer_trtllm_moe.py - Remove debug env vars (SGLANG_HACK_DEBUG_W4A16_*, SGLANG_DSV4_FIX_ATTN_PADDING) from environ.py and all usages in w4a16_deepseek.py - Remove sunrise_moe_code_path_checker counter bump (internal DSv4 artifact)
- fp8.py: fix broken class import Mxfp4FlashinferTrtllmMoEMethod → DeepSeekMxfp4MoEMethod; simplify W4A16 routing condition to match mxfp4 pattern (remove SGLANG_DSV4_MODE check, use self.is_fp4_experts) - mxfp4_flashinfer_trtllm_moe.py: remove broken import of deleted deepseek_v4_debug_utils module; remove SGLANG_DSV4_2604_SUBMODE internal counter bump; restore maybe_fuse_routed_scale_and_shared_add (used by deepseek_v2.py) with updated DeepSeekMxfp4MoEMethod name; remove spurious leading blank line - w4a16_deepseek.py: remove debug_utils/w4a16_moe_ref_related.py (pure- torch debug ref, not production code); remove unused get_global_server_args import; remove RC-X2 disabled renorm comment block; remove dead commented-out RSF post-hoc multiply block
- utils.py: rename enum member FLASHINFER_W4A16 → FLASHINFER_CUTLASS_WMXFP4A16 - mxfp4_flashinfer_trtllm_moe.py: revert class rename back to Mxfp4FlashinferTrtllmMoEMethod; remove all is_marlin() branches (MarlinMoeQuantInfo, marlin create_moe_runner/process_weights/apply paths) that have no place in the mxfp4 trtllm file - w4a16_deepseek.py → mxfp4_flashinfer_cutlass_moe.py: move W4A16 impl to a new file named after the backend, not the model family; rename class DeepSeekW4A16MoEMethod → Mxfp4FlashinferCutlassMoEMethod - fp8.py: update import and instantiation to use new file/class name
…e / Wmxfp4A16FlashinferCutlassMoEMethod
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces W4A16 (BF16 × MXFP4) MoE inference support for DeepSeek-V4-Flash on H200 (SM90), and fixes a critical accuracy bug in the W4A16 kernel that caused a -7pp regression on GSM8k.
Accuracy
Model: DeepSeek-V4-Flash
Codebase:
w4a16_on_dsv4(=origin/deepseek_v4+ this PR, same code for both backends)Throughput
Output token throughput (tok/s), random 1024-in / 1024-out, H200 TP=4:
W4A16 runs at 0.80–0.97x of marlin throughput. The gap is widest at mid-concurrency (cc=32–64) where MXFP4 GEMM tile utilization is sub-optimal; performance optimization is ongoing.
Server Launch
Checklist
Review and Merge Process
/tag-and-rerun-ci,/tag-run-ci-label,/rerun-failed-ci