[Qwen3.5] Add FlashInfer GDN decode backend#41966
[Qwen3.5] Add FlashInfer GDN decode backend#41966vadiklyutiy wants to merge 2 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
|
cc @arpera |
There was a problem hiding this comment.
Code Review
This pull request integrates FlashInfer's gated_delta_rule_decode_pretranspose kernel as a backend for Gated Delta Net (GDN) decode operations, specifically for SM90+ hardware. It includes a new test suite for regression and parity, adds a CLI argument to select the backend, and implements the logic in the Mamba linear attention layer. Feedback suggests optimizing the decode hot path by avoiding redundant tensor allocations in torch.where and moving local imports to the module level to reduce overhead.
| state_indices_fi = torch.where( | ||
| state_indices > 0, | ||
| state_indices, | ||
| torch.full_like(state_indices, -1), | ||
| ) |
There was a problem hiding this comment.
Using torch.full_like inside the forward pass is inefficient as it triggers a new tensor allocation and a fill kernel on every call. Since torch.where supports a scalar for the other argument, you can pass -1 directly. This avoids the extra allocation and is more efficient, especially in the hot path of the decode loop.
state_indices_fi = torch.where(state_indices > 0, state_indices, -1)| # `.detach()` on `nn.Parameter`s: FI ingests via `from_dlpack` | ||
| # which rejects `requires_grad=True`. detach is a free view | ||
| # and preserves strides, so the FI compile cache stays stable. | ||
| from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose |
There was a problem hiding this comment.
Performing an import inside the _fi_decode method adds unnecessary overhead to every decode step. While Python caches imports in sys.modules, the lookup still incurs a cost in the hot path. Consider moving this import to the top of the file (wrapped in a try-except block) or importing it once during __init__ and storing the function reference on the layer instance.
ZJY0516
left a comment
There was a problem hiding this comment.
Happy to see this. It's good enough, but I think what we want is a pure flashinfer path, now we still use triton kernel in mixed-pd scenario.
And do you think we should unify the gdn-prefill-kernel and gdn-decode-kernel?
| if self.gdn_prefill_backend is not None: | ||
| self.additional_config["gdn_prefill_backend"] = self.gdn_prefill_backend | ||
|
|
||
| if self.gdn_decode_backend is not None: | ||
| self.additional_config["gdn_decode_backend"] = self.gdn_decode_backend |
There was a problem hiding this comment.
Why are we doing it this way through additional_config? It seems like this would naturally go into a field in AttentionConfig, even if it isn't an Attention layer. You can still keep the top-level UX of a dedicated flag
Co-authored-by: Cursor Agent <cursor-agent@cursor.com> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
| cuda_version = torch.version.cuda | ||
| cuda_major = int(cuda_version.split(".")[0]) if cuda_version else 0 | ||
| # FlashInfer GDN decode kernel has known issues on CUDA 12.x; | ||
| # only enable on CUDA 13+. | ||
| cuda_supports_fi_gdn = cuda_major >= 13 |
There was a problem hiding this comment.
I did a separate function for that get_cuda_runtime_major in my integration of Blackwell GDN for prefill link. Please, consider using it.
Blocked on FlashInfer fix. This PR depends on flashinfer-ai/flashinfer#3230 (widen pool indices to int64 to prevent int32 element-offset overflow in
gdn_decode) and a subsequent FlashInfer release/upgrade in vLLM. All accuracy and perf results below were produced against a FlashInfer build that includes the #3230 fix.Add a FlashInfer-based decode path (
flashinfer.gdn_decode.gated_delta_rule_decode_pretranspose) for the GDN (Gated Delta Net) attention used in Qwen3-Next / Qwen3.5 hybrid models, side-by-side with the existing Triton/FLAfused_recurrent_gated_delta_rule_packed_decode. New engine flag--gdn-decode-backend {flashinfer,triton}selects the kernel; auto-picks FlashInfer on SM≥90 with bf16/fp32 SSM state, falls back to Triton/FLA otherwise. Includes a kernel-level parity test.Accuracy — GSM8K (1319 questions, 5-shot)
Command (per backend), each model uses its YAML-native topology (DP=2 + EP) with
--gdn-decode-backend {flashinfer|triton}appended toserver_args:Performance — Qwen3.5-397B-A17B-NVFP4
Command (DP=8 + EP, ISL=1 / OSL=1024, concurrency=2048, 2 iters):
vllm serve nvidia/Qwen3.5-397B-A17B-NVFP4 -tp 1 -dp 8 \ --enable-expert-parallel --no-enable-prefix-caching --stream-interval=100 \ --gdn-decode-backend {flashinfer|triton} vllm bench serve --backend vllm --model nvidia/Qwen3.5-397B-A17B-NVFP4 \ --endpoint /v1/completions --dataset-name random \ --random-input 1 --random-output 1024 \ --max-concurrency 2048 --num-prompt 2048 --ignore-eos --temperature 0.0Steady-state (iter 2):