Skip to content

[GDN] Remove FlashInfer GDN decode + no_buffer guard and default to FlashInfer on SM100+ #21861

Open
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:remove-gdn-no-buffer-guard
Open

[GDN] Remove FlashInfer GDN decode + no_buffer guard and default to FlashInfer on SM100+ #21861
YAMY1234 wants to merge 2 commits intosgl-project:mainfrom
YAMY1234:remove-gdn-no-buffer-guard

Conversation

@YAMY1234
Copy link
Copy Markdown
Contributor

@YAMY1234 YAMY1234 commented Apr 1, 2026

Motivation

The FlashInfer GDN decode kernel was previously blocked from being used with --mamba-scheduler-strategy no_buffer due to accuracy degradation caused by OOB memory access from negative padding indices in the bf16 decode kernel (see #20791).

The root cause was fixed in FlashInfer v0.6.7 via flashinfer-ai/flashinfer#2810 (padding index guard for bf16 decode kernel). Thanks to @kaixih!

With PR #21422 merged (upgrading FlashInfer to v0.6.7), we are able to remove this guard and proceed with further benchmarking.

Modifications

  1. Remove the raise ValueError guard that blocked --linear-attn-decode-backend flashinfer with --mamba-scheduler-strategy no_buffer.
  2. Default to FlashInfer GDN decode on SM100+ when mamba-ssm-dtype=bfloat16 and no explicit decode backend is specified. This is excluded when MTP speculative decoding is active, since FlashInfer GDN MTP verify is not yet supported on SM100+.

Benchmarking

Accuracy Validation

Verified on Qwen3.5-397B-A17B (4×GB200), GSM8K (1319 examples, 8-shot, temp=0.6):

Scenario Quantization Scheduler MTP conc=128 conc=512
NVFP4 + no_buffer modelopt_fp4 no_buffer - 0.977 0.977
NVFP4 + extra_buffer modelopt_fp4 extra_buffer - 0.979 0.977
FP8 + no_buffer fp8 no_buffer - 0.978 0.977
FP8 + extra_buffer fp8 extra_buffer - 0.979 0.979
NVFP4 + MTP modelopt_fp4 extra_buffer NEXTN N/A N/A

MTP is excluded because FlashInfer GDN MTP verify is not yet supported on SM100+.

Performance

sa-bench, ISL=1024, OSL=1024, NVFP4, 4×GB200
Comparison: --linear-attn-decode-backend flashinfer vs default (triton)

Concurrency Baseline Mean TPOT LADFI Mean TPOT Difference Output tok/s Improvement
2 5.68 ms 5.61 ms -1.3% +1.3%
4 6.43 ms 6.35 ms -1.2% +0.5%
8 7.33 ms 7.29 ms -0.6% +1.2%
16 8.58 ms 8.54 ms -0.5% +0.5%
32 10.86 ms 10.77 ms -0.8% +0.7%
64 13.61 ms 13.43 ms -1.3% +1.6%
128 17.36 ms 17.16 ms -1.2% +1.2%
256 23.60 ms 23.04 ms -2.4% +2.3%
512 33.70 ms 32.47 ms -3.7% +3.6%
1024 53.50 ms 51.08 ms -4.5% +4.7%

TPOT improves across all concurrency levels, with gains increasing at higher concurrencies, reaching up to 4.5% TPOT reduction and 4.7% throughput improvement at conc=1024.

Closes #20791

Motivation

Modifications

Accuracy Tests

Speed Tests and Profiling

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

YAMY1234 added 2 commits April 1, 2026 00:25
The root cause (OOB memory access from negative padding indices in
bf16 decode kernel) was fixed in FlashInfer v0.6.7 via
flashinfer-ai/flashinfer#2810.

Verified on Qwen3.5-397B-A17B-NVFP4 (4xGB200, no_buffer +
disable-radix-cache + --linear-attn-decode-backend flashinfer):
  - GSM8K accuracy: 0.977-0.979 across conc=128/512
  - sa-bench TPOT improvement: 1-5% vs baseline (no ladfi)

Closes sgl-project#20791
On SM100+ with mamba-ssm-dtype=bfloat16, automatically set
--linear-attn-decode-backend to flashinfer when not explicitly
specified. This gives 1-5% TPOT improvement at higher concurrencies.

The prerequisite bug (OOB from negative padding indices in bf16
decode kernel) was fixed in FlashInfer v0.6.7 via
flashinfer-ai/flashinfer#2810.

Verified on Qwen3.5-397B-A17B-NVFP4 (4xGB200, no_buffer +
disable-radix-cache), sa-bench ISL=1024 OSL=1024, conc 2-1024:
  - GSM8K accuracy: 0.977-0.979
  - Mean TPOT: -1.3% (conc=2) to -4.5% (conc=1024)
  - Excluded when MTP speculative decoding is active (not yet supported)
  - Output throughput: +1.3% (conc=2) to +4.7% (conc=1024)
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@YAMY1234
Copy link
Copy Markdown
Contributor Author

YAMY1234 commented Apr 1, 2026

/tag-and-rerun-ci

@github-actions github-actions bot added the run-ci label Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] [GDN] Accuracy degradation with flashinfer gated_delta_rule_decode_pretranspose under no_buffer scheduling

1 participant