[NVIDIA] Integrate FlashInfer decode kernel (Blackwell) for Qwen3.5#19150
[NVIDIA] Integrate FlashInfer decode kernel (Blackwell) for Qwen3.5#19150ispobock merged 7 commits intosgl-project:mainfrom
Conversation
Summary of ChangesHello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the performance of Qwen3.5 hybrid linear attention models by integrating FlashInfer's optimized decode kernel for Gated Delta Network (GDN) layers. It introduces a new configurable backend option, Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request integrates FlashInfer's gated_delta_rule_decode_pretranspose kernel as an optional backend for Gated Delta Network (GDN) layers, specifically targeting Qwen3.5 hybrid linear attention models. The implementation includes a new server argument --gdn-backend, a specialized decode path that operates directly on a K-last state pool to eliminate gather/scatter overhead, and a prefill path that uses explicit gather/scatter for compatibility with existing kernels. My feedback focuses on a usability issue in the server argument validation and a minor inconsistency in parameter casting within the FlashInfer decode path.
python/sglang/srt/server_args.py
Outdated
|
|
||
| # The flashinfer GDN decode path (gated_delta_rule_decode_pretranspose) | ||
| # uses a bf16 state kernel; non-bf16 pools are not supported. | ||
| if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype != "bfloat16": |
There was a problem hiding this comment.
The validation logic for mamba_ssm_dtype is too strict when gdn_backend is set to flashinfer. Since self.mamba_ssm_dtype defaults to None, this check will raise a ValueError even if the user doesn't explicitly provide the flag, forcing them to specify --mamba-ssm-dtype bfloat16 manually even if the model configuration already defaults to bfloat16. The check should only trigger if mamba_ssm_dtype is explicitly set to a value other than bfloat16.
| if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype != "bfloat16": | |
| if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype is not None and self.mamba_ssm_dtype != "bfloat16": |
| A_log=layer.A_log.detach().float(), | ||
| a=a_fi, | ||
| dt_bias=layer.dt_bias.detach(), |
There was a problem hiding this comment.
There is an inconsistency in how parameters are passed to the FlashInfer kernel. A_log is explicitly cast to float(), which is good for precision during exponentiation, but dt_bias is passed without a cast. In the Triton path (line 945), both are effectively treated as float32 inside the kernel. If the FlashInfer kernel expects float32 for dt_bias to maintain numerical stability during the softplus calculation, it should be cast here as well.
| A_log=layer.A_log.detach().float(), | |
| a=a_fi, | |
| dt_bias=layer.dt_bias.detach(), | |
| A_log=layer.A_log.detach().float(), | |
| a=a_fi, | |
| dt_bias=layer.dt_bias.detach().float(), |
|
cc. @hlu1 |
a48b5b6 to
f69fdf3
Compare
|
Rebased to main to follow the new argument structure introduced by #18622 (linear attention backend refactor). The old New usage: python -m sglang.launch_server --model <Qwen3-Next-model> \
--mamba-ssm-dtype bfloat16 \
--linear-attn-decode-backend flashinferNote: |
|
|
||
|
|
||
| class FlashInferGDNKernel(LinearAttnKernelBase): | ||
| """FlashInfer pretranspose kernel for GDN decode. |
There was a problem hiding this comment.
This is not necessarily for decode, it can be prefill in the future. Maybe revise the comment.
42a4404 to
93520f1
Compare
b6dd6f8 to
cdffe64
Compare
526e113 to
06ea4f3
Compare
|
Also, cc. @xutizhou who has done the same thing for the hopper. |
|
/tag-and-rerun-ci |
|
/tag-and-rerun-ci |
|
@ispobock any thing blocking the merge? |
|
@kaixih Qwen3.5 ci test should be passed |
|
@ispobock thx for the pointer. the tests pass on my b200 machine after I do this patch:
Update: |
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The Triton prefill kernel (chunk_delta_h.py) already uses K-last strides (K, 1) matching the VK pool layout [pool, HV, V, K] natively. The previous _use_flashinfer_pool path unnecessarily gathered states, transposed to KV, ran the kernel, then transposed back — the double transpose cancels for zero initial states but is conceptually wrong. Remove _use_flashinfer_pool and pass ssm_states + cache_indices directly to the Triton prefill kernel for all backends. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Swap the use_state_pool branch to call fused_sigmoid_gating_delta_rule_update (Triton) instead of the FlashInfer CuTe DSL kernel. Original FlashInfer call is commented out for easy restoration. Result: 0.980 accuracy on gsm8k (vs 0.890 with FlashInfer decode), confirming the accuracy gap is entirely in the FlashInfer decode kernel, not prefill. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…x and add flashinfer gsm8k test - Restore FlashInfer pool API (initial_state + initial_state_indices) for SM100+ decode path; remove the temporary Triton-at-callsite debug block - Pass dt_bias as float32 to the kernel (dt_bias.detach().float()); the kernel reads dt_bias_val without an explicit fp32 cast, so passing bf16 caused a precision gap (0.89 → 0.94 on gsm8k with all precision fixes) - Add TestQwen35FP4Flashinfer: same as TestQwen35FP4 but with --linear-attn-decode-backend flashinfer and threshold 0.93 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
06ea4f3 to
aef6aa7
Compare
…uracyTestParams - Merge TestQwen35FP4 and TestQwen35FP4Flashinfer into a single class using run_combined_tests with Triton and FlashInfer variants - Add top_k support to AccuracyTestParams/_run_simple_eval - Restore dt_bias.detach().float() fix (regression from debug session) - Remove stale debug comments in gdn_flashinfer.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
aef6aa7 to
836924d
Compare
|
Rebased and added one test for the flashinfer decode backend: To test: Results: @ispobock PTAL |
…duling FlashInfer GDN decode is incompatible with --mamba-scheduler-strategy no_buffer, causing ~5-10% accuracy degradation on gsm8k. Raise a clear ValueError pointing to the upstream issue. See sgl-project#20791 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move the incompatibility check for --linear-attn-decode-backend flashinfer with --mamba-scheduler-strategy no_buffer to before the extra_buffer block, so it fails fast before unrelated extra_buffer validations run. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…gl-project#19150) Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
…gl-project#19150) Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Summary
Integrate FlashInfer's
gated_delta_rule_decode_pretransposekernel as an optional backend (--gdn-backend flashinfer) for GDN (Gated Delta Network) layers in Qwen3.5 hybrid linear attention models.What changed
--gdn-backend {triton,cutedsl,flashinfer}server argument (default:triton)gated_delta_rule_decode_pretranspose) which operates directly on the K-last state pool viainitial_state/initial_state_indices, eliminating explicit gather/scatter calls on every decode step (feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path) flashinfer-ai/flashinfer#2619)chunk_kdawithinitial_state_indicesfor gather/scatter into the poolPerformance (Qwen3.5-FP8, decode-focused, 256 concurrency, 8xB200)
Per-step profiling: overall decode step 353 µs → 316 µs (−10%), GDN kernel 30 µs → 16 µs (−47%).
The end-to-end gains are modest (~2%) because GDN accounts for a relatively small fraction of the total decode step. The kernel-level improvement is more significant: GDN alone drops by ~47%, contributing ~38% of the overall per-step savings. The remaining improvement comes from other FlashInfer kernel optimizations.
Note: the prefill path still uses KV layout with explicit gather/scatter (the pool->batch trick), which limits the overall e2e gains. A dedicated prefill kernel with native pool support is left for a follow-up PR.
Kernel microbenchmark — T=1 decode latency (µs)
At batch sizes ≥ 32, the FlashInfer bf16 kernel (this PR) achieves ~2.4–2.8x speedup over the triton KV reference and ~1.4–1.6x over the CuteDSL VK kernel from #17981. Note, both #17981 and this PR use VK layout.
Accuracy