Skip to content

[NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)#23273

Open
wenscarl wants to merge 1 commit into
sgl-project:mainfrom
wenscarl:gdnmtp_decode
Open

[NVIDIA] [GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)#23273
wenscarl wants to merge 1 commit into
sgl-project:mainfrom
wenscarl:gdnmtp_decode

Conversation

@wenscarl
Copy link
Copy Markdown
Collaborator

@wenscarl wenscarl commented Apr 20, 2026

[GDN] Enable FlashInfer MTP verify on SM100+ (Blackwell)

co-authored by @YAMY1234 (main contributor)

Summary

Enables FlashInfer GDN MTP (speculative decoding) verify on SM100+ (Blackwell) hardware, previously raising NotImplementedError. SM90 (Hopper) MTP was already supported; this PR completes SM100+ coverage.

Root cause: target_verify guarded on use_state_pool, blocking SM100+ even though the FlashInfer gated_delta_rule_mtp kernel already accepts initial_state_indices (pool API) — the same API used by the SM90 path.

Changes (2 files, ~15 lines):

  • gdn_flashinfer.py: remove use_state_pool guard in target_verify; unify SM90 + SM100+ into a single pool-API path; add A_log.detach().float() cast (matches SM100+ decode path, no-op on SM90).
  • server_args.py: remove and self.speculative_algorithm is None from the SM100+ FlashInfer auto-default — FlashInfer is now safe to default on SM100+ regardless of whether MTP is enabled.

Accuracy (Qwen3.5-397B-A17B-NVFP4, B200)

gsm8k (TODO: examples, baseline threshold: 0.95)

SGLANG_ENABLE_SPEC_V2=1 python3 -m sglang.launch_server --model-path nvidia/Qwen3.5-397B-A17B-NVFP4 --tokenizer-path nvidia/Qwen3.5-397B-A17B-NVFP4 --trust-remote-code --host 0.0.0.0 --port 8000 --tp-size 4 --chunked-prefill-size 2048 --mamba-scheduler-strategy extra_buffer --mamba-track-interval 128 --mamba-ssm-dtype bfloat16 --max-running-requests 128 --reasoning-parser qwen3 --attention-backend trtllm_mha --quantization modelopt_fp4 --speculative-algorithm NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4 --mem-fraction-static 0.8 --model-loader-extra-config '{"enable_multithread_load": true,"num_threads": 64}'

python3 -m sglang.test.run_eval   --model nvidia/Qwen3.5-397B-A17B-NVFP4   --eval-name gsm8k   --num-shots 5   --num-examples 200   --max-tokens 16000   --num-threads 128   --repeat 1   --temperature 0.6   --top-p 0.95   --top-k 20   --base-url http://127.0.0.1:8000   --host http://127.0.0.1   --port 8000
Backend Score
Triton (decode + MTP) 0.985
FlashInfer (decode + MTP) 0.980

GPQA Diamond (TODO: examples, repeat=8, temperature=0.6)

fi:
Total latency: 247.500 s
Score: 0.859
Output throughput: 6286.781 token/s

and

triton
Total latency: 253.352 s
Score: 0.854
Output throughput: 6196.159 token/
Backend Scores
Triton (decode + MTP) 0.854
FlashInfer (decode + MTP) 0.859

Throughput Benchmark (GB200, Qwen3.5-397B-A17B-NVFP4, TP=4)

Focus: long output sequence length (OSL), where per-step GDN state-update cost is most significant.

Server settings:

  • --tp-size 4 --max-running-requests 128
  • --mamba-ssm-dtype bfloat16 --mamba-scheduler-strategy no_buffer --mamba-track-interval 128
  • --attention-backend trtllm_mha --linear-attn-decode-backend <triton|flashinfer>
  • --speculative-algorithm NEXTN (MTP runs)
  • --disable-radix-cache --quantization modelopt_fp4

Benchmark settings:

  • --dataset-name random --random-input-len 32 --random-output-len <512|1024|2048|4096>
  • --num-prompts <varied> --request-rate inf

Decode throughput (w/ MTP), output throughput( tok/s) — ISL=32
acc len: 3.13-3.22
num_prompts: 256

OSL no-MTP MTP Speedup
1024 2731.86 3682.65 1.35x
2048 2937.99 4329.87 1.47x
4096 2915.84 4831.15 1.66x
OSL Triton FlashInfer Speedup
1024 3645.60 3682.65 1.01x
2048 4145.32 4329.87 1.04x
4096 4707.04 4831.15 1.03x

Mean TPOT (ms/tok), ISL=32, OSL=512

concurrency FlashInfer TPOT (ms) Triton TPOT (ms) Speedup (Triton / FlashInfer) Winner
1 3.23 3.34 1.034 FlashInfer
4 4.89 4.94 1.010 FlashInfer
16 10.21 10.29 1.008 FlashInfer
32 15.58 15.79 1.013 FlashInfer
64 23.29 24.64 1.058 FlashInfer
128 32.86 34.25 1.042 FlashInfer
256 31.14 31.62 1.015 FlashInfer

Requirements

The traces are collected at ISL: 32 OSL: 512, CC: 64
Flashinfer:
Screenshot 2026-04-23 at 3 21 58 PM

triton:
Screenshot 2026-04-23 at 3 26 48 PM

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@wenscarl wenscarl marked this pull request as ready for review April 22, 2026 14:56
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

mmangkad added a commit to mmangkad-dev/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad-dev/sglang that referenced this pull request Apr 28, 2026
…fy on SM100+ (Blackwell)

Resolved conflicts with PR sgl-project#22921:
- gdn_flashinfer.py: combined module and class docstrings to reflect that
  SM100+ now supports decode, prefill, and MTP verify.
- gdn_flashinfer.py target_verify: dropped the SM100+ NotImplementedError
  guard so the pool-API MTP path runs on both SM90 and SM100+.
- server_args.py: kept the bf16 gate from sgl-project#22921 and removed the
  speculative_algorithm gate now that MTP verify is supported on SM100+.
mmangkad added a commit to mmangkad-dev/sglang that referenced this pull request Apr 28, 2026
PR sgl-project#22921 renamed the SM-gating attribute from use_state_pool to
is_sm100plus (updating all existing call sites). PR sgl-project#23273 was authored
against the older name and added a new reference in the bf16 MTP adapter
setup. The git auto-merge picked up sgl-project#22921's renames and sgl-project#23273's new
block, leaving a single dangling use_state_pool access that crashed at
FlashInferGDNKernel.__init__.

Rename the one remaining reference to is_sm100plus to match the rest of
the class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant