[core]add gdn packed decode path#35739
Conversation
Signed-off-by: hdj <1293066020@qq.com>
Signed-off-by: hdj <1293066020@qq.com>
|
Hi @caozuoba, the pre-commit checks have failed. Please run: uv pip install pre-commit
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
There was a problem hiding this comment.
Code Review
This pull request introduces an opt-in packed recurrent decode fast path for Qwen3Next models, which is a valuable performance optimization. The implementation is well-structured, with the new functionality gated by a feature flag and a safe fallback mechanism. The addition of unit tests ensures correctness. I have one suggestion to enhance the robustness of the new Triton kernel by adding contiguity checks for A_log and dt_bias tensors, which could prevent potential silent errors.
| if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias): | ||
| raise ValueError("`A_log`/`dt_bias` must be floating tensors.") |
There was a problem hiding this comment.
The function fused_recurrent_gated_delta_rule_packed_decode_fwd includes several validation checks for its inputs, which is excellent for ensuring correctness. However, it appears to be missing contiguity checks for the A_log and dt_bias tensors. The Triton kernel fused_recurrent_gated_delta_rule_packed_decode_fwd_kernel loads from these tensors assuming they are contiguous. If non-contiguous tensors are passed, it could lead to incorrect data being read and silent errors in the computation. To improve robustness, I suggest adding contiguity checks for these tensors.
| if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias): | |
| raise ValueError("`A_log`/`dt_bias` must be floating tensors.") | |
| if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias): | |
| raise ValueError("`A_log`/`dt_bias` must be floating tensors.") | |
| if A_log.stride(0) != 1: | |
| raise ValueError("`A_log` must be contiguous.") | |
| if dt_bias.stride(0) != 1: | |
| raise ValueError("`dt_bias` must be contiguous.") |
Signed-off-by: hdj <1293066020@qq.com>
|
@sighingnow @mgoin Hi, would you have time to help move this PR forward? Thanks a lot! |
|
Hi,@tlrmchlsmth @ZJY0516 @ywang96 This PR is part of the GDN decode optimization work tracked in #35149. Now that #35777 has landed in If the direction still makes sense, I'm happy to rebase it onto current |
I think it depends on perf improvement. Honestly, the gdn code is a little messy now |
|
This pull request has merge conflicts that must be resolved before it can be |
Hi @ZJY0516 , I think opening a new PR makes the follow-up discussion cleaner, so I went ahead with that approach. I also re-ran the benchmark against the latest When you have time, could you please take another look at the new PR?#36596 |
Purpose
mixed_qkvand fusing gating + recurrent update in a single Triton kernel.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1is set (default:0). If the fast path preconditions are not met, it falls back to the existing path (logs once) instead of crashing.Implementation
fused_recurrent_gated_delta_rule_packed_decode_fwd(vllm/model_executor/layers/fla/ops/fused_recurrent.py).vllm/model_executor/layers/fla/ops/__init__.py.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE(default: off) with a safe fallback (vllm/model_executor/models/qwen3_next.py).tests/kernels/test_fused_recurrent_packed_decode.py).vllm/envs.py.Motivation
The existing Qwen3Next decode-uniform path performs extra work that becomes noticeable at high decode concurrency:
q/k/vviews (from packed projections) and runs standalone gating before the recurrent kernel.T=1, these extra tensor transformations + kernel launches add overhead and extra memory traffic.This PR introduces a decode-only packed fast path that:
mixed_qkv) directly,g/betainside the recurrent kernel,which reduces intermediate reads/writes and kernel launch overhead.
Correctness / Accuracy (How it matches the baseline)
The packed fast path is designed to be numerically equivalent to the existing implementation:
Same gating math
g = -exp(A_log) * softplus(a + dt_bias)beta = sigmoid(b)Same recurrent update (single token)
h = h * exp(g)v = (v - h @ k) * betah = h + outer(v, k)o = h @ qSame normalization option
use_qk_l2norm_in_kernelflag as the baseline path.Same accumulation behavior
float32and then casts to output/state dtype (fp16/bf16), consistent with the existing fused recurrent kernel behavior.Same continuous batching semantics
ssm_state_indicesto index per-request state.PAD_SLOT_ID = -1is handled by writing zeros to output and skipping state update (important for CUDAGraph replay where the output buffer can be reused).Safety / Rollout
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=0.export VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1Test Result
Correctness (pytest)
Command
Result
Notes:
fp16/bf16, including a strided packedmixed_qkvview andPAD_SLOT_ID=-1cases.Performance
Compared to
main, on NVIDIA H800, this PR improves Output token throughput (tok/s) by ~9.58%, reduces Mean TPOT (ms) by ~12.15%, and reduces Mean E2EL (ms) by ~9.40%.