Skip to content

[Attention][Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265

Closed
LopezCastroRoberto wants to merge 16 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/topKperRow-FI
Closed

[Attention][Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265
LopezCastroRoberto wants to merge 16 commits intovllm-project:mainfrom
LopezCastroRoberto:perf/topKperRow-FI

Conversation

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor

@LopezCastroRoberto LopezCastroRoberto commented Feb 10, 2026

Summary

This PR integrates FlashInfer's radix-based top-k kernel as an alternative implementation for the large context top-k operation in the sparse attention indexer, specifically for DeepSeek-V3.2 models.

Kernel adapted from: flashinfer-ai/flashinfer#2215

Microbenchmark study

overview2

E2E results

Example on NVIDIA B200:

vllm serve nvidia/DeepSeek-V3.2-NVFP4 -tp 4
vllm bench serve --backend vllm --model nvidia/DeepSeek-V3.2-NVFP4 --input-len 128000 --output-len 4096 --num-prompts 1

MAIN:

============ Serving Benchmark Result ============
Successful requests:                     1         
Failed requests:                         0         
Benchmark duration (s):                  59.15     
Total input tokens:                      128000    
Total generated tokens:                  4096      
Request throughput (req/s):              0.02      
Output token throughput (tok/s):         69.24     
Peak output token throughput (tok/s):    71.00     
Peak concurrent requests:                1.00      
Total token throughput (tok/s):          2233.14   
---------------Time to First Token----------------
Mean TTFT (ms):                          717.80    
Median TTFT (ms):                        717.80    
P99 TTFT (ms):                           717.80    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.27     
Median TPOT (ms):                        14.27     
P99 TPOT (ms):                           14.27     
---------------Inter-token Latency----------------
Mean ITL (ms):                           14.27     
Median ITL (ms):                         14.27     
P99 ITL (ms):                            14.57     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1         
Failed requests:                         0         
Benchmark duration (s):                  53.62     
Total input tokens:                      128000    
Total generated tokens:                  4096      
Request throughput (req/s):              0.02      
Output token throughput (tok/s):         76.39     
Peak output token throughput (tok/s):    80.00     
Peak concurrent requests:                1.00      
Total token throughput (tok/s):          2463.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          732.15    
Median TTFT (ms):                        732.15    
P99 TTFT (ms):                           732.15    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          12.92     
Median TPOT (ms):                        12.92     
P99 TPOT (ms):                           12.92     
---------------Inter-token Latency----------------
Mean ITL (ms):                           12.92     
Median ITL (ms):                         12.72     
P99 ITL (ms):                            13.81     
==================================================

In this example, the current PR improves MAIN throughput by ~10%

Here is a more general analysis across different sequence lengths on NVIDIA B300:

vllm bench serve --backend vllm --model nvidia/DeepSeek-V3.2-NVFP4 --input-len seq_len --output-len 4096 --num-prompts 1

throughput

Accuracy

python tests/evals/gsm8k/gsm8k_eval.py

MAIN:

Results:
Accuracy: 0.926
Invalid responses: 0.000
Total latency: 54.086 s
Questions per second: 24.387
Total output tokens: 121416
Output tokens per second: 2244.889

PR:

Results:
Accuracy: 0.929
Invalid responses: 0.000
Total latency: 52.035 s
Questions per second: 25.348
Total output tokens: 121881
Output tokens per second: 2342.299

@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft February 10, 2026 18:47
@LopezCastroRoberto LopezCastroRoberto changed the title Add FlashInfer top-k support to large context decode path [Perf] Add FlashInfer top-k support to large context decode path Feb 10, 2026
@mergify mergify bot added rocm Related to AMD ROCm v1 labels Feb 10, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 10, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces the custom large_context_topk kernel with flashinfer.top_k_ragged_transform for handling top-k operations in the large context decode path. The changes primarily involve updating sparse_attn_indexer.py to use the FlashInfer function and passing a new offsets_buffer. Corresponding changes are made for API compatibility in the ROCm path. The tests are also updated to validate the new implementation. My review found a critical issue in the test file where a new test function shadows an existing one due to having the same name, and also misuses a pytest parameter. I've provided a suggestion to fix this.

@LopezCastroRoberto LopezCastroRoberto changed the title [Perf] Add FlashInfer top-k support to large context decode path [Perf] Add FlashInfer top-k support to large context decode path - DeepSeek-V3.2 sparse attention Feb 10, 2026
@mergify mergify bot added the deepseek Related to DeepSeek models label Feb 10, 2026
@mergify mergify bot added the nvidia label Feb 12, 2026
@LopezCastroRoberto LopezCastroRoberto changed the title [Perf] Add FlashInfer top-k support to large context decode path - DeepSeek-V3.2 sparse attention [Perf][Kernel] Improve topKperRow routine for large context decode path - DeepSeek-V3.2 sparse attention Feb 12, 2026
@LopezCastroRoberto LopezCastroRoberto changed the title [Perf][Kernel] Improve topKperRow routine for large context decode path - DeepSeek-V3.2 sparse attention [Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention Feb 12, 2026
@LopezCastroRoberto LopezCastroRoberto marked this pull request as ready for review February 12, 2026 18:49
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 13, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LopezCastroRoberto.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 13, 2026
@mergify mergify bot added the performance Performance-related issues label Feb 16, 2026
@mgoin
Copy link
Copy Markdown
Member

mgoin commented Feb 16, 2026

Interesting that the gsm8k eval is ~15% faster even though it is only ~2k context length and has especially short prefills (due to high prefix cache hit rate)

@LopezCastroRoberto
Copy link
Copy Markdown
Contributor Author

LopezCastroRoberto commented Feb 16, 2026

Interesting that the gsm8k eval is ~15% faster even though it is only ~2k context length and has especially short prefills (due to high prefix cache hit rate)

@mgoin Yeah, there isn’t actually a remarkable speedup on GSM8K. My goal wasn’t to demonstrate performance improvements on this test, but simply to verify that accuracy was preserved. It looks like I probably just copied one of the later runs (second/third/fourth) from several consecutive executions I did for this test, rather than the initial run.

e.g.,
python tests/evals/gsm8k/gsm8k_eval.py

First execution:

Results:
Accuracy: 0.929
Invalid responses: 0.000
Total latency: 52.035 s
Questions per second: 25.348
Total output tokens: 121881
Output tokens per second: 2342.299

Second execution:

Accuracy: 0.929
Invalid responses: 0.000
Total latency: 46.138 s
Questions per second: 28.588
Total output tokens: 122155
Output tokens per second: 2647.619

I agree this can be confusing. I will update it in the PR description.

# See: https://github.com/vllm-project/vllm/pull/34265
max_seq_len = common_attn_metadata.max_seq_len
use_radix_topk = max_seq_len >= 65536
use_large_context_topk = max_seq_len == 2048 or (8192 < max_seq_len < 65536)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, there may be a minor improvement for 2k contexts due to the updated heuristic introduced in this PR. Specifically, large_context_topk (integrated in a previous PR) is selected instead of top_k_per_row_decode. This choice is based on the microbenchmark results described above.

@LopezCastroRoberto LopezCastroRoberto changed the title [Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention [Attention][Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention Feb 18, 2026
@LopezCastroRoberto
Copy link
Copy Markdown
Contributor Author

LopezCastroRoberto commented Mar 2, 2026

This could be fairly expensive, any way to avoid? I noticed it isn't included in the benchmarking code. Maybe the kernel itself could zero only the required sections just for RadixRowState?
cc: @mgoin

I’ve fused the topk_workspace.zero_() op directly into the kernel. At the kernel level, there’s no remarkable overhead from this change. I also added a few additional test points to further validate that the heuristic behaves as expected.
topk_decode_benchmark

Copy link
Copy Markdown
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we know how much perf the cudagraph specialization brings e2e? it adds quite a bit of complexity, just wondering if its worth it? how hard would it be optimize the topk for the shorter contexts?

@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft March 17, 2026 10:33
LopezCastroRoberto and others added 13 commits March 17, 2026 14:50
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as ready for review March 17, 2026 15:13
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
@LopezCastroRoberto LopezCastroRoberto marked this pull request as draft March 17, 2026 15:21
@mergify mergify bot removed the needs-rebase label Mar 17, 2026
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models nvidia performance Performance-related issues rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants