Skip to content

[Qwen3.5] Add FlashInfer GDN decode backend#41966

Open
vadiklyutiy wants to merge 2 commits intovllm-project:mainfrom
vadiklyutiy:fi-gdn-decode
Open

[Qwen3.5] Add FlashInfer GDN decode backend#41966
vadiklyutiy wants to merge 2 commits intovllm-project:mainfrom
vadiklyutiy:fi-gdn-decode

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

Blocked on FlashInfer fix. This PR depends on flashinfer-ai/flashinfer#3230 (widen pool indices to int64 to prevent int32 element-offset overflow in gdn_decode) and a subsequent FlashInfer release/upgrade in vLLM. All accuracy and perf results below were produced against a FlashInfer build that includes the #3230 fix.

Add a FlashInfer-based decode path (flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose) for the GDN (Gated Delta Net) attention used in Qwen3-Next / Qwen3.5 hybrid models, side-by-side with the existing Triton/FLA fused_recurrent_gated_delta_rule_packed_decode. New engine flag --gdn-decode-backend {flashinfer,triton} selects the kernel; auto-picks FlashInfer on SM≥90 with bf16/fp32 SSM state, falls back to Triton/FLA otherwise. Includes a kernel-level parity test.

Accuracy — GSM8K (1319 questions, 5-shot)

Command (per backend), each model uses its YAML-native topology (DP=2 + EP) with --gdn-decode-backend {flashinfer|triton} appended to server_args:

pytest -s -v tests/evals/gsm8k/test_gsm8k_correctness.py \
  --config-list-file=configs/<single-model.txt>
Model (DP=2, EP) flashinfer triton Threshold
Qwen/Qwen3.5-35B-A3B 0.8552 0.8590 0.84
Qwen/Qwen3.5-35B-A3B-FP8 (kv fp8) 0.7786 0.7635 0.79
nvidia/Qwen3.5-397B-A17B-NVFP4 0.8817 0.8825 0.88

Performance — Qwen3.5-397B-A17B-NVFP4

Command (DP=8 + EP, ISL=1 / OSL=1024, concurrency=2048, 2 iters):

vllm serve nvidia/Qwen3.5-397B-A17B-NVFP4 -tp 1 -dp 8 \
  --enable-expert-parallel --no-enable-prefix-caching --stream-interval=100 \
  --gdn-decode-backend {flashinfer|triton} 

vllm bench serve --backend vllm --model nvidia/Qwen3.5-397B-A17B-NVFP4 \
  --endpoint /v1/completions --dataset-name random \
  --random-input 1 --random-output 1024 \
  --max-concurrency 2048 --num-prompt 2048 --ignore-eos --temperature 0.0

Steady-state (iter 2):

Metric flashinfer triton Δ
Output throughput (tok/s) 48 736 47 210 +3.2%

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

cc @arpera

@mergify mergify Bot added the qwen Related to Qwen models label May 7, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates FlashInfer's gated_delta_rule_decode_pretranspose kernel as a backend for Gated Delta Net (GDN) decode operations, specifically for SM90+ hardware. It includes a new test suite for regression and parity, adds a CLI argument to select the backend, and implements the logic in the Mamba linear attention layer. Feedback suggests optimizing the decode hot path by avoiding redundant tensor allocations in torch.where and moving local imports to the module level to reduce overhead.

Comment on lines +1129 to +1133
state_indices_fi = torch.where(
state_indices > 0,
state_indices,
torch.full_like(state_indices, -1),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.full_like inside the forward pass is inefficient as it triggers a new tensor allocation and a fill kernel on every call. Since torch.where supports a scalar for the other argument, you can pass -1 directly. This avoids the extra allocation and is more efficient, especially in the hot path of the decode loop.

        state_indices_fi = torch.where(state_indices > 0, state_indices, -1)

# `.detach()` on `nn.Parameter`s: FI ingests via `from_dlpack`
# which rejects `requires_grad=True`. detach is a free view
# and preserves strides, so the FI compile cache stays stable.
from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Performing an import inside the _fi_decode method adds unnecessary overhead to every decode step. While Python caches imports in sys.modules, the lookup still incurs a cost in the hot path. Consider moving this import to the top of the file (wrapped in a try-except block) or importing it once during __init__ and storing the function reference on the layer instance.

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to see this. It's good enough, but I think what we want is a pure flashinfer path, now we still use triton kernel in mixed-pd scenario.

And do you think we should unify the gdn-prefill-kernel and gdn-decode-kernel?

Comment thread vllm/engine/arg_utils.py
Comment on lines 2155 to +2159
if self.gdn_prefill_backend is not None:
self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend

if self.gdn_decode_backend is not None:
self.additional_config["gdn_decode_backend"] = self.gdn_decode_backend
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing it this way through additional_config? It seems like this would naturally go into a field in AttentionConfig, even if it isn't an Attention layer. You can still keep the top-level UX of a dedicated flag

Co-authored-by: Cursor Agent <cursor-agent@cursor.com>
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Comment on lines +227 to +231
cuda_version = torch.version.cuda
cuda_major = int(cuda_version.split(".")[0]) if cuda_version else 0
# FlashInfer GDN decode kernel has known issues on CUDA 12.x;
# only enable on CUDA 13+.
cuda_supports_fi_gdn = cuda_major >= 13
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a separate function for that get_cuda_runtime_major in my integration of Blackwell GDN for prefill link. Please, consider using it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants