[Attention][Perf][Kernel] Improve topKperRow for large context decode path - DeepSeek-V3.2 sparse attention#34265
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces the custom large_context_topk kernel with flashinfer.top_k_ragged_transform for handling top-k operations in the large context decode path. The changes primarily involve updating sparse_attn_indexer.py to use the FlashInfer function and passing a new offsets_buffer. Corresponding changes are made for API compatibility in the ROCm path. The tests are also updated to validate the new implementation. My review found a critical issue in the test file where a new test function shadows an existing one due to having the same name, and also misuses a pytest parameter. I've provided a suggestion to fix this.
bbba437 to
2d74e0f
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
|
Interesting that the gsm8k eval is ~15% faster even though it is only ~2k context length and has especially short prefills (due to high prefix cache hit rate) |
@mgoin Yeah, there isn’t actually a remarkable speedup on GSM8K. My goal wasn’t to demonstrate performance improvements on this test, but simply to verify that accuracy was preserved. It looks like I probably just copied one of the later runs (second/third/fourth) from several consecutive executions I did for this test, rather than the initial run. e.g., First execution: Second execution: I agree this can be confusing. I will update it in the PR description. |
| # See: https://github.com/vllm-project/vllm/pull/34265 | ||
| max_seq_len = common_attn_metadata.max_seq_len | ||
| use_radix_topk = max_seq_len >= 65536 | ||
| use_large_context_topk = max_seq_len == 2048 or (8192 < max_seq_len < 65536) |
There was a problem hiding this comment.
However, there may be a minor improvement for 2k contexts due to the updated heuristic introduced in this PR. Specifically, large_context_topk (integrated in a previous PR) is selected instead of top_k_per_row_decode. This choice is based on the microbenchmark results described above.
I’ve fused the |
LucasWilkinson
left a comment
There was a problem hiding this comment.
do we know how much perf the cudagraph specialization brings e2e? it adds quite a bit of complexity, just wondering if its worth it? how hard would it be optimize the topk for the shorter contexts?
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
6e3a5f1 to
02bc7c5
Compare
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>

Summary
This PR integrates FlashInfer's radix-based top-k kernel as an alternative implementation for the large context top-k operation in the sparse attention indexer, specifically for DeepSeek-V3.2 models.
Kernel adapted from: flashinfer-ai/flashinfer#2215
Microbenchmark study
E2E results
Example on NVIDIA B200:
MAIN:
PR:
In this example, the current PR improves MAIN throughput by ~10%
Here is a more general analysis across different sequence lengths on NVIDIA B300:
vllm bench serve --backend vllm --model nvidia/DeepSeek-V3.2-NVFP4 --input-len seq_len --output-len 4096 --num-prompts 1Accuracy
python tests/evals/gsm8k/gsm8k_eval.pyMAIN:
PR: