Skip to content

[core]add gdn packed decode path#35739

Open
caozuoba wants to merge 3 commits intovllm-project:mainfrom
caozuoba:pr/gdn-packed-decode
Open

[core]add gdn packed decode path#35739
caozuoba wants to merge 3 commits intovllm-project:mainfrom
caozuoba:pr/gdn-packed-decode

Conversation

@caozuoba
Copy link
Copy Markdown
Contributor

@caozuoba caozuoba commented Mar 2, 2026

Purpose

  • Add an opt-in packed recurrent decode fast path for Qwen3Next (FLA/GDN) uniform decode (T=1) by directly consuming packed mixed_qkv and fusing gating + recurrent update in a single Triton kernel.
  • Keep default behavior unchanged unless VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1 is set (default: 0). If the fast path preconditions are not met, it falls back to the existing path (logs once) instead of crashing.

Implementation

  • Add a decode-only Triton kernel and Python wrapper: fused_recurrent_gated_delta_rule_packed_decode_fwd (vllm/model_executor/layers/fla/ops/fused_recurrent.py).
  • Export the new op in vllm/model_executor/layers/fla/ops/__init__.py.
  • Integrate into Qwen3Next uniform decode path behind VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE (default: off) with a safe fallback (vllm/model_executor/models/qwen3_next.py).
  • Add CUDA unit test for output + state parity vs baseline (tests/kernels/test_fused_recurrent_packed_decode.py).
  • Add env var definition in vllm/envs.py.

Motivation

The existing Qwen3Next decode-uniform path performs extra work that becomes noticeable at high decode concurrency:

  • It materializes contiguous q/k/v views (from packed projections) and runs standalone gating before the recurrent kernel.
  • For decode T=1, these extra tensor transformations + kernel launches add overhead and extra memory traffic.

This PR introduces a decode-only packed fast path that:

  • reuses the packed projection output (mixed_qkv) directly,
  • computes g/beta inside the recurrent kernel,
  • writes results directly into the caller-provided output buffer,
    which reduces intermediate reads/writes and kernel launch overhead.

Correctness / Accuracy (How it matches the baseline)

The packed fast path is designed to be numerically equivalent to the existing implementation:

  • Same gating math

    • g = -exp(A_log) * softplus(a + dt_bias)
    • beta = sigmoid(b)
    • This matches the existing fused gating used by the baseline path.
  • Same recurrent update (single token)

    • h = h * exp(g)
    • v = (v - h @ k) * beta
    • h = h + outer(v, k)
    • o = h @ q
  • Same normalization option

    • The packed kernel supports the same use_qk_l2norm_in_kernel flag as the baseline path.
  • Same accumulation behavior

    • The kernel performs math in float32 and then casts to output/state dtype (fp16/bf16), consistent with the existing fused recurrent kernel behavior.
  • Same continuous batching semantics

    • The kernel uses ssm_state_indices to index per-request state.
    • PAD_SLOT_ID = -1 is handled by writing zeros to output and skipping state update (important for CUDAGraph replay where the output buffer can be reused).

Safety / Rollout

  • Default is OFF: VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=0.
  • Enable with:
    export VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1
  • When enabled, Qwen3Next uniform-decode will attempt the packed fast path; if preconditions are violated (e.g., unexpected strides/layouts), it logs once and falls back to the original path.

Test Result

Correctness (pytest)

Command

python3 -m pytest -q tests/kernels/test_fused_recurrent_packed_decode.py

Result

....
[100%]
4 passed in 6.79s

Notes:

  • Requires a CUDA device + Triton environment.
  • The test compares packed vs baseline outputs and final states for fp16/bf16, including a strided packed mixed_qkv view and PAD_SLOT_ID=-1 cases.

Performance

Compared to main, on NVIDIA H800, this PR improves Output token throughput (tok/s) by ~9.58%, reduces Mean TPOT (ms) by ~12.15%, and reduces Mean E2EL (ms) by ~9.40%.

The performance numbers for this PR are collected with
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE=1.

main
============ Serving Benchmark Result ============
Successful requests:                     800
Failed requests:                         0
Benchmark duration (s):                  11.03
Total input tokens:                      102400
Total generated tokens:                  80000
Request throughput (req/s):              72.55
Output token throughput (tok/s):         7254.76
Peak output token throughput (tok/s):    11880.00
Peak concurrent requests:                800.00
Total token throughput (tok/s):          16540.86
---------------Time to First Token----------------
Mean TTFT (ms):                          2523.85
Median TTFT (ms):                        2346.24
P99 TTFT (ms):                           3835.41
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          77.36
Median TPOT (ms):                        79.67
P99 TPOT (ms):                           85.37
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.61
Median ITL (ms):                         70.00
P99 ITL (ms):                            215.30
----------------End-to-end Latency----------------
Mean E2EL (ms):                          10182.42
Median E2EL (ms):                        10239.01
P99 E2EL (ms):                           10372.99
==================================================
PR
============ Serving Benchmark Result ============
Successful requests:                     800
Failed requests:                         0
Benchmark duration (s):                  10.06
Total input tokens:                      102400
Total generated tokens:                  80000
Request throughput (req/s):              79.50
Output token throughput (tok/s):         7950.07
Peak output token throughput (tok/s):    13868.00
Peak concurrent requests:                800.00
Total token throughput (tok/s):          18126.17
---------------Time to First Token----------------
Mean TTFT (ms):                          2497.60
Median TTFT (ms):                        2432.61
P99 TTFT (ms):                           3765.33
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          67.96
Median TPOT (ms):                        69.22
P99 TPOT (ms):                           76.57
---------------Inter-token Latency----------------
Mean ITL (ms):                           67.37
Median ITL (ms):                         59.42
P99 ITL (ms):                            213.89
----------------End-to-end Latency----------------
Mean E2EL (ms):                          9225.53
Median E2EL (ms):                        9284.60
P99 E2EL (ms):                           9422.97
==================================================

caozuoba added 2 commits March 2, 2026 18:03
Signed-off-by: hdj <1293066020@qq.com>
Signed-off-by: hdj <1293066020@qq.com>
@mergify mergify bot added the qwen Related to Qwen models label Mar 2, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 2, 2026

Hi @caozuoba, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an opt-in packed recurrent decode fast path for Qwen3Next models, which is a valuable performance optimization. The implementation is well-structured, with the new functionality gated by a feature flag and a safe fallback mechanism. The addition of unit tests ensures correctness. I have one suggestion to enhance the robustness of the new Triton kernel by adding contiguity checks for A_log and dt_bias tensors, which could prevent potential silent errors.

Comment on lines +337 to +338
if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias):
raise ValueError("`A_log`/`dt_bias` must be floating tensors.")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function fused_recurrent_gated_delta_rule_packed_decode_fwd includes several validation checks for its inputs, which is excellent for ensuring correctness. However, it appears to be missing contiguity checks for the A_log and dt_bias tensors. The Triton kernel fused_recurrent_gated_delta_rule_packed_decode_fwd_kernel loads from these tensors assuming they are contiguous. If non-contiguous tensors are passed, it could lead to incorrect data being read and silent errors in the computation. To improve robustness, I suggest adding contiguity checks for these tensors.

Suggested change
if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias):
raise ValueError("`A_log`/`dt_bias` must be floating tensors.")
if not torch.is_floating_point(A_log) or not torch.is_floating_point(dt_bias):
raise ValueError("`A_log`/`dt_bias` must be floating tensors.")
if A_log.stride(0) != 1:
raise ValueError("`A_log` must be contiguous.")
if dt_bias.stride(0) != 1:
raise ValueError("`dt_bias` must be contiguous.")

Signed-off-by: hdj <1293066020@qq.com>
@caozuoba
Copy link
Copy Markdown
Contributor Author

caozuoba commented Mar 4, 2026

@sighingnow @mgoin Hi, would you have time to help move this PR forward? Thanks a lot!

@tlrmchlsmth tlrmchlsmth self-assigned this Mar 7, 2026
@caozuoba
Copy link
Copy Markdown
Contributor Author

caozuoba commented Mar 9, 2026

Hi,@tlrmchlsmth @ZJY0516 @ywang96

This PR is part of the GDN decode optimization work tracked in #35149.

Now that #35777 has landed in main, this branch has some overlap with the new fused sigmoid-gating path and will likely need a rebase / integration update. Before I invest more time rebasing and re-benchmarking it on top of the latest baseline, could you help confirm whether you'd still like this packed mixed_qkv decode path to move forward as a follow-up optimization?

If the direction still makes sense, I'm happy to rebase it onto current main and refresh the performance numbers.

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 9, 2026

Hi,@tlrmchlsmth @ZJY0516 @ywang96

This PR is part of the GDN decode optimization work tracked in #35149.

Now that #35777 has landed in main, this branch has some overlap with the new fused sigmoid-gating path and will likely need a rebase / integration update. Before I invest more time rebasing and re-benchmarking it on top of the latest baseline, could you help confirm whether you'd still like this packed mixed_qkv decode path to move forward as a follow-up optimization?

If the direction still makes sense, I'm happy to rebase it onto current main and refresh the performance numbers.

I think it depends on perf improvement. Honestly, the gdn code is a little messy now

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 9, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @caozuoba.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 9, 2026
@caozuoba
Copy link
Copy Markdown
Contributor Author

Hi,@tlrmchlsmth @ZJY0516 @ywang96
This PR is part of the GDN decode optimization work tracked in #35149.
Now that #35777 has landed in main, this branch has some overlap with the new fused sigmoid-gating path and will likely need a rebase / integration update. Before I invest more time rebasing and re-benchmarking it on top of the latest baseline, could you help confirm whether you'd still like this packed mixed_qkv decode path to move forward as a follow-up optimization?
If the direction still makes sense, I'm happy to rebase it onto current main and refresh the performance numbers.

I think it depends on perf improvement. Honestly, the gdn code is a little messy now

Hi @ZJY0516 ,

I think opening a new PR makes the follow-up discussion cleaner, so I went ahead with that approach.

I also re-ran the benchmark against the latest main, and the remaining perf gain still looks meaningful on my side.

When you have time, could you please take another look at the new PR?#36596

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

needs-rebase qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants