[Perf] Triton-based top-p/top-k masking#32558
[Perf] Triton-based top-p/top-k masking#32558njhill wants to merge 6 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new Triton-based kernel for top-k/top-p masking to improve performance, along with comprehensive benchmarks and tests. The changes are well-structured. However, I found a critical issue in vllm/v1/sample/ops/topk_topp_sampler.py where the native forward pass incorrectly calls a TPU-specific sampler function due to a mistaken import. This needs to be fixed to ensure correctness and intended performance on non-TPU platforms.
|
didnt know you were a kernel guy now :0 |
Signed-off-by: Nick Hill <nickhill123@gmail.com>
8a916c9 to
7643eab
Compare
with some help from claude lol |
ebc635a to
3240277
Compare
3240277 to
5a241a6
Compare
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 2 potential issues.
Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.
Comment @cursor review or bugbot run to trigger another review on this PR
| if batch_size < threshold: | ||
| # Use pytorch sort implementation for smaller batch sizes. | ||
| return apply_top_k_top_pytorch(logits, k, p) | ||
| return apply_top_k_top_p_triton(logits, k, p) |
There was a problem hiding this comment.
Triton dispatch breaks non-CUDA device support
High Severity
The apply_top_k_top_p function dispatches to apply_top_k_top_p_triton for large batch sizes without checking if the tensor is on CUDA. This function is called by forward_native, which is used on non-CUDA devices including CPU with RISCV/POWERPC architectures, XPU, and ROCm when aiter import fails. Since apply_top_k_top_p_triton asserts logits.is_cuda, this causes an assertion failure on these platforms when batch sizes exceed the threshold.
There was a problem hiding this comment.
Yes, I think some adjustments are still needed to ensure this isn't on any non-cuda path.
| if batch_size < threshold: | ||
| # Use pytorch sort implementation for smaller batch sizes. | ||
| return apply_top_k_top_pytorch(logits, k, p) | ||
| return apply_top_k_top_p_triton(logits, k, p) |
There was a problem hiding this comment.
Inconsistent dtype handling causes batch-size-dependent failures
Medium Severity
The apply_top_k_top_p dispatcher creates inconsistent behavior based on batch size for non-float32 logits. The Triton kernel asserts logits.dtype == torch.float32, but apply_top_k_top_pytorch works with any float dtype. Since logits can arrive in fp16/bf16 (evidenced by explicit dtype=torch.float32 in all downstream softmax calls), small batches succeed while large batches fail with an assertion error. This is a regression from the previous behavior where all batch sizes worked with any float dtype.
|
nice, this is really good. Do you think we should do something similar for the rejection sampler? |
The same function is already used by the rejection sampler, so it should hopefully also benefit from this: vllm/vllm/v1/sample/rejection_sampler.py Lines 496 to 498 in 7350331 |
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
|
I wasn't aware of similar existing PR #25824, which looks in good shape. Hoping to switch to that one, include the tests/benchmark from this one. |
|
Closing in favour of #33538 which is now merged. |
The current pytorch sort-based top-p/top-k implementation uses a large amount of memory and scales very poorly with vocab and batch sizes. There is a top-k only path which is much faster but it involves a GPU->CPU sync that's detrimental to async scheduling.
The FlashInfer sampler is much faster but has various downsides and is not enabled by default:
This is a triton kernel which is not as fast as FlashInfer but is much faster than the pytorch impl for large sizes and does not use extra memory. Written mostly by claude code with some HIL iteration.
Pytorch sorting impl still wins for small batch sizes (how small depends on vocab size and whether both top-p and top-k are used, cutoff between 8 and 128).
End-to-end benchmark on Qwen3-8B, 1xH100
Both top-p and top-k set on all requests (this represents out-of-the-box performance since Qwen defaults to setting these).
Before
After (+17% throughput)
No top-k, setting top-p on only 10% of requests:
Before
After (+26% throughput)