Skip to content

Add FlashInfer SM90 cutlass MXFP4 MoE backend (W4A16) for GPT-OSS + DeepSeek-V4#24816

Merged
Fridge003 merged 5 commits into
sgl-project:mainfrom
yuan-luo:support_fi_w4a16
May 13, 2026
Merged

Add FlashInfer SM90 cutlass MXFP4 MoE backend (W4A16) for GPT-OSS + DeepSeek-V4#24816
Fridge003 merged 5 commits into
sgl-project:mainfrom
yuan-luo:support_fi_w4a16

Conversation

@yuan-luo
Copy link
Copy Markdown
Collaborator

@yuan-luo yuan-luo commented May 9, 2026

Summary

Wires FlashInfer's SM90 mixed-input cutlass_fused_moe(use_w4_group_scaling=True) path (FlashInfer PR #3084) into both SGLang MXFP4 entry points as an opt-in backend on Hopper:

  • GPT-OSS path: Mxfp4MoEMethod in mxfp4.py, dispatched from Mxfp4Config.
  • DeepSeek-V4 path: new Mxfp4FlashinferCutlassMoEMethod, sibling of
    Mxfp4MarlinMoEMethod / Mxfp4FlashinferTrtllmMoEMethod, dispatched from
    Fp8MoEConfig.get_quant_method when is_fp4_experts=True.
image

Both are triggered by --moe-runner-backend flashinfer_mxfp4 on H100/H200.
Today that flag auto-selects only on SM100; on SM90 users still default to Marlin, so this PR is purely additive.

PD-disaggregation is the killer use case: prefill workers run FlashInfer
(+24–36% at M ≥ 1024), decode workers stay on Marlin (+12–15% at M ≤ 64),
each phase picks the kernel that wins its regime.
See benchmark below.

No default behavior changes. Marlin remains the SM90 default for both models.

What FlashInfer PR #3084 ships

PR #3084 ports TensorRT-LLM PR #12451 into FlashInfer's existing cutlass_fused_moe kernel:

  • LDSM + LUT FP4 / INT4 → BF16 weight load pipeline in mixed_input_utils.hpp
    and sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_*.hpp — replaces the
    previous bit-shuffle path on Hopper for 4-bit weights × 16/8-bit activations.
  • Tactic-list pruning in cutlass_heuristic.cpp: skip
    CtaShape128x128x128B + COOPERATIVE for has_w4afp8 (register overflow on SM90),
    pick COOPERATIVE / PINGPONG per tile.
  • Scheduler tweaks in moe_gemm_tma_ws_mixed_input_launcher.inl:
    max_swizzle_size=2, raster_order=Heuristic.
  • Two new Python helpers that the kernel layout requires:
    • interleave_moe_weights_for_sm90_mixed_gemm(weight, "fp4"|"int4") — C++
      byte-level interleave of packed 4-bit weights.
    • interleave_moe_scales_for_sm90_mixed_gemm(scales, group_size=32) — pure-PyTorch
      reshape + permute of E8M0 block scales (factor 128 // group_size).

Both helpers must run once at weight-load time. They live in flashinfer.fused_moe.core and ship in FlashInfer ≥ 0.6.11. Depends on #24452

What changes in SGLang

GPT-OSS path — python/sglang/srt/layers/quantization/mxfp4.py

  1. Mxfp4MoEMethod.__init__ — new self._fi_kernel discriminator
    ("trtllm_sm100" / "cutlass_sm90" / None). Imports of the new helpers are
    guarded with try/except so older FlashInfer builds still load.
  2. create_weights — adds an SM90 cutlass branch that pads intermediate_size
    and hidden_size to multiples of 128 (mixed-input GEMM contraction-dim
    constraint: K % 128 == 0 because the scale interleave factor is
    128 // group_size = 4).
  3. process_weights_after_loading — early-dispatch into
    _process_weights_for_sm90_cutlass, which:
    • byte-interleaves w13_weight / w2_weight with the FP4 helper,
    • reshape+permute interleaves w13_weight_scale / w2_weight_scale,
    • builds per-expert SwiGLU scalars (α=1.702, β=1.0, limit=7.0; matches the
      existing SM100 trtllm-gen path's GPT-OSS hardcoded defaults).
  4. apply — early-dispatch into _apply_sm90_cutlass, which pads input,
    calls flashinfer_cutlass_fused_moe(use_w4_group_scaling=True, …) with
    per-expert SwiGLU scalars and bias, and trims the output back to the original
    hidden width.

DeepSeek-V4 path — python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py (new)

New sibling class Mxfp4FlashinferCutlassMoEMethod paralleling
Mxfp4MarlinMoEMethod / Mxfp4FlashinferTrtllmMoEMethod:

  1. Dispatchfp8.py:get_quant_method checks SM at runtime: SM100
    continues to route to Mxfp4FlashinferTrtllmMoEMethod (trtllm-gen),
    SM90 routes to the new class.
  2. process_weights_after_loading
    • calls the FP8 base hook (ROCm normalization etc),
    • reorders [w1; w3] → [w3; w1] via reorder_w1w3_to_w3w1 (cutlass
      expects [up; gate], matching the trtllm-gen convention),
    • converts the fp32-stored E8M0 scales to raw uint8 bytes via
      .to(float8_e8m0fnu).view(uint8),
    • byte-interleaves weights, reshape+permute interleaves scales.
  3. applycutlass_fused_moe(use_w4_group_scaling=True) with biases
    None (DSv4 has no MoE bias), swiglu_alpha/swiglu_beta None
    (kernel defaults give standard SwiGLU), and swiglu_limit from
    moe_runner_config.swiglu_limit (per-expert tensor).
  4. Routed-scale fusion — extends the isinstance check in
    mxfp4_flashinfer_trtllm_moe.maybe_fuse_routed_scale_and_shared_add so the
    new class participates in the same shared-add fusion as Marlin / trtllm-gen.

Padding is asserted rather than applied: DSv4's standard config (hidden=7168,
intermediate=2048) is already a multiple of 128. Variants with non-aligned
shapes will raise a clear error.

Tests

Unit tests

python/sglang/test/test_mxfp4_sm90_cutlass.py — 12 cases, all PASSED on H100:

GPT-OSS path (5 + 4):

  • test_process_weights_matches_direct_interleave × 5 — verifies that
    _process_weights_for_sm90_cutlass (de-interleave + pad + halved swap +
    byte/scale interleave) produces bit-exact the same bytes as a hand-rolled
    reference that performs the same transform sequence. Cases include the
    GPT-OSS-20B production shape (2880 × 2880) and the 192 × 192 boundary case
    to exercise the pad path.
  • test_apply_sm90_cutlass_matches_flashinfer_direct × 4 — end-to-end on
    random MXFP4 inputs: SGLang's _apply_sm90_cutlass output is bit-exact
    equal to a direct cutlass_fused_moe(use_w4_group_scaling=True, …) call
    fed with the layer's processed tensors (with manual pad x + trim out for
    the unaligned shape). Confirms bias / SwiGLU scalar / quant_scales
    plumbing + input-pad / output-trim is correct.

DeepSeek-V4 path (3):

  • test_dsv4_apply_matches_flashinfer_direct × 3 — end-to-end on random
    DSv4-style MXFP4 inputs (int8-packed weights, fp32 E8M0 scales).
    Mxfp4FlashinferCutlassMoEMethod.apply output is bit-exact equal to a
    reference path that manually applies reorder_w1w3_to_w3w1 + the fp32→uint8
    E8M0 cast + the FlashInfer interleave helpers + a direct
    cutlass_fused_moe(use_w4_group_scaling=True, biases=None, swiglu_*=None)
    call. Confirms the DSv4-specific reorder + scale-cast plumbing is correct.

End-to-end accuracy

Both paths exercised on real production checkpoints with chat-API GSM8K
(200 questions, T=0, on H100):

GPT-OSS-20BMxfp4MoEMethod (mxfp4.py)

$ python -m sglang.launch_server --model openai/gpt-oss-20b \
      --moe-runner-backend flashinfer_mxfp4 --attention-backend triton \
      --port 30011 --tp 1
$ python -m sglang.test.run_eval --port 30011 --eval-name gsm8k \
      --num-examples 200 --api chat --max-tokens 1024 --num-threads 32 \
      --temperature 0
backend GSM8K score latency throughput
FlashInfer cutlass (this PR) 0.940 20.25 s 3338 tok/s
triton_kernel baseline 0.945 17.44 s 3926 tok/s

It matches the published GPT-OSS-20B GSM8K range (≈ 91–94 %).

DeepSeek-V4-FlashMxfp4FlashinferCutlassMoEMethod (new)

$ python -m sglang.launch_server --model deepseek-ai/DeepSeek-V4-Flash \
      --moe-runner-backend flashinfer_mxfp4 --port 30000 --tp 4 \
      --trust-remote-code
$ python -m sglang.test.run_eval --port 30000 --eval-name gsm8k \
      --num-examples 200 --api chat --max-tokens 1024 --num-threads 32 \
      --temperature 0
metric value
GSM8K score (200 q, chat API, T=0) 0.985
Total latency 49.21 s
Output throughput 469 tok/s

98.5 % is consistent with DSv4-Flash's published GSM8K range.

Benchmark

python/sglang/test/bench_mxfp4_sm90_kernels.py — kernel-level perf on H100
(80 GB HBM3, cap 9.0). Both paths use the same random MXFP4 weights;
autotune(True) is active for FlashInfer; bias-on and bias-off are autotuned
independently. Body fixed at the GPT-OSS-like config (hidden=4096, inter=2048,
E=256, top_k=6) and tokens swept across the decode → prefill range.

tokens Marlin FI cutlass (AT, bias) FI cutlass (AT, no bias) Winner Margin
4 0.174 ms 0.195 ms 0.190 ms Marlin +12 %
16 0.484 ms 0.549 ms 0.549 ms Marlin +13 %
64 1.046 ms 1.204 ms 1.204 ms Marlin +15 %
256 1.549 ms 1.587 ms 1.583 ms tie +2 %
1024 2.084 ms 1.943 ms 1.954 ms FlashInfer +7 %
2048 3.420 ms 2.663 ms 2.624 ms FlashInfer +28 %
4096 7.011 ms 5.149 ms 5.134 ms FlashInfer +36 %
8192 12.541 ms 10.085 ms 10.176 ms FlashInfer +24 %

Each row is the median of 30 timed iterations after 5 warmups.
Bias overhead (FI bias-on vs FI bias-off) is consistently within ±2 %, i.e.
essentially free.

Interpretation

The two backends have clearly distinct sweet spots:

  • Decode regime (m ≤ 64). Marlin is 12–15 % faster. Marlin's tile policy
    and native E8M0 scale path squeeze more out of small-M grouped GEMM where
    the kernel is launch-bound and per-CTA work matters more than aggregate
    throughput.
  • Tie zone (m ≈ 256). Both within 2 % of each other.
  • Prefill / chunked-prefill regime (m ≥ 1024). FlashInfer cutlass is
    24–36 % faster. At large M the kernel becomes GEMM-bound, the SM90
    TMA-warp-specialized cooperative scheduler from PR [hotfix] fix test_sampling_scaling_penalties.py ci test #3084 fully populates the
    H100 SM array, and Marlin's per-CTA overhead dominates instead.

Sanity check: FlashInfer's m=4 number (0.195 ms) matches PR #3084's H200 table
(0.193 ms), so the integration is correct. PR #3084's reported 2.66–3.11×
speedup is vs FlashInfer's own previous SM90 path (~0.79 ms at m=4),
not vs Marlin — Marlin wasn't in that comparison.

Conclusion

  • Default unchanged for decode-heavy serving. Marlin remains the SM90
    MXFP4 default; it dominates the small-M region typical of pure decode.
  • FlashInfer cutlass is the right pick for prefill. Long context, large
    chunked-prefill (≥ 1024 tokens/forward), or PD-disaggregation prefill
    workers all benefit by 24–36 %.
  • PD-disaggregation natural fit: prefill workers
    (--moe-runner-backend flashinfer_mxfp4) and decode workers (Marlin default)
    can pick the best kernel for their phase.
  • Mixed-batch deployments: pick by expected average M; an end-to-end
    throughput benchmark on a target workload is the next step before flipping
    any default.

Caveats / known issues

  • Requires FlashInfer ≥ 0.6.11 (PR [hotfix] fix test_sampling_scaling_penalties.py ci test #3084 landed after 0.6.10). Bumping
    flashinfer_python is a separate change; for now the import is gated and
    raises a clear error if the user picks flashinfer_mxfp4 on SM90 without the
    helpers.
  • Padding to 128 is required on SM90 cutlass; small models (e.g. GPT-OSS
    hidden=2880 → padded 2944) pay a ~2 % memory overhead in the MoE weights.
  • SwiGLU per-expert scalars are hardcoded to GPT-OSS defaults
    (α=1.702, β=1.0, limit=7.0), mirroring the existing SM100 trtllm-gen path.
    When other models adopt this backend, they should be plumbed from
    moe_runner_config.gemm1_alpha / .gemm1_clamp_limit.
  • FlashInfer's autotuner does not key tactics on bias presence; if a deployment
    toggles bias on/off through the same MoE method, run autotune for each
    configuration. (Affects only autotune cache hit-rate, not correctness.)

Files changed

  • python/sglang/srt/layers/quantization/mxfp4.py (+~140 lines, GPT-OSS path)
  • python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py (new, ~230 lines, DSv4 path)
  • python/sglang/srt/layers/quantization/fp8.py (+~15 lines, SM-aware dispatch)
  • python/sglang/srt/layers/quantization/mxfp4_flashinfer_trtllm_moe.py (+~5 lines, fuse-check)
  • python/sglang/test/test_mxfp4_sm90_cutlass.py (new, ~370 lines, 9 cases)
  • python/sglang/test/bench_mxfp4_sm90_kernels.py (new, ~280 lines)

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 9, 2026

/tag-and-rerun-ci

@github-actions github-actions Bot added the run-ci label May 9, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FlashInfer's SM90 cutlass mixed-input MoE GEMM for MXFP4 quantization, specifically targeting DeepSeek-V4 models on Hopper GPUs. It adds the Mxfp4FlashinferCutlassMoEMethod, implements the necessary weight and scale interleaving logic, and includes comprehensive unit tests and benchmarks. Feedback suggests using local expert counts and dynamic device assignment for SwiGLU scalars to correctly support Expert Parallelism, and ensuring tensor contiguity before passing data to FlashInfer's C++ extensions.

Comment thread python/sglang/srt/layers/quantization/mxfp4.py
Comment thread python/sglang/srt/layers/quantization/mxfp4.py Outdated
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 9, 2026

Per FlashInfer PR's result (below), it was tested on H200, but the performance is almost the same as my test on H100. FlashInfer kernel has some space to improve in the short token part.
image

image

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 9, 2026

GPT-OSS path:
test_process_weights_matches_direct_interleave[4-256-256]      PASSED
test_process_weights_matches_direct_interleave[8-768-384]      PASSED
test_process_weights_matches_direct_interleave[8-1024-1024]    PASSED
test_apply_sm90_cutlass_matches_flashinfer_direct[4-4-...]     PASSED
test_apply_sm90_cutlass_matches_flashinfer_direct[16-8-...]    PASSED
test_apply_sm90_cutlass_matches_flashinfer_direct[32-8-...]    PASSED

DSv4 path:
test_dsv4_apply_matches_flashinfer_direct[4-4-256-256-2]       PASSED
test_dsv4_apply_matches_flashinfer_direct[16-8-768-384-2]      PASSED
test_dsv4_apply_matches_flashinfer_direct[256-8-1024-1024-4]   PASSED

@yuan-luo yuan-luo marked this pull request as draft May 9, 2026 10:12
@yuan-luo yuan-luo marked this pull request as ready for review May 9, 2026 10:13
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@yuan-luo yuan-luo force-pushed the support_fi_w4a16 branch 2 times, most recently from 122fbc1 to 177e638 Compare May 9, 2026 14:47
Copy link
Copy Markdown
Contributor

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's great that this PR adds a candidate path for DS4 W4A16, I think we still need end-to-end DS4 accuracy validation before calling it fully supported.

From the experience of #24492, a few areas may affect DS4 correctness:

  • SwiGLU clamp: DS4 may need explicit alpha=1, beta=0, limit, similar to #24492.
  • TP behavior: it would be good to confirm whether FlashInfer should receive real tp_size/tp_rank, or tp_size=1, tp_rank=0 after SGLang has already sharded the weights.
  • Routed scaling factor: DS4 uses routed_scaling_factor=1.5, so we should verify it is applied at the right stage.
  • Checkpoint layout: unit tests may not fully cover the real DS4 FP4 checkpoint loading path.
  • Accuracy: ideally run DS4 with TP=4/8 and compare with Marlin on GSM8k/AIME or similar evals.

Comment thread python/sglang/srt/layers/quantization/mxfp4.py
Comment thread python/sglang/srt/layers/quantization/mxfp4.py
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 10, 2026

It's great that this PR adds a candidate path for DS4 W4A16, I think we still need end-to-end DS4 accuracy validation before calling it fully supported.

From the experience of #24492, a few areas may affect DS4 correctness:

  • SwiGLU clamp: DS4 may need explicit alpha=1, beta=0, limit, similar to feat(w4a16): SM90 W4A16 MoE path for DeepSeek-V4-Flash + RSF pre-multiply fix #24492.
  • TP behavior: it would be good to confirm whether FlashInfer should receive real tp_size/tp_rank, or tp_size=1, tp_rank=0 after SGLang has already sharded the weights.
  • Routed scaling factor: DS4 uses routed_scaling_factor=1.5, so we should verify it is applied at the right stage.
  • Checkpoint layout: unit tests may not fully cover the real DS4 FP4 checkpoint loading path.
  • Accuracy: ideally run DS4 with TP=4/8 and compare with Marlin on GSM8k/AIME or similar evals.

@samuellees Agree. Will do it.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 10, 2026

Updated DeepSeekV4 GSM8k result in description.

$ python -m sglang.launch_server --model deepseek-ai/DeepSeek-V4-Flash \
      --moe-runner-backend flashinfer_mxfp4 --port 30000 --tp 4 \
      --trust-remote-code
$ python -m sglang.test.run_eval --port 30000 --eval-name gsm8k \
      --num-examples 200 --api chat --max-tokens 1024 --num-threads 32 \
      --temperature 0
metric value
GSM8K score (200 q, chat API, T=0) 0.985
Total latency 49.21 s
Output throughput 469 tok/s

@yiakwy-xpu-ml-framework-team
Copy link
Copy Markdown
Contributor

@yuan-luo is the issue resolved ? ref #23686

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

@yuan-luo is the issue resolved ? ref #23686

@yiakwy-xpu-ml-framework-team #23686 is the marlin w4a16 supporting. I've verified e2e DSV4 on H100 with FlashInffer SM90 mxfp4 backend.

@samuellees
Copy link
Copy Markdown
Contributor

LGTM overall~
Could you also verify the AIME or GPQA accuracy, please? @Fridge003 provided an insight that DeepSeek V4 could do a great job on GSM8k even with some numeric bugs. That's exactly what #23681 and #24492 want to solve.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 11, 2026

LGTM overall~ Could you also verify the AIME or GPQA accuracy, please? @Fridge003 provided an insight that DeepSeek V4 could do a great job on GSM8k even with some numeric bugs. That's exactly what #23681 and #24492 want to solve.

@samuellees I tested FlashInfer cutlass and Marlin for DSV4-Flash, here's the GPQA Diamond score.

eval n_repeats examples FlashInfer cutlass score Marlin score FlashInfer latency Marlin latency
GPQA Diamond ×8 (chat API) 8 1584 0.7386 0.7330 379.6 s 366.7 s

High concurrency (×8, 64 threads, ~64 in-flight requests): the gap collapses to ~3.5 % (366.7 s vs 379.6 s). In decode part (high concurrency scenario), probably we can also adopt FlashInfer backend.

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 11, 2026

@samuellees We don't have the RSF-PREMUL fix in this PR, but we don't show the -7.5pp regression either.
Our post-hoc multiply is a single fused op: shared.add_(routed, alpha=routed_scaling_factor), which is exactly shared + rsf * routed. It does not scale the shared expert.

If the original code in your report was output += shared; output.mul_(rsf) (i.e. rsf * (shared + routed)), that would double-scale shared by rsf. That's the failure mode I'd expect to see -7.5pp on a smart-eval benchmark. This PR's add_(..., alpha=rsf) pattern avoids it.

GPQA Diamond score updated.

Copy link
Copy Markdown
Contributor

@samuellees samuellees left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR could be an alternative of #23681 and #24492
cc @Fridge003 for more comments

Comment thread python/sglang/srt/layers/quantization/fp8.py Outdated
Comment thread python/sglang/srt/layers/quantization/mxfp4.py
Comment thread python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py Outdated
Comment thread python/sglang/srt/layers/quantization/mxfp4_flashinfer_cutlass_moe.py Outdated
Comment thread python/sglang/srt/layers/quantization/mxfp4.py
@Fridge003
Copy link
Copy Markdown
Collaborator

@yuan-luo The 0.73 accuracy result is not as expected.
For accuracy testing, please benchmark AIME25 on V4-Pro with this tool https://github.com/sgl-project/sgl-eval, and:

  1. When launching server, add these flags:
SGLANG_DEFAULT_THINKING=1  SGLANG_DSV4_REASONING_EFFORT=max
  1. Set temperature to 1.0, top_p to 1.0 and OSL to 400K as benchmark setting
  2. Enable MTP and larger concurrency (like 512) during benchmark, otherwise it will be really long

Hopefully the result can be ~0.97

@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 12, 2026

@yuan-luo The 0.73 accuracy result is not as expected. For accuracy testing, please benchmark AIME25 on V4-Pro with this tool https://github.com/sgl-project/sgl-eval, and:

  1. When launching server, add these flags:
SGLANG_DEFAULT_THINKING=1  SGLANG_DSV4_REASONING_EFFORT=max
  1. Set temperature to 1.0, top_p to 1.0 and OSL to 400K as benchmark setting
  2. Enable MTP and larger concurrency (like 512) during benchmark, otherwise it will be really long

Hopefully the result can be ~0.97

@Fridge003 tested Flash on H100. Here's the result.

Eval: sgl-eval aime25, 30 problems x 16 repeats, --num-threads 64
Backend: flashinfer_mxfp4 (SM90 cutlass mixed-input, this PR)
Hardware: 8x H100 80GB, TP=8
Sampling: T=1.0, top_p=1.0, max_tokens=131072

Server:

FLASHINFER_DISABLE_VERSION_CHECK=1 \
SGLANG_DEFAULT_THINKING=1 \
SGLANG_DSV4_REASONING_EFFORT=max \
nohup python -m sglang.launch_server \
  --model deepseek-ai/DeepSeek-V4-Flash \
  --moe-runner-backend flashinfer_mxfp4 \
  --port 30000 --tp 8 \
  --trust-remote-code \
  --reasoning-parser deepseek-v4 \
  --context-length 393216 \
  --speculative-algorithm EAGLE \
  --speculative-num-steps 3 \
  --speculative-eagle-topk 1 \
  --speculative-num-draft-tokens 4

Client:

OPENAI_API_KEY=EMPTY nohup sgl-eval run aime25 \
  --base-url http://127.0.0.1:30000/v1 \
  --temperature 1.0 --top-p 1.0 \
  --thinking \
  --max-tokens 131072 \
  --num-threads 64
aime25 rep  1/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=90.00%]
aime25 rep  2/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep  3/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep  4/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=90.00%]
aime25 rep  5/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]
aime25 rep  6/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep  7/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep  8/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=90.00%]
aime25 rep  9/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=100.00%]
aime25 rep 10/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep 11/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]]
aime25 rep 12/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]
aime25 rep 13/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=96.67%]
aime25 rep 14/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]
aime25 rep 15/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]
aime25 rep 16/16: 100%|██████████| 30/30 [1:35:46<00:00, 191.55s/it, acc=93.33%]
aime25 overall  : 100%|██████████| 480/480 [1:35:46<00:00, 11.97s/it, acc=94.38%]
== aime25 ==
30 examples x 16 repeats  |  5746.5s  |  2581 tok/s  |  14.8M tokens

* pass@1[avg-of-16]  =  94.38% +/- 2.91% (SEM 0.73%)
  pass@16            =  100.00%
  majority@16        =  100.00%
  no_answer          =  5.00%  [warn: consider --max-tokens]

Results

pass@1 (avg-of-16) = 94.38 % +/- 2.91 % (SEM 0.73 %)
pass@16 = 100.00 %
majority@16 = 100.00 %
no_answer = 5.00 % (capped by --max-tokens 131072)
total tokens = 14.8 M
wall time = 5746.5 s (~96 min, 2581 tok/s)

@yuan-luo yuan-luo force-pushed the support_fi_w4a16 branch from 45c3b4c to 0b4c7e1 Compare May 12, 2026 09:01
@yuan-luo
Copy link
Copy Markdown
Collaborator Author

yuan-luo commented May 12, 2026

Tested on H800 with max-tokens 20000.

Server:

FLASHINFER_DISABLE_VERSION_CHECK=1 SGLANG_DEFAULT_THINKING=1 SGLANG_DSV4_REASONING_EFFORT=max nohup python -m sglang.launch_server --model deepseek-ai/DeepSeek-V4-Flash --moe-runner-backend flashinfer_mxfp4 --port 30000 --tp 8 --trust-remote-code --reasoning-parser deepseek-v4 --context-length 393216 --speculative-algorithm EAGLE --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --max-running-requests 64 --mem-fraction-static 0.88

Client:

OPENAI_API_KEY=EMPTY nohup sgl-eval run aime25 --base-url http://127.0.0.1:30000/v1 --temperature 1.0 --top-p 1.0 --thinking --max-tokens 200000 --num-threads 64 --n-repeats 16 --out-dir /tmp/aime25_h800_n32 

AIME25 score: 97.29%

aime25 rep  1/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=100.00%]
aime25 rep  2/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=100.00%]  
aime25 rep  3/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]
aime25 rep  4/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=100.00%]
aime25 rep  5/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]  
aime25 rep  6/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=100.00%]  
aime25 rep  7/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=100.00%]  
aime25 rep  8/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=93.33%]
aime25 rep  9/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]   
aime25 rep 10/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]%]
aime25 rep 11/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=93.33%]
aime25 rep 12/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]
aime25 rep 13/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]  %]
aime25 rep 14/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]  
aime25 rep 15/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%]%]
aime25 rep 16/16: 100%|██████████| 30/30 [1:49:40<00:00, 219.36s/it, acc=96.67%] %]
aime25 overall  : 100%|██████████| 480/480 [1:49:40<00:00, 13.71s/it, acc=97.29%]
== aime25 ==
30 examples x 16 repeats  |  6580.7s  |  2369 tok/s  |  15.6M tokens

* pass@1[avg-of-16]  =  97.29% +/- 2.18% (SEM 0.55%)
  pass@16            =  100.00%
  majority@16        =  100.00%
  no_answer          =  1.67%

@Fridge003 Fridge003 mentioned this pull request May 12, 2026
34 tasks
@yuan-luo yuan-luo force-pushed the support_fi_w4a16 branch from 0b4c7e1 to 15421b5 Compare May 13, 2026 05:55
Comment thread test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py
Comment thread test/registered/unit/layers/quantization/test_mxfp4_sm90_cutlass.py
Comment thread python/sglang/srt/layers/quantization/mxfp4.py
luoyuan.luo added 4 commits May 13, 2026 15:13
  Wires FlashInfer's SM90 mixed-input cutlass_fused_moe(use_w4_group_scaling=True)
  path (FlashInfer PR sgl-project#3084) into Mxfp4MoEMethod as an opt-in backend on Hopper.
  Triggered by --moe-runner-backend flashinfer_mxfp4 on H100/H200; today that flag
  auto-selects only on SM100, so SM90 users still default to Marlin.

  Why: PR sgl-project#3084 ports TRT-LLM's LDSM + LUT FP4->BF16 weight load pipeline plus
  SM90 tactic-list pruning into FlashInfer's cutlass MoE kernel. Two new helpers
  (interleave_moe_weights_for_sm90_mixed_gemm, interleave_moe_scales_for_sm90_
  mixed_gemm) are required at weight-load time. Available in flashinfer-python
  >= 0.6.11.

  Behavior split (kernel-level bench, H100, GPT-OSS-like body
  hidden=4096/inter=2048/E=256/topk=6, autotune ON):

    decode  (M <=   64) :  Marlin     +12-15 %
    tie     (M ~=  256) :  ~equal
    prefill (M >= 1024) :  FlashInfer +24-36 %  <-- new value

  PD-disaggregation is the killer use case: prefill workers pick FlashInfer,
  decode workers keep Marlin, each phase runs the kernel that wins its regime.

  Focus on GPT-OSS model.
Also silence FlashInfer cutlass autotune trace via TLLM_LOG_LEVEL=INFO
@yuan-luo yuan-luo force-pushed the support_fi_w4a16 branch from 3e51770 to d2c7346 Compare May 13, 2026 07:14
@Fridge003 Fridge003 merged commit 28758d3 into sgl-project:main May 13, 2026
390 of 482 checks passed
Fridge003 pushed a commit that referenced this pull request May 13, 2026
…eepSeek-V4 (#24816)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
@yuan-luo yuan-luo deleted the support_fi_w4a16 branch May 14, 2026 02:37
ch-wan added a commit that referenced this pull request May 15, 2026
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Fridge003 added a commit that referenced this pull request May 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants