Skip to content

[Perf] Triton-based top-p/top-k masking#32558

Closed
njhill wants to merge 6 commits intovllm-project:mainfrom
njhill:triton-topk-topp
Closed

[Perf] Triton-based top-p/top-k masking#32558
njhill wants to merge 6 commits intovllm-project:mainfrom
njhill:triton-topk-topp

Conversation

@njhill
Copy link
Copy Markdown
Member

@njhill njhill commented Jan 18, 2026

The current pytorch sort-based top-p/top-k implementation uses a large amount of memory and scales very poorly with vocab and batch sizes. There is a top-k only path which is much faster but it involves a GPU->CPU sync that's detrimental to async scheduling.

The FlashInfer sampler is much faster but has various downsides and is not enabled by default:

  • It's historically had stability and correctness issues (Disable FlashInfer sampler by default #26859)
  • It also involves a CPU sync (or at least used to)
  • It does not support per-request seeds
  • Since it combines the logit selection and sampling, we don't have access to the intermediate logits/logprobs which are needed for some use cases.

This is a triton kernel which is not as fast as FlashInfer but is much faster than the pytorch impl for large sizes and does not use extra memory. Written mostly by claude code with some HIL iteration.

Pytorch sorting impl still wins for small batch sizes (how small depends on vocab size and whether both top-p and top-k are used, cutoff between 8 and 128).

Scenario            Batch   Vocab   Ops%  Triton (ms)  PyTorch (ms)  Speedup    Tri Mem    Pyt Mem
--------------------------------------------------------------------------------------------------
topk_whole            256  131072   100%        1.504         3.812    2.53x  384.33 MB    1.25 GB
topk_partial          256  131072    50%        1.371         3.815    2.78x  384.33 MB    1.25 GB
topp_whole            256  131072   100%        1.847         4.258    2.31x  384.33 MB    1.25 GB
topp_partial          256  131072    50%        1.713         4.259    2.49x  384.33 MB    1.25 GB
topk_topp_whole       256  131072   200%        3.236         4.460    1.38x  384.33 MB    1.25 GB
mixed_partial         256  131072   134%        3.127         4.460    1.43x  384.33 MB    1.25 GB
topk_whole            512  131072   100%        1.893         7.509    3.97x  768.33 MB    2.50 GB
topk_partial          512  131072    50%        1.501         7.506    5.00x  768.33 MB    2.50 GB
topp_whole            512  131072   100%        2.207         8.366    3.79x  768.33 MB    2.50 GB
topp_partial          512  131072    50%        1.864         8.366    4.49x  768.33 MB    2.50 GB
topk_topp_whole       512  131072   200%        3.911         8.758    2.24x  768.33 MB    2.50 GB
mixed_partial         512  131072   134%        3.561         8.758    2.46x  768.33 MB    2.50 GB
topk_whole           1024  131072   100%        3.224        14.911    4.63x    1.50 GB    5.00 GB
topk_partial         1024  131072    50%        1.906        14.909    7.82x    1.50 GB    5.00 GB
topp_whole           1024  131072   100%        3.441        16.368    4.76x    1.50 GB    5.00 GB
topp_partial         1024  131072    50%        2.271        16.371    7.21x    1.50 GB    5.00 GB
topk_topp_whole      1024  131072   200%        6.286        17.141    2.73x    1.50 GB    5.00 GB
mixed_partial        1024  131072   133%        5.021        17.141    3.41x    1.50 GB    5.00 GB

End-to-end benchmark on Qwen3-8B, 1xH100

Both top-p and top-k set on all requests (this represents out-of-the-box performance since Qwen defaults to setting these).

Before

============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             800       
Benchmark duration (s):                  81.77     
Total input tokens:                      256000    
Total generated tokens:                  1024000   
Request throughput (req/s):              24.46     
Output token throughput (tok/s):         12523.40  
Peak output token throughput (tok/s):    17344.00  
Peak concurrent requests:                1167.00   
Total token throughput (tok/s):          15654.25  
---------------Time to First Token----------------
Mean TTFT (ms):                          2305.83   
Median TTFT (ms):                        2840.91   
P99 TTFT (ms):                           3535.57   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.66     
Median TPOT (ms):                        55.50     
P99 TPOT (ms):                           78.03     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.66     
Median ITL (ms):                         50.28     
P99 ITL (ms):                            207.79    
==================================================

After (+17% throughput)

============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             800       
Benchmark duration (s):                  69.92     
Total input tokens:                      256000    
Total generated tokens:                  1024000   
Request throughput (req/s):              28.61     
Output token throughput (tok/s):         14645.93  
Peak output token throughput (tok/s):    21395.00  
Peak concurrent requests:                1104.00   
Total token throughput (tok/s):          18307.42  
---------------Time to First Token----------------
Mean TTFT (ms):                          2178.85   
Median TTFT (ms):                        2607.79   
P99 TTFT (ms):                           3227.72   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          45.11     
Median TPOT (ms):                        46.58     
P99 TPOT (ms):                           63.48     
---------------Inter-token Latency----------------
Mean ITL (ms):                           45.12     
Median ITL (ms):                         42.18     
P99 ITL (ms):                            200.54    
==================================================

No top-k, setting top-p on only 10% of requests:

Before

============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             800       
Benchmark duration (s):                  80.89     
Total input tokens:                      256000    
Total generated tokens:                  1024000   
Request throughput (req/s):              24.72     
Output token throughput (tok/s):         12659.03  
Peak output token throughput (tok/s):    17600.00  
Peak concurrent requests:                1078.00   
Total token throughput (tok/s):          15823.79  
---------------Time to First Token----------------
Mean TTFT (ms):                          2191.01   
Median TTFT (ms):                        2822.09   
P99 TTFT (ms):                           3473.53   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          53.25     
Median TPOT (ms):                        55.17     
P99 TPOT (ms):                           77.04     
---------------Inter-token Latency----------------
Mean ITL (ms):                           53.26     
Median ITL (ms):                         49.65     
P99 ITL (ms):                            207.46    
==================================================

After (+26% throughput)

============ Serving Benchmark Result ============
Successful requests:                     2000      
Failed requests:                         0         
Maximum request concurrency:             800       
Benchmark duration (s):                  64.36     
Total input tokens:                      256000    
Total generated tokens:                  1024000   
Request throughput (req/s):              31.07     
Output token throughput (tok/s):         15909.96  
Peak output token throughput (tok/s):    24752.00  
Peak concurrent requests:                1105.00   
Total token throughput (tok/s):          19887.45  
---------------Time to First Token----------------
Mean TTFT (ms):                          2144.75   
Median TTFT (ms):                        2577.34   
P99 TTFT (ms):                           3074.03   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          41.34     
Median TPOT (ms):                        42.53     
P99 TPOT (ms):                           58.37     
---------------Inter-token Latency----------------
Mean ITL (ms):                           41.34     
Median ITL (ms):                         38.30     
P99 ITL (ms):                            197.00    
==================================================

@mergify mergify bot added performance Performance-related issues v1 labels Jan 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new Triton-based kernel for top-k/top-p masking to improve performance, along with comprehensive benchmarks and tests. The changes are well-structured. However, I found a critical issue in vllm/v1/sample/ops/topk_topp_sampler.py where the native forward pass incorrectly calls a TPU-specific sampler function due to a mistaken import. This needs to be fixed to ensure correctness and intended performance on non-TPU platforms.

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

didnt know you were a kernel guy now :0

Signed-off-by: Nick Hill <nickhill123@gmail.com>
@njhill
Copy link
Copy Markdown
Member Author

njhill commented Jan 18, 2026

didnt know you were a kernel guy now :0

with some help from claude lol

Signed-off-by: Nick Hill <nickhill123@gmail.com>
@njhill njhill marked this pull request as ready for review January 19, 2026 17:47
Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with Cloud Agents, enable Autofix in the Cursor dashboard.

Comment @cursor review or bugbot run to trigger another review on this PR

if batch_size < threshold:
# Use pytorch sort implementation for smaller batch sizes.
return apply_top_k_top_pytorch(logits, k, p)
return apply_top_k_top_p_triton(logits, k, p)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Triton dispatch breaks non-CUDA device support

High Severity

The apply_top_k_top_p function dispatches to apply_top_k_top_p_triton for large batch sizes without checking if the tensor is on CUDA. This function is called by forward_native, which is used on non-CUDA devices including CPU with RISCV/POWERPC architectures, XPU, and ROCm when aiter import fails. Since apply_top_k_top_p_triton asserts logits.is_cuda, this causes an assertion failure on these platforms when batch sizes exceed the threshold.

Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think some adjustments are still needed to ensure this isn't on any non-cuda path.

if batch_size < threshold:
# Use pytorch sort implementation for smaller batch sizes.
return apply_top_k_top_pytorch(logits, k, p)
return apply_top_k_top_p_triton(logits, k, p)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inconsistent dtype handling causes batch-size-dependent failures

Medium Severity

The apply_top_k_top_p dispatcher creates inconsistent behavior based on batch size for non-float32 logits. The Triton kernel asserts logits.dtype == torch.float32, but apply_top_k_top_pytorch works with any float dtype. Since logits can arrive in fp16/bf16 (evidenced by explicit dtype=torch.float32 in all downstream softmax calls), small batches succeed while large batches fail with an assertion error. This is a regression from the previous behavior where all batch sizes worked with any float dtype.

Fix in Cursor Fix in Web

@robertgshaw2-redhat
Copy link
Copy Markdown
Collaborator

nice, this is really good.

Do you think we should do something similar for the rejection sampler?

@njhill
Copy link
Copy Markdown
Member Author

njhill commented Jan 19, 2026

nice, this is really good.

Do you think we should do something similar for the rejection sampler?

The same function is already used by the rejection sampler, so it should hopefully also benefit from this:

# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
return apply_top_k_top_p(logits, top_k, top_p)

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 21, 2026
Signed-off-by: Nick Hill <nickhill123@gmail.com>
@njhill
Copy link
Copy Markdown
Member Author

njhill commented Jan 26, 2026

I wasn't aware of similar existing PR #25824, which looks in good shape. Hoping to switch to that one, include the tests/benchmark from this one.

@cakeng
Copy link
Copy Markdown
Contributor

cakeng commented Feb 2, 2026

Hi @njhill , I created a new PR #33538, I wasn't able to re-open PR #25824. I've included the tests/benchmarks from this PR, please take a look!

@njhill
Copy link
Copy Markdown
Member Author

njhill commented Feb 18, 2026

Closing in favour of #33538 which is now merged.

@njhill njhill closed this Feb 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants