[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K]#33291
Conversation
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a performance optimization by changing the memory layout of the GDN attention state from [N, HV, K, V] to [N, HV, V, K]. This change is aimed at improving memory access patterns and throughput. The modifications are consistently applied across documentation, examples, and Triton kernel implementations. The kernel logic has been correctly adapted to the new layout, including the use of transpositions where necessary. The provided performance benchmarks and correctness verification results support the effectiveness and validity of this change. The code appears to be correct and well-implemented.
|
cc @ZJY0516 |
|
This pull request has merge conflicts that must be resolved before it can be |
pavanimajety
left a comment
There was a problem hiding this comment.
Could we do a perf comparison against multiple batch sizes? Also does this change naturally work for spec decode too?
| b_v = b_v.to(k.dtype.element_ty) | ||
|
|
||
| p_k = tl.make_block_ptr( | ||
| k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1) | ||
| ) | ||
| b_k = tl.load(p_k, boundary_check=(0, 1)) | ||
| b_h1 += tl.dot(b_k, b_v) | ||
| b_h1 += tl.trans(tl.dot(b_k, b_v)) |
There was a problem hiding this comment.
do we see as much speedup for larger batch sizes too with these additional transposes?
There was a problem hiding this comment.
What do you mean? In description I wrote result for batch=1024
There was a problem hiding this comment.
Thanks for the clarification, I misread it as num-prompt 32. IMO we should still have performance across a range of batch sizes.
There was a problem hiding this comment.
I added to description comparison with several additional batch sizes
| @@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd( | |||
| if inplace_final_state: | |||
| final_state = initial_state | |||
There was a problem hiding this comment.
Does the ssm_state / kv_cache also need to be created in the N, HV, V, K layout?
There was a problem hiding this comment.
This layout change is layout of ssm_state/kv_cache.
Did I miss some place where I should change it?
There was a problem hiding this comment.
Just checking whether anything needs to change when the kv_cache is initially created. If the current setup yields good accuracy, it should be fine. Let's double-check with spec decode since it’s currently supported.
There was a problem hiding this comment.
I ran spec decoding and added results to description. No accuracy lost with spec decoding.
I added to description comparison with several additional batch sizes |
c9ecabd to
6db5d35
Compare
|
Thanks for your contributions, overall LGTM. |
It is indirectly impact prefill. See #32846 that try to enable FlashInfer prefill. But this kernel use the new layout.
This is fluctuation. If you make several runs of the vllm bench you will got 1-2% fluctuations. |
youkaichao
left a comment
There was a problem hiding this comment.
thanks for the contribution, LGTM 👍 since @zhiyuan1i agrees
…, V, K] (vllm-project#33291) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: felix01.yu <felix01.yu@vipshop.com>
|
…, V, K] (vllm-project#33291) Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
Summary
This PR changes the recurrent state memory layout in GDN (Gated Delta Net) attention from
[N, HV, K, V]to[N, HV, V, K]for improved memory access patterns and throughput.Behind speedup, also allows to use FI's GDN kernels
Performance Results
Model:
nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4(TP=2)Server:
VLLM_USE_FLASHINFER_MOE_FP4=1 vllm serve nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4 \ -tp 2 --enable-expert-parallel --async-scheduling --no-enable-prefix-caching \ --compilation_config.max_cudagraph_capture_size 2048Benchmark:
Correctness Verification (lm_eval)
Task: GSM8K (5-shot)
Model:
nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4Server:
Evaluation:
lm_eval --model local-chat-completions \ --model_args model=nvidia/Qwen3-Next-80B-A3B-Instruct-NVFP4,base_url=http://localhost:8000/v1/chat/completions,num_concurrent=250 \ --tasks gsm8k --apply_chat_template --num_fewshot 5 --output_path ./eval_results --log_samplesWith speculative decoding.
Unfortunatelly we have a problem in case spec decoding+cudagraph. Run without cudagraph. Also used
local-completionsinsted of abovelocal-chat-completions- that produce better accuracy.Server:
Evaluation:
lm_eval --model local-completions --tasks gsm8k --model_args base_url=http://localhost:8000/v1/completions,model=Qwen/Qwen3-Next-80B-A3B-Instruct-FP8,num_concurrent=109;Result the same as baseline.