Skip to content

[NVIDIA] Integrate FlashInfer decode kernel (Blackwell) for Qwen3.5#19150

Merged
ispobock merged 7 commits intosgl-project:mainfrom
kaixih:integrate_flashinfer_gdn
Mar 18, 2026
Merged

[NVIDIA] Integrate FlashInfer decode kernel (Blackwell) for Qwen3.5#19150
ispobock merged 7 commits intosgl-project:mainfrom
kaixih:integrate_flashinfer_gdn

Conversation

@kaixih
Copy link
Collaborator

@kaixih kaixih commented Feb 22, 2026

Summary

Integrate FlashInfer's gated_delta_rule_decode_pretranspose kernel as an optional backend (--gdn-backend flashinfer) for GDN (Gated Delta Network) layers in Qwen3.5 hybrid linear attention models.

What changed

  • Added --gdn-backend {triton,cutedsl,flashinfer} server argument (default: triton)
  • FlashInfer decode path: uses the bf16 pretranspose kernel (gated_delta_rule_decode_pretranspose) which operates directly on the K-last state pool via initial_state/initial_state_indices, eliminating explicit gather/scatter calls on every decode step (feat: add pool+indices support to gated_delta_rule_decode_pretranspose (bf16 path)  flashinfer-ai/flashinfer#2619)
  • FlashInfer prefill path: uses chunk_kda with initial_state_indices for gather/scatter into the pool
  • Requires FlashInfer ≥ 0.6.4 (Need update!)

Performance (Qwen3.5-FP8, decode-focused, 256 concurrency, 8xB200)

Metric Before After Δ
Output throughput (tok/s) 5638 5741 +1.8%
TPOT (ms) 21.51 20.96 -2.6%
Mean ITL (ms) 21.58 21.04 -2.5%
Total throughput (tok/s) 7047 7176 +1.8%

Per-step profiling: overall decode step 353 µs → 316 µs (−10%), GDN kernel 30 µs → 16 µs (−47%).

The end-to-end gains are modest (~2%) because GDN accounts for a relatively small fraction of the total decode step. The kernel-level improvement is more significant: GDN alone drops by ~47%, contributing ~38% of the overall per-step savings. The remaining improvement comes from other FlashInfer kernel optimizations.

Note: the prefill path still uses KV layout with explicit gather/scatter (the pool->batch trick), which limits the overall e2e gains. A dedicated prefill kernel with native pool support is left for a follow-up PR.

Kernel microbenchmark — T=1 decode latency (µs)

At batch sizes ≥ 32, the FlashInfer bf16 kernel (this PR) achieves ~2.4–2.8x speedup over the triton KV reference and ~1.4–1.6x over the CuteDSL VK kernel from #17981. Note, both #17981 and this PR use VK layout.

B TRI-KV-BF16 (triton, ref) SG-VK-BF16 (#17981) FI-VK-BF16 (this PR)
1 4.3 3.0 2.8
32 59.1 34.4 20.9
64 118.4 69.5 49.0
128 235.7 139.0 95.8
256 466.1 283.7 187.8
512 931.4 565.9 365.6

Accuracy

Benchmark Backend Accuracy
GSM8K flashinfer (this PR) 0.945
GPQA flashinfer (this PR) 0.848, 0.874, 0.884, 0.879, 0.889, 0.838, 0.889, 0.869 (mean: 0.871)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @kaixih, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of Qwen3.5 hybrid linear attention models by integrating FlashInfer's optimized decode kernel for Gated Delta Network (GDN) layers. It introduces a new configurable backend option, --gdn-backend flashinfer, allowing users to leverage FlashInfer's gated_delta_rule_decode_pretranspose kernel. This integration streamlines state management during decoding, leading to notable speedups in GDN operations and overall throughput, while maintaining accuracy.

Highlights

  • FlashInfer GDN Kernel Integration: Integrated FlashInfer's gated_delta_rule_decode_pretranspose kernel as an optional backend (--gdn-backend flashinfer) for Gated Delta Network (GDN) layers in Qwen3.5 hybrid linear attention models.
  • New GDN Backend Argument: Introduced a new server argument --gdn-backend with choices triton, cutedsl, and flashinfer, allowing users to select their preferred GDN kernel implementation.
  • Optimized Decode Path: Optimized the FlashInfer decode path to directly operate on the K-last state pool via initial_state/initial_state_indices, eliminating explicit gather/scatter calls on every decode step.
  • Performance Improvements: Achieved notable performance gains, including a 1.8% increase in overall output throughput and a significant 47% reduction in GDN kernel time during decode operations.
  • Memory Layout Optimization: Ensured compatibility and proper state layout for FlashInfer by making the SSM state pool K-contiguous when the FlashInfer GDN backend is active, facilitating direct kernel operations.
  • Backend Validation: Added validation to enforce that the FlashInfer GDN backend requires the mamba_ssm_dtype to be bfloat16.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
    • Removed an unused import for sglang.srt.environ.Envs.
    • Added conditional import and initialization of flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose based on the selected gdn_backend.
    • Modified the __init__ method to handle GDN backend selection, including fallback logic if FlashInfer is requested but unavailable.
    • Updated the forward_decode method to conditionally use the FlashInfer pretranspose kernel, adjusting tensor layouts and detaching parameters as required.
    • Refactored the forward_extend method to introduce use_gather_scatter logic, managing state for FlashInfer, NPU, and CPU backends, including making states contiguous and scattering updates back to the pool.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Modified the __init__ method to conditionally transpose and make the temporal_state (SSM state pool) K-contiguous when the FlashInfer GDN backend is enabled, optimizing its layout for FlashInfer's kernel.
    • Applied similar contiguous transformation to intermediate_ssm_state_cache for FlashInfer compatibility.
  • python/sglang/srt/server_args.py
    • Defined GDN_BACKEND_CHOICES to include "triton", "cutedsl", and "flashinfer".
    • Added gdn_backend as a new attribute to ServerArgs with a default value of "triton".
    • Implemented _handle_gdn_backend to manage GDN backend selection, including backward compatibility for SGLANG_USE_CUTEDSL_GDN_DECODE and validation for bfloat16 dtype requirement when using FlashInfer.
    • Extended add_cli_args to include --gdn-backend as a command-line argument for configuring the GDN kernel.
Activity
  • No specific activity (comments, reviews, or progress updates) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates FlashInfer's gated_delta_rule_decode_pretranspose kernel as an optional backend for Gated Delta Network (GDN) layers, specifically targeting Qwen3.5 hybrid linear attention models. The implementation includes a new server argument --gdn-backend, a specialized decode path that operates directly on a K-last state pool to eliminate gather/scatter overhead, and a prefill path that uses explicit gather/scatter for compatibility with existing kernels. My feedback focuses on a usability issue in the server argument validation and a minor inconsistency in parameter casting within the FlashInfer decode path.


# The flashinfer GDN decode path (gated_delta_rule_decode_pretranspose)
# uses a bf16 state kernel; non-bf16 pools are not supported.
if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype != "bfloat16":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The validation logic for mamba_ssm_dtype is too strict when gdn_backend is set to flashinfer. Since self.mamba_ssm_dtype defaults to None, this check will raise a ValueError even if the user doesn't explicitly provide the flag, forcing them to specify --mamba-ssm-dtype bfloat16 manually even if the model configuration already defaults to bfloat16. The check should only trigger if mamba_ssm_dtype is explicitly set to a value other than bfloat16.

Suggested change
if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype != "bfloat16":
if self.gdn_backend == "flashinfer" and self.mamba_ssm_dtype is not None and self.mamba_ssm_dtype != "bfloat16":

Comment on lines +929 to +931
A_log=layer.A_log.detach().float(),
a=a_fi,
dt_bias=layer.dt_bias.detach(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an inconsistency in how parameters are passed to the FlashInfer kernel. A_log is explicitly cast to float(), which is good for precision during exponentiation, but dt_bias is passed without a cast. In the Triton path (line 945), both are effectively treated as float32 inside the kernel. If the FlashInfer kernel expects float32 for dt_bias to maintain numerical stability during the softplus calculation, it should be cast here as well.

Suggested change
A_log=layer.A_log.detach().float(),
a=a_fi,
dt_bias=layer.dt_bias.detach(),
A_log=layer.A_log.detach().float(),
a=a_fi,
dt_bias=layer.dt_bias.detach().float(),

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 22, 2026

cc. @hlu1

@kaixih
Copy link
Collaborator Author

kaixih commented Feb 27, 2026

Rebased to main to follow the new argument structure introduced by #18622 (linear attention backend refactor). The old --gdn-backend flag has been replaced by the new --linear-attn-decode-backend / --linear-attn-prefill-backend flags.

New usage:

python -m sglang.launch_server --model <Qwen3-Next-model> \
  --mamba-ssm-dtype bfloat16 \
  --linear-attn-decode-backend flashinfer

Note: --linear-attn-backend flashinfer is intentionally rejected (FlashInfer only provides a decode kernel. Will enable it when Flashinfer has prefill kernels). The prefill path uses Triton with gather/scatter to handle the K-contiguous state pool.



class FlashInferGDNKernel(LinearAttnKernelBase):
"""FlashInfer pretranspose kernel for GDN decode.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not necessarily for decode, it can be prefill in the future. Maybe revise the comment.

@kaixih kaixih force-pushed the integrate_flashinfer_gdn branch 4 times, most recently from 42a4404 to 93520f1 Compare March 8, 2026 16:14
@kaixih kaixih changed the title [NVIDIA] Integrate FlashInfer decode kernel for Qwen3.5 [NVIDIA] Integrate FlashInfer decode kernel for Qwen3.5 (WIP) Mar 8, 2026
@kaixih kaixih force-pushed the integrate_flashinfer_gdn branch 4 times, most recently from b6dd6f8 to cdffe64 Compare March 9, 2026 01:10
@kaixih kaixih force-pushed the integrate_flashinfer_gdn branch 3 times, most recently from 526e113 to 06ea4f3 Compare March 9, 2026 11:42
@kaixih kaixih changed the title [NVIDIA] Integrate FlashInfer decode kernel for Qwen3.5 (WIP) [NVIDIA] Integrate FlashInfer decode kernel for Qwen3.5 Mar 9, 2026
@kaixih kaixih changed the title [NVIDIA] Integrate FlashInfer decode kernel for Qwen3.5 [NVIDIA] Integrate FlashInfer decode kernel (Blackwell) for Qwen3.5 Mar 9, 2026
@kaixih
Copy link
Collaborator Author

kaixih commented Mar 9, 2026

Also, cc. @xutizhou who has done the same thing for the hopper.

@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@ispobock
Copy link
Collaborator

/tag-and-rerun-ci

@kaixih
Copy link
Collaborator Author

kaixih commented Mar 16, 2026

@ispobock any thing blocking the merge?

@ispobock
Copy link
Collaborator

@kaixih Qwen3.5 ci test should be passed

@kaixih
Copy link
Collaborator Author

kaixih commented Mar 16, 2026

@kaixih Qwen3.5 ci test should be passed

@ispobock is there a pointer for this test?

@ispobock
Copy link
Collaborator

@kaixih
Copy link
Collaborator Author

kaixih commented Mar 16, 2026

@ispobock thx for the pointer. the tests pass on my b200 machine after I do this patch:

diff --git a/python/sglang/srt/multimodal/processors/qwen_vl.py b/python/sglang/srt/multimodal/processors/qwen_vl.py
index 4395654e4..4ad214327 100644
--- a/python/sglang/srt/multimodal/processors/qwen_vl.py
+++ b/python/sglang/srt/multimodal/processors/qwen_vl.py
@@ -7,7 +7,10 @@ from typing import List, Union
 import numpy as np
 import torch
 import torchvision
-from decord import VideoReader
+try:
+    from decord import VideoReader
+except ImportError:
+    VideoReader = None
 from PIL import Image
 from torchvision.transforms import InterpolationMode
 
@@ -156,7 +159,7 @@ async def preprocess_video(
     video_config: dict = {},
 ) -> torch.Tensor:
     # preprocessed video
-    if not isinstance(vr, VideoReader):
+    if VideoReader is None or not isinstance(vr, VideoReader):
         return vr
     entry_time = time.perf_counter()

this patch is needed even without my change. so i am not sure if you see the same failure on decord. or do you want me to include this patch in this pr or in a separate pr? or is this completely unrelated?

Update:
Rebasing seems to resolve the above issue. Rerun the tests.

kaixih and others added 4 commits March 16, 2026 20:03
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
The Triton prefill kernel (chunk_delta_h.py) already uses K-last strides
(K, 1) matching the VK pool layout [pool, HV, V, K] natively. The previous
_use_flashinfer_pool path unnecessarily gathered states, transposed to KV,
ran the kernel, then transposed back — the double transpose cancels for
zero initial states but is conceptually wrong.

Remove _use_flashinfer_pool and pass ssm_states + cache_indices directly
to the Triton prefill kernel for all backends.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Swap the use_state_pool branch to call fused_sigmoid_gating_delta_rule_update
(Triton) instead of the FlashInfer CuTe DSL kernel. Original FlashInfer call
is commented out for easy restoration.

Result: 0.980 accuracy on gsm8k (vs 0.890 with FlashInfer decode), confirming
the accuracy gap is entirely in the FlashInfer decode kernel, not prefill.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…x and add flashinfer gsm8k test

- Restore FlashInfer pool API (initial_state + initial_state_indices) for
  SM100+ decode path; remove the temporary Triton-at-callsite debug block
- Pass dt_bias as float32 to the kernel (dt_bias.detach().float()); the
  kernel reads dt_bias_val without an explicit fp32 cast, so passing bf16
  caused a precision gap (0.89 → 0.94 on gsm8k with all precision fixes)
- Add TestQwen35FP4Flashinfer: same as TestQwen35FP4 but with
  --linear-attn-decode-backend flashinfer and threshold 0.93

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kaixih kaixih force-pushed the integrate_flashinfer_gdn branch from 06ea4f3 to aef6aa7 Compare March 17, 2026 16:50
…uracyTestParams

- Merge TestQwen35FP4 and TestQwen35FP4Flashinfer into a single class
  using run_combined_tests with Triton and FlashInfer variants
- Add top_k support to AccuracyTestParams/_run_simple_eval
- Restore dt_bias.detach().float() fix (regression from debug session)
- Remove stale debug comments in gdn_flashinfer.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@kaixih kaixih force-pushed the integrate_flashinfer_gdn branch from aef6aa7 to 836924d Compare March 17, 2026 16:54
@kaixih
Copy link
Collaborator Author

kaixih commented Mar 17, 2026

Rebased and added one test for the flashinfer decode backend:

To test:

python test/registered/4-gpu-models/test_qwen35_models.py TestQwen35FP4

Results:

============================================================
Qwen3.5-397B-A17B-NVFP4 Results Summary
Dataset: gsm8k
Baseline: 0.95
============================================================

Model 1: nvidia/Qwen3.5-397B-A17B-NVFP4
  Accuracy: PASS
  Score: 0.980

Model 2: nvidia/Qwen3.5-397B-A17B-NVFP4
  Accuracy: PASS
  Score: 0.985

============================================================
OVERALL: ALL TESTS PASSED
============================================================

@ispobock PTAL

kaixih and others added 2 commits March 17, 2026 17:37
…duling

FlashInfer GDN decode is incompatible with --mamba-scheduler-strategy no_buffer,
causing ~5-10% accuracy degradation on gsm8k. Raise a clear ValueError pointing
to the upstream issue.

See sgl-project#20791

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Move the incompatibility check for --linear-attn-decode-backend flashinfer
with --mamba-scheduler-strategy no_buffer to before the extra_buffer block,
so it fails fast before unrelated extra_buffer validations run.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ispobock ispobock merged commit 4cc1986 into sgl-project:main Mar 18, 2026
88 of 95 checks passed
Wangzheee pushed a commit to Wangzheee/sglang that referenced this pull request Mar 21, 2026
0-693 pushed a commit to 0-693/sglang that referenced this pull request Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants