Skip to content

[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#33291

Merged
youkaichao merged 2 commits into
vllm-project:mainfrom
CentML:vadim/gdn-kv-2-vk
Feb 4, 2026
Merged

[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#33291
youkaichao merged 2 commits into
vllm-project:mainfrom
CentML:vadim/gdn-kv-2-vk

Conversation

@vadiklyutiy
Copy link
Copy Markdown
Collaborator

@vadiklyutiy vadiklyutiy commented Jan 28, 2026

Summary

This PR changes the recurrent state memory layout in GDN (Gated Delta Net) attention from [N, HV, K, V] to [N, HV, V, K] for improved memory access patterns and throughput.

Behind speedup, also allows to use FI's GDN kernels

Performance Results

Model: nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 (TP=2)

Server:

VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
    -tp 2 --enable-expert-parallel --async-scheduling --no-enable-prefix-caching \
    --compilation_config.max_cudagraph_capture_size 2048

Benchmark:

vllm bench serve --backend vllm --model nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
--endpoint /v1/completions --dataset-name random --random-input 32 --random-output 1000 \
--max-concurrency $CONC --num-prompt $CONC --ignore-eos
Batch Size Baseline (tok/s) With PR (tok/s) Delta
1 199.58 201.37 +0.9%
16 2,251.70 2,225.81 -1.2%
64 6,148.08 6,088.84 -1.0%
256 14,420.40 14,620.51 +1.4%
1024 23,245 24,350 +4.8%

Correctness Verification (lm_eval)

Task: GSM8K (5-shot)
Model: nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4

Server:

VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \
-tp 2 --enable-expert-parallel --async-scheduling --no-enable-prefix-caching \
--compilation_config.max_cudagraph_capture_size 2048

Evaluation:

lm_eval --model local-chat-completions \
    --model_args model=nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=250 \
    --tasks gsm8k --apply_chat_template --num_fewshot 5 --output_path ./eval_results --log_samples
Metric Baseline With PR Delta
exact_match (flexible-extract) 0.7703 0.7718 +0.0015
exact_match (strict-match) 0.6406 0.6368 -0.0038

With speculative decoding.

Unfortunatelly we have a problem in case spec decoding+cudagraph. Run without cudagraph. Also used local-completions insted of above local-chat-completions - that produce better accuracy.

Server:

vllm serve Qwen/Qwen3-Next-80B-A3B-Instruct-FP8 -tp 4     --enable-expert-parallel --async-scheduling --no
-enable-prefix-caching     --compilation_config.cudagraph_mode NONE    --speculative_config.method qwen3_next_mtp     --speculative_config.num_speculative_toke
ns 3

Evaluation:

lm_eval --model local-completions --tasks gsm8k   --model_args base_url=http://localhost:8000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct-FP8,num_concurrent=109;
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.8537 ± 0.0097
strict-match 5 exact_match 0.8143 ± 0.0107

Result the same as baseline.

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
@vadiklyutiy vadiklyutiy requested a review from tdoublep as a code owner January 28, 2026 23:41
@vadiklyutiy vadiklyutiy requested review from sighingnow and youkaichao and removed request for tdoublep January 28, 2026 23:41
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization by changing the memory layout of the GDN attention state from [N, HV, K, V] to [N, HV, V, K]. This change is aimed at improving memory access patterns and throughput. The modifications are consistently applied across documentation, examples, and Triton kernel implementations. The kernel logic has been correctly adapted to the new layout, including the use of transpositions where necessary. The provided performance benchmarks and correctness verification results support the effectiveness and validity of this change. The code appears to be correct and well-implemented.

@vadiklyutiy vadiklyutiy changed the title [PERF] PR: Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] [PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] Jan 28, 2026
@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

cc @ZJY0516

@mergify
Copy link
Copy Markdown
Contributor

mergify Bot commented Jan 29, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vadiklyutiy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Copy Markdown
Collaborator

@pavanimajety pavanimajety left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we do a perf comparison against multiple batch sizes? Also does this change naturally work for spec decode too?

b_v = b_v.to(k.dtype.element_ty)

p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v)
b_h1 += tl.trans(tl.dot(b_k, b_v))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we see as much speedup for larger batch sizes too with these additional transposes?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean? In description I wrote result for batch=1024

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, I misread it as num-prompt 32. IMO we should still have performance across a range of batch sizes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added to description comparison with several additional batch sizes

@@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd(
if inplace_final_state:
final_state = initial_state
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the ssm_state / kv_cache also need to be created in the N, HV, V, K layout?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This layout change is layout of ssm_state/kv_cache.
Did I miss some place where I should change it?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just checking whether anything needs to change when the kv_cache is initially created. If the current setup yields good accuracy, it should be fine. Let's double-check with spec decode since it’s currently supported.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran spec decoding and added results to description. No accuracy lost with spec decoding.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

Could we do a perf comparison against multiple batch sizes? Also does this change naturally work for spec decode too?

I added to description comparison with several additional batch sizes

@zhiyuan1i
Copy link
Copy Markdown
Contributor

Thanks for your contributions, overall LGTM.
I understand that this will affect the throughput more in the decode case and not so much on the prefill, right?
I see a certain negative impact on your PR for the small bsz interval, can we look into the reason for this at the same time.
For me, it would be better to look at the performance impact of prefill and decode separately if there were a kernel-level benchmark.

@vadiklyutiy
Copy link
Copy Markdown
Collaborator Author

I understand that this will affect the throughput more in the decode case and not so much on the prefill, right?

It is indirectly impact prefill. See #32846 that try to enable FlashInfer prefill. But this kernel use the new layout.
In general several kernel experts have been working on GDN prefill and decode kernels and independently came to conclusion that such layout is better.

I see a certain negative impact on your PR for the small bsz interval, can we look into the reason for this at the same time.

This is fluctuation. If you make several runs of the vllm bench you will got 1-2% fluctuations.

Copy link
Copy Markdown
Member

@youkaichao youkaichao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the contribution, LGTM 👍 since @zhiyuan1i agrees

@youkaichao youkaichao enabled auto-merge (squash) February 4, 2026 02:29
@github-actions github-actions Bot added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 4, 2026
@youkaichao youkaichao merged commit 8240580 into vllm-project:main Feb 4, 2026
45 checks passed
gameofdimension pushed a commit to gameofdimension/vllm that referenced this pull request Feb 5, 2026
…, V, K] (vllm-project#33291)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
@leuasseurfarrelds247-arch
Copy link
Copy Markdown

  • [ ]

@vadiklyutiy vadiklyutiy deleted the vadim/gdn-kv-2-vk branch March 11, 2026 08:00
mystous pushed a commit to mystous/vllm_hybrid that referenced this pull request May 10, 2026
…, V, K] (vllm-project#33291)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants