Skip to content

[Perf] add packed recurrent fast path for decode#36596

Merged
vllm-bot merged 9 commits intovllm-project:mainfrom
caozuoba:perf/gdn-packed
Mar 12, 2026
Merged

[Perf] add packed recurrent fast path for decode#36596
vllm-bot merged 9 commits intovllm-project:mainfrom
caozuoba:perf/gdn-packed

Conversation

@caozuoba
Copy link
Copy Markdown
Contributor

@caozuoba caozuoba commented Mar 10, 2026

Purpose

  • Add a packed recurrent decode fast path for Qwen3Next GDN non-spec uniform decode (T=1).
  • Directly consume packed mixed_qkv in a decode-only fast path instead of materializing contiguous q/k/v.
  • Fuse the packed data flow with recurrent decode state update/output writeback to reduce intermediate memory traffic in the decode hot path.
  • Guard the fast path with VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE (default: 0).

Test Result

Correctness (pytest)

Command

python3 -m pytest -q tests/kernels/test_fused_recurrent_packed_decode.py

Result

....                                                                                                                                                   [100%]
4 passed in 8.55s

Performance

Compared to the main branch, this PR improves Output token throughput (tok/s) by ~6.37%, reduces Mean TPOT (ms) by ~5.34%, reduces Mean E2EL (ms) by ~6.56%, and reduces Mean TTFT (ms) by ~9.79%.

Command

python3 -m vllm.entrypoints.openai.api_server \
  --host 0.0.0.0 \
  --port 19000 \
  --dtype bfloat16 \
  --model /nas/disk1/Qwen3.5-35B-A3B \
  --served-model-name Qwen3.5-35B-A3B \
  --tensor-parallel-size 2 \
  --gpu-memory-utilization 0.9 \
  --max-model-len 32768 \
  --max-num-batched-tokens 32768 \
  --trust-remote-code \
  --no-enable-prefix-caching
Main:
============ Serving Benchmark Result ============
Successful requests:                     800
Failed requests:                         0
Benchmark duration (s):                  9.37
Total input tokens:                      102400
Total generated tokens:                  80000
Request throughput (req/s):              85.34
Output token throughput (tok/s):         8533.93
Peak output token throughput (tok/s):    13948.00
Peak concurrent requests:                800.00
Total token throughput (tok/s):          19457.36
---------------Time to First Token----------------
Mean TTFT (ms):                          2289.22
Median TTFT (ms):                        2349.07
P99 TTFT (ms):                           8254.40
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          62.00
Median TPOT (ms):                        61.78
P99 TPOT (ms):                           69.39
---------------Inter-token Latency----------------
Mean ITL (ms):                           61.63
Median ITL (ms):                         56.47
P99 ITL (ms):                            343.68
----------------End-to-end Latency----------------
Mean E2EL (ms):                          8427.53
Median E2EL (ms):                        8461.47
P99 E2EL (ms):                           9300.76
==================================================
PR:
============ Serving Benchmark Result ============
Successful requests:                     800
Failed requests:                         0
Benchmark duration (s):                  8.81
Total input tokens:                      102400
Total generated tokens:                  80000
Request throughput (req/s):              90.78
Output token throughput (tok/s):         9077.62
Peak output token throughput (tok/s):    14744.00
Peak concurrent requests:                800.00
Total token throughput (tok/s):          20696.97
---------------Time to First Token----------------
Mean TTFT (ms):                          2064.99
Median TTFT (ms):                        1807.52
P99 TTFT (ms):                           7798.07
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          58.69
Median TPOT (ms):                        61.33
P99 TPOT (ms):                           66.13
---------------Inter-token Latency----------------
Mean ITL (ms):                           58.25
Median ITL (ms):                         53.12
P99 ITL (ms):                            310.54
----------------End-to-end Latency----------------
Mean E2EL (ms):                          7874.99
Median E2EL (ms):                        7882.78
P99 E2EL (ms):                           8737.92
==================================================

Signed-off-by: hdj <1293066020@qq.com>
Signed-off-by: hdj <1293066020@qq.com>
@mergify mergify bot added the qwen Related to Qwen models label Mar 10, 2026
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Mar 10, 2026

Hi @caozuoba, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance optimization for Qwen3Next models by adding a packed recurrent fast path for the decode phase. The changes are well-structured and include a new Triton kernel that directly consumes a packed QKV tensor, which avoids materializing separate Q, K, and V tensors and fuses the gating logic. This new functionality is controlled by an environment variable and includes robust fallback mechanisms to the baseline implementation, ensuring safety. The pull request also adds a new test to verify the correctness of the new kernel. The integration into the existing model code is clean and follows established patterns. Overall, the changes appear to be correct, safe, and a valuable performance enhancement.

Signed-off-by: hdj <1293066020@qq.com>
@caozuoba
Copy link
Copy Markdown
Contributor Author

@mgoin @tlrmchlsmth @WoosukKwon @ZJY0516 Hello everyone, could you please take a look at this PR and provide some feedback when you have a moment? This is a follow-up to a PR #35739 submitted last week. Due to some conflicts with an already merged PR, I’ve created a new PR and re-run the benchmark tests. Thanks!

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add accuracy test

tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)


def fused_recurrent_gated_delta_rule_packed_decode_fwd(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def fused_recurrent_gated_delta_rule_packed_decode_fwd(
def fused_recurrent_gated_delta_rule_packed_decode(

I prefer this name

use_qk_l2norm_in_kernel=True,
)
return
except ValueError as exc:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When this will fail? I don't think this needs try and except here

else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)

def _forward_core_packed(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _forward_core_packed(
def _forward_core_decode_non_spec(

vllm/envs.py Outdated
if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","),
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "0"))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we enable this by default?

Signed-off-by: hdj <1293066020@qq.com>
@caozuoba
Copy link
Copy Markdown
Contributor Author

Accuracy Testing

Command

python3 -m lm_eval --model local-completions \
  --model_args model=Qwen3.5-35B-A3B,base_url=http://127.0.0.1:19000/v1/completions,num_concurrent=80,tokenizer=/nas/disk1/Qwen3.5-35B-A3B \
  --tasks gsm8k

Baseline

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8529|±  |0.0098|
|     |       |strict-match    |     5|exact_match|↑  |0.8370|±  |0.0102|

This PR

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8582|±  |0.0096|
|     |       |strict-match    |     5|exact_match|↑  |0.8431|±  |0.0100|

@caozuoba
Copy link
Copy Markdown
Contributor Author

@ZJY0516 Hi, I’ve addressed the review comments and added an accuracy test.

When you have time, could you please take another look and share your thoughts on the next step?

Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now

else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)

def _forward_core_decode_non_spec(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we switch to this inside _forward_core instead of falling back inside this decode method?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we switch to this inside _forward_core instead of falling back inside this decode method?

Good point, thanks. Moving this selection into _forward_core does make the flow cleaner.

Let me update that.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we switch to this inside _forward_core instead of falling back inside this decode method?

Updated accordingly, thanks!

@ZJY0516
Copy link
Copy Markdown
Member

ZJY0516 commented Mar 11, 2026

also cc @vadiklyutiy

Signed-off-by: hdj <1293066020@qq.com>
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)

if (
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer this:

if is_decode:
    return self._forward_core_decode_non_spec

# oringinal logic here

No need to introduce _forward_core_baseline

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer this:

if is_decode:
    return self._forward_core_decode_non_spec

# oringinal logic here

No need to introduce _forward_core_baseline

Makes sense — updated accordingly. Thanks.

Signed-off-by: hdj <1293066020@qq.com>
Copy link
Copy Markdown
Member

@ZJY0516 ZJY0516 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing

@ywang96 ywang96 added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 12, 2026
@caozuoba
Copy link
Copy Markdown
Contributor Author

@ZJY0516 Thanks again for the review and approval! I also have a draft follow-up PR for the spec path, and I’m planning to submit it soon. I’d really appreciate it if you could take a look when you have time.

@caozuoba
Copy link
Copy Markdown
Contributor Author

@ywang96 @ZJY0516 It looks like some of the failed checks are due to 403 / permission issues rather than the code change itself.When you have time, could you please help rerun those failed checks from the maintainer side?

@caozuoba
Copy link
Copy Markdown
Contributor Author

It looks like one check is still consistently hitting a 403. Do you happen to know what might be causing that, or if there’s anything I should do on my side? @ZJY0516 @ywang96

Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I just don't see the need for an env var

Comment on lines +903 to +905
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this env var at all if it is enabled by default?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, no need to add this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this env var at all if it is enabled by default?

Good point. I plan to clean this up together with the spec-path follow-up so both paths stay consistent.

@caozuoba
Copy link
Copy Markdown
Contributor Author

LGTM, I just don't see the need for an env var

@mgoin When you have time, could you please help retrigger that failed CI check? It seems to have hit a 403 a few times already.

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Mar 12, 2026

We can just force merge for now

@vllm-bot vllm-bot merged commit 9e19f83 into vllm-project:main Mar 12, 2026
56 of 58 checks passed
Lucaskabela pushed a commit to Lucaskabela/vllm that referenced this pull request Mar 17, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
wendyliu235 pushed a commit to wendyliu235/vllm-public that referenced this pull request Mar 18, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
fxdawnn pushed a commit to fxdawnn/vllm that referenced this pull request Mar 19, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
khairulkabir1661 pushed a commit to khairulkabir1661/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Monishver11 pushed a commit to Monishver11/vllm that referenced this pull request Mar 27, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
vrdn-23 pushed a commit to vrdn-23/vllm that referenced this pull request Mar 30, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: Vinay Damodaran <vrdn@hey.com>
EricccYang pushed a commit to EricccYang/vllm that referenced this pull request Apr 1, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Signed-off-by: EricccYang <yangyang4991@gmail.com>
liuchenbing2026 pushed a commit to liuchenbing2026/vllm that referenced this pull request Apr 4, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
big-yellow-duck pushed a commit to EmbeddedLLM/vllm that referenced this pull request Apr 8, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants