Skip to content

[Kernel] Optimize sample_recovered_tokens_kernel#34974

Merged
vllm-bot merged 3 commits intovllm-project:mainfrom
xyang16:sample_recovered_tokens
Feb 21, 2026
Merged

[Kernel] Optimize sample_recovered_tokens_kernel#34974
vllm-bot merged 3 commits intovllm-project:mainfrom
xyang16:sample_recovered_tokens

Conversation

@xyang16
Copy link
Contributor

@xyang16 xyang16 commented Feb 20, 2026

Purpose

This PR optimize the sample_recovered_tokens_kernel in rejection sampler.

  • Use tiled reduction over vocab. Instead of creating one massive tl.arange(0, PADDED_VOCAB_SIZE) vector and loading the whole vocab in one program, load it in chunks of BLOCK_SIZE. This will have lower register pressure and better occupancy.
  • Replace score = prob / q division with score = prob * inv_q, because division is slower than multiply.
  • Kernel level more than 5x improvement. e2e improvement is about 1% since this kernel is not launched many times (launched once in every decode gpu_model_runner.sample())

Test Plan

Added unit test in test_rejection_sampler.py

pytest -s -v tests/v1/sample/test_rejection_sampler.py -k test_sample_recovered_tokens

Test Result

Unit tests passed.

Profiling

Main:

Screenshot 2026-02-19 at 11 37 47 PM

PR:
Screenshot 2026-02-19 at 11 38 03 PM

Main: sample_recovered_tokens_kernel takes around 159µs.

PR: sample_recovered_tokens_kernel takes around 29µs.

Benchmarking

vllm serve deepseek-ai/DeepSeek-R1-0528 \
    --tensor-parallel-size 8 \
    --enable_expert-parallel \
    --max-num-seqs 16 \
    --trust-remote-code \
    --no-enable-prefix-caching \
    --speculative-config '{"method": "mtp", "num_speculative_tokens": 3}'
vllm bench serve \
  --model deepseek-ai/DeepSeek-R1-0528 \
  --dataset-name sharegpt \
  --dataset-path /tmp/ShareGPT_V3_unfiltered_cleaned_split.json \
  --sharegpt-output-len 300 \
  --max-concurrency 16 \
  --num-prompts 1000 \
  --num-warmups 50 \
  --ignore-eos \
  --temperature 0

Main:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  287.68    
Total input tokens:                      229418    
Total generated tokens:                  300000    
Request throughput (req/s):              3.48      
Output token throughput (tok/s):         1042.81   
Peak output token throughput (tok/s):    480.00    
Peak concurrent requests:                25.00     
Total token throughput (tok/s):          1840.28   
---------------Time to First Token----------------
Mean TTFT (ms):                          174.79    
Median TTFT (ms):                        190.15    
P99 TTFT (ms):                           308.77    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.76     
Median TPOT (ms):                        14.62     
P99 TPOT (ms):                           20.18     
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.52     
Median ITL (ms):                         33.48     
P99 ITL (ms):                            142.01    
---------------Speculative Decoding---------------
Acceptance rate (%):                     58.56     
Acceptance length:                       2.76      
Drafts:                                  108883    
Draft tokens:                            326649    
Accepted tokens:                         191290    
Per-position acceptance (%):
  Position 0:                            87.18     
  Position 1:                            56.63     
  Position 2:                            31.88     
==================================================

PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Failed requests:                         0         
Maximum request concurrency:             16        
Benchmark duration (s):                  284.99    
Total input tokens:                      229418    
Total generated tokens:                  300000    
Request throughput (req/s):              3.51      
Output token throughput (tok/s):         1053.68   
Peak output token throughput (tok/s):    480.00    
Peak concurrent requests:                25.00     
Total token throughput (tok/s):          1857.69   
---------------Time to First Token----------------
Mean TTFT (ms):                          170.27    
Median TTFT (ms):                        186.90    
P99 TTFT (ms):                           302.46    
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          14.62     
Median TPOT (ms):                        14.52     
P99 TPOT (ms):                           20.15     
---------------Inter-token Latency----------------
Mean ITL (ms):                           40.27     
Median ITL (ms):                         33.50     
P99 ITL (ms):                            138.88    
---------------Speculative Decoding---------------
Acceptance rate (%):                     58.88     
Acceptance length:                       2.77      
Drafts:                                  108514    
Draft tokens:                            325542    
Accepted tokens:                         191676    
Per-position acceptance (%):
  Position 0:                            87.32     
  Position 1:                            57.03     
  Position 2:                            32.28     
==================================================

e2e output token throughput improves 1%.

Accuracy Testing

python3 -m lm_eval --model local-completions \
  --model_args model=deepseek-ai/DeepSeek-R1-0528,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=16 \
  --tasks gsm8k

Main:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9583|±  |0.0055|
|     |       |strict-match    |     5|exact_match|↑  |0.9553|±  |0.0057|

PR:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.9598|±  |0.0054|
|     |       |strict-match    |     5|exact_match|↑  |0.9568|±  |0.0056|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Feb 20, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request significantly optimizes the sample_recovered_tokens_kernel by implementing tiled reduction over the vocabulary and replacing division with multiplication for improved efficiency. The changes align well with the stated purpose of reducing register pressure and enhancing occupancy, leading to a reported 5x kernel-level improvement. The addition of a native Python reference implementation and comprehensive unit tests ensures the correctness of the optimized kernel. The profiling and benchmarking results demonstrate the positive impact of these optimizations. Overall, this is a well-executed and beneficial performance improvement.

@mergify
Copy link

mergify bot commented Feb 20, 2026

Hi @xyang16, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@xyang16 xyang16 force-pushed the sample_recovered_tokens branch 2 times, most recently from b8f10d1 to 2567007 Compare February 20, 2026 18:18
Signed-off-by: Xin Yang <xyangx@amazon.com>
@xyang16 xyang16 force-pushed the sample_recovered_tokens branch from 2567007 to 5bade31 Compare February 20, 2026 18:43
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good find, using a block size of vocab_size is indeed a very bad idea. This makes sense to me and I don't think it should be controversial

inv_q = q.reciprocal()

recovered_token_ids = torch.empty_like(draft_token_ids)
BLOCK_SIZE = 8192
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since you do test with a vocab_size of 100, this does seem wasteful to not take the min or something

Copy link
Contributor Author

@xyang16 xyang16 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review!

Yes I thought of having BLOCK_SIZE = min(8192, triton.next_power_of_2(vocab_size)).

But since most model's vocab_size is very big and vocab 100 only happens in tests, so I thought having next_power_of_2 would add some overhead.

But I'm happy to make the change. Please let me know how you think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it is probably not worth worrying over

@mgoin mgoin added performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed labels Feb 20, 2026
xyang16 and others added 2 commits February 20, 2026 13:28
Signed-off-by: Xin Yang <xyangx@amazon.com>
@vllm-bot vllm-bot merged commit 7a5adad into vllm-project:main Feb 21, 2026
43 of 46 checks passed
@dosubot
Copy link

dosubot bot commented Feb 21, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

DarkLight1337 pushed a commit to DarkLight1337/vllm that referenced this pull request Feb 21, 2026
joeqzzuo pushed a commit to joeqzzuo/vllm that referenced this pull request Feb 21, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: joezuo <qianzhou.zuo@gmail.com>
yugong333 pushed a commit to yugong333/vllm that referenced this pull request Feb 22, 2026
jmamou pushed a commit to jmamou/vllm that referenced this pull request Feb 23, 2026
@xyang16 xyang16 deleted the sample_recovered_tokens branch February 24, 2026 07:20
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
Signed-off-by: Xin Yang <xyangx@amazon.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants