Add FlashInfer SM90 cutlass MXFP4 MoE backend (W4A16) for GPT-OSS + DeepSeek-V4#24816
Conversation
|
/tag-and-rerun-ci |
There was a problem hiding this comment.
Code Review
This pull request introduces support for FlashInfer's SM90 cutlass mixed-input MoE GEMM for MXFP4 quantization, specifically targeting DeepSeek-V4 models on Hopper GPUs. It adds the Mxfp4FlashinferCutlassMoEMethod, implements the necessary weight and scale interleaving logic, and includes comprehensive unit tests and benchmarks. Feedback suggests using local expert counts and dynamic device assignment for SwiGLU scalars to correctly support Expert Parallelism, and ensuring tensor contiguity before passing data to FlashInfer's C++ extensions.
|
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
122fbc1 to
177e638
Compare
samuellees
left a comment
There was a problem hiding this comment.
It's great that this PR adds a candidate path for DS4 W4A16, I think we still need end-to-end DS4 accuracy validation before calling it fully supported.
From the experience of #24492, a few areas may affect DS4 correctness:
- SwiGLU clamp: DS4 may need explicit alpha=1, beta=0, limit, similar to #24492.
- TP behavior: it would be good to confirm whether FlashInfer should receive real tp_size/tp_rank, or tp_size=1, tp_rank=0 after SGLang has already sharded the weights.
- Routed scaling factor: DS4 uses routed_scaling_factor=1.5, so we should verify it is applied at the right stage.
- Checkpoint layout: unit tests may not fully cover the real DS4 FP4 checkpoint loading path.
- Accuracy: ideally run DS4 with TP=4/8 and compare with Marlin on GSM8k/AIME or similar evals.
@samuellees Agree. Will do it. |
|
Updated DeepSeekV4 GSM8k result in description.
|
|
@yiakwy-xpu-ml-framework-team #23686 is the marlin w4a16 supporting. I've verified e2e DSV4 on H100 with FlashInffer SM90 mxfp4 backend. |
|
LGTM overall~ |
@samuellees I tested FlashInfer cutlass and Marlin for DSV4-Flash, here's the GPQA Diamond score.
High concurrency (×8, 64 threads, ~64 in-flight requests): the gap collapses to ~3.5 % (366.7 s vs 379.6 s). In decode part (high concurrency scenario), probably we can also adopt FlashInfer backend. |
|
@samuellees We don't have the RSF-PREMUL fix in this PR, but we don't show the -7.5pp regression either. If the original code in your report was output += shared; output.mul_(rsf) (i.e. rsf * (shared + routed)), that would double-scale shared by rsf. That's the failure mode I'd expect to see -7.5pp on a smart-eval benchmark. This PR's add_(..., alpha=rsf) pattern avoids it. GPQA Diamond score updated. |
samuellees
left a comment
There was a problem hiding this comment.
This PR could be an alternative of #23681 and #24492
cc @Fridge003 for more comments
|
@yuan-luo The 0.73 accuracy result is not as expected.
Hopefully the result can be ~0.97 |
@Fridge003 tested Flash on H100. Here's the result. Eval: sgl-eval aime25, 30 problems x 16 repeats, --num-threads 64 Server: Client: Resultspass@1 (avg-of-16) = 94.38 % +/- 2.91 % (SEM 0.73 %) |
45c3b4c to
0b4c7e1
Compare
|
Tested on H800 with max-tokens 20000. Server: Client: AIME25 score: 97.29% |
0b4c7e1 to
15421b5
Compare
15421b5 to
3e51770
Compare
Wires FlashInfer's SM90 mixed-input cutlass_fused_moe(use_w4_group_scaling=True) path (FlashInfer PR sgl-project#3084) into Mxfp4MoEMethod as an opt-in backend on Hopper. Triggered by --moe-runner-backend flashinfer_mxfp4 on H100/H200; today that flag auto-selects only on SM100, so SM90 users still default to Marlin. Why: PR sgl-project#3084 ports TRT-LLM's LDSM + LUT FP4->BF16 weight load pipeline plus SM90 tactic-list pruning into FlashInfer's cutlass MoE kernel. Two new helpers (interleave_moe_weights_for_sm90_mixed_gemm, interleave_moe_scales_for_sm90_ mixed_gemm) are required at weight-load time. Available in flashinfer-python >= 0.6.11. Behavior split (kernel-level bench, H100, GPT-OSS-like body hidden=4096/inter=2048/E=256/topk=6, autotune ON): decode (M <= 64) : Marlin +12-15 % tie (M ~= 256) : ~equal prefill (M >= 1024) : FlashInfer +24-36 % <-- new value PD-disaggregation is the killer use case: prefill workers pick FlashInfer, decode workers keep Marlin, each phase runs the kernel that wins its regime. Focus on GPT-OSS model.
Also silence FlashInfer cutlass autotune trace via TLLM_LOG_LEVEL=INFO
3e51770 to
d2c7346
Compare
…eepSeek-V4 (#24816) Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This reverts commit 1913cb4.


Summary
Wires FlashInfer's SM90 mixed-input
cutlass_fused_moe(use_w4_group_scaling=True)path (FlashInfer PR #3084) into both SGLang MXFP4 entry points as an opt-in backend on Hopper:Mxfp4MoEMethodinmxfp4.py, dispatched fromMxfp4Config.Mxfp4FlashinferCutlassMoEMethod, sibling ofMxfp4MarlinMoEMethod/Mxfp4FlashinferTrtllmMoEMethod, dispatched fromFp8MoEConfig.get_quant_methodwhenis_fp4_experts=True.Both are triggered by
--moe-runner-backend flashinfer_mxfp4on H100/H200.Today that flag auto-selects only on SM100; on SM90 users still default to Marlin, so this PR is purely additive.
PD-disaggregation is the killer use case: prefill workers run FlashInfer
(+24–36% at M ≥ 1024), decode workers stay on Marlin (+12–15% at M ≤ 64),
each phase picks the kernel that wins its regime. See benchmark below.
No default behavior changes. Marlin remains the SM90 default for both models.
What FlashInfer PR #3084 ships
PR #3084 ports TensorRT-LLM PR #12451 into FlashInfer's existing
cutlass_fused_moekernel:mixed_input_utils.hppand
sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_*.hpp— replaces theprevious bit-shuffle path on Hopper for 4-bit weights × 16/8-bit activations.
cutlass_heuristic.cpp: skipCtaShape128x128x128B + COOPERATIVEforhas_w4afp8(register overflow on SM90),pick
COOPERATIVE/PINGPONGper tile.moe_gemm_tma_ws_mixed_input_launcher.inl:max_swizzle_size=2,raster_order=Heuristic.interleave_moe_weights_for_sm90_mixed_gemm(weight, "fp4"|"int4")— C++byte-level interleave of packed 4-bit weights.
interleave_moe_scales_for_sm90_mixed_gemm(scales, group_size=32)— pure-PyTorchreshape + permute of E8M0 block scales (factor
128 // group_size).Both helpers must run once at weight-load time. They live in
flashinfer.fused_moe.coreand ship in FlashInfer ≥ 0.6.11. Depends on #24452What changes in SGLang
GPT-OSS path —
python/sglang/srt/layers/quantization/mxfp4.pyMxfp4MoEMethod.__init__— newself._fi_kerneldiscriminator(
"trtllm_sm100"/"cutlass_sm90"/None). Imports of the new helpers areguarded with
try/exceptso older FlashInfer builds still load.create_weights— adds an SM90 cutlass branch that padsintermediate_sizeand
hidden_sizeto multiples of 128 (mixed-input GEMM contraction-dimconstraint:
K % 128 == 0because the scale interleave factor is128 // group_size = 4).process_weights_after_loading— early-dispatch into_process_weights_for_sm90_cutlass, which:w13_weight/w2_weightwith the FP4 helper,w13_weight_scale/w2_weight_scale,existing SM100 trtllm-gen path's GPT-OSS hardcoded defaults).
apply— early-dispatch into_apply_sm90_cutlass, which pads input,calls
flashinfer_cutlass_fused_moe(use_w4_group_scaling=True, …)withper-expert SwiGLU scalars and bias, and trims the output back to the original
hidden width.
DeepSeek-V4 path —
python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py(new)New sibling class
Mxfp4FlashinferCutlassMoEMethodparallelingMxfp4MarlinMoEMethod/Mxfp4FlashinferTrtllmMoEMethod:fp8.py:get_quant_methodchecks SM at runtime: SM100continues to route to
Mxfp4FlashinferTrtllmMoEMethod(trtllm-gen),SM90 routes to the new class.
process_weights_after_loading—[w1; w3] → [w3; w1]viareorder_w1w3_to_w3w1(cutlassexpects
[up; gate], matching the trtllm-gen convention),.to(float8_e8m0fnu).view(uint8),apply—cutlass_fused_moe(use_w4_group_scaling=True)with biasesNone(DSv4 has no MoE bias),swiglu_alpha/swiglu_betaNone(kernel defaults give standard SwiGLU), and
swiglu_limitfrommoe_runner_config.swiglu_limit(per-expert tensor).isinstancecheck inmxfp4_flashinfer_trtllm_moe.maybe_fuse_routed_scale_and_shared_addso thenew class participates in the same shared-add fusion as Marlin / trtllm-gen.
Padding is asserted rather than applied: DSv4's standard config (hidden=7168,
intermediate=2048) is already a multiple of 128. Variants with non-aligned
shapes will raise a clear error.
Tests
Unit tests
python/sglang/test/test_mxfp4_sm90_cutlass.py— 12 cases, all PASSED on H100:GPT-OSS path (5 + 4):
test_process_weights_matches_direct_interleave× 5 — verifies that_process_weights_for_sm90_cutlass(de-interleave + pad + halved swap +byte/scale interleave) produces bit-exact the same bytes as a hand-rolled
reference that performs the same transform sequence. Cases include the
GPT-OSS-20B production shape (2880 × 2880) and the 192 × 192 boundary case
to exercise the pad path.
test_apply_sm90_cutlass_matches_flashinfer_direct× 4 — end-to-end onrandom MXFP4 inputs: SGLang's
_apply_sm90_cutlassoutput is bit-exactequal to a direct
cutlass_fused_moe(use_w4_group_scaling=True, …)callfed with the layer's processed tensors (with manual pad x + trim out for
the unaligned shape). Confirms bias / SwiGLU scalar /
quant_scalesplumbing + input-pad / output-trim is correct.
DeepSeek-V4 path (3):
test_dsv4_apply_matches_flashinfer_direct× 3 — end-to-end on randomDSv4-style MXFP4 inputs (int8-packed weights, fp32 E8M0 scales).
Mxfp4FlashinferCutlassMoEMethod.applyoutput is bit-exact equal to areference path that manually applies
reorder_w1w3_to_w3w1+ the fp32→uint8E8M0 cast + the FlashInfer interleave helpers + a direct
cutlass_fused_moe(use_w4_group_scaling=True, biases=None, swiglu_*=None)call. Confirms the DSv4-specific reorder + scale-cast plumbing is correct.
End-to-end accuracy
Both paths exercised on real production checkpoints with chat-API GSM8K
(200 questions, T=0, on H100):
GPT-OSS-20B —
Mxfp4MoEMethod(mxfp4.py)triton_kernelbaselineIt matches the published GPT-OSS-20B GSM8K range (≈ 91–94 %).
DeepSeek-V4-Flash —
Mxfp4FlashinferCutlassMoEMethod(new)98.5 % is consistent with DSv4-Flash's published GSM8K range.
Benchmark
python/sglang/test/bench_mxfp4_sm90_kernels.py— kernel-level perf on H100(80 GB HBM3, cap 9.0). Both paths use the same random MXFP4 weights;
autotune(True)is active for FlashInfer; bias-on and bias-off are autotunedindependently. Body fixed at the GPT-OSS-like config (hidden=4096, inter=2048,
E=256, top_k=6) and tokens swept across the decode → prefill range.
Each row is the median of 30 timed iterations after 5 warmups.
Bias overhead (FI bias-on vs FI bias-off) is consistently within ±2 %, i.e.
essentially free.
Interpretation
The two backends have clearly distinct sweet spots:
and native E8M0 scale path squeeze more out of small-M grouped GEMM where
the kernel is launch-bound and per-CTA work matters more than aggregate
throughput.
24–36 % faster. At large M the kernel becomes GEMM-bound, the SM90
TMA-warp-specialized cooperative scheduler from PR [hotfix] fix test_sampling_scaling_penalties.py ci test #3084 fully populates the
H100 SM array, and Marlin's per-CTA overhead dominates instead.
Sanity check: FlashInfer's m=4 number (0.195 ms) matches PR #3084's H200 table
(0.193 ms), so the integration is correct. PR #3084's reported 2.66–3.11×
speedup is vs FlashInfer's own previous SM90 path (~0.79 ms at m=4),
not vs Marlin — Marlin wasn't in that comparison.
Conclusion
MXFP4 default; it dominates the small-M region typical of pure decode.
chunked-prefill (≥ 1024 tokens/forward), or PD-disaggregation prefill
workers all benefit by 24–36 %.
(
--moe-runner-backend flashinfer_mxfp4) and decode workers (Marlin default)can pick the best kernel for their phase.
throughput benchmark on a target workload is the next step before flipping
any default.
Caveats / known issues
flashinfer_pythonis a separate change; for now the import is gated andraises a clear error if the user picks
flashinfer_mxfp4on SM90 without thehelpers.
hidden=2880 → padded 2944) pay a ~2 % memory overhead in the MoE weights.
(α=1.702, β=1.0, limit=7.0), mirroring the existing SM100 trtllm-gen path.
When other models adopt this backend, they should be plumbed from
moe_runner_config.gemm1_alpha/.gemm1_clamp_limit.toggles bias on/off through the same MoE method, run autotune for each
configuration. (Affects only autotune cache hit-rate, not correctness.)
Files changed
python/sglang/srt/layers/quantization/mxfp4.py(+~140 lines, GPT-OSS path)python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py(new, ~230 lines, DSv4 path)python/sglang/srt/layers/quantization/fp8.py(+~15 lines, SM-aware dispatch)python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py(+~5 lines, fuse-check)python/sglang/test/test_mxfp4_sm90_cutlass.py(new, ~370 lines, 9 cases)python/sglang/test/bench_mxfp4_sm90_kernels.py(new, ~280 lines)