Skip to content

[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling#32852

Merged
benchislett merged 6 commits intovllm-project:mainfrom
caozuoba:perf/rejection-sampler-greedy
Jan 31, 2026
Merged

[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling#32852
benchislett merged 6 commits intovllm-project:mainfrom
caozuoba:perf/rejection-sampler-greedy

Conversation

@caozuoba
Copy link
Copy Markdown
Contributor

@caozuoba caozuoba commented Jan 22, 2026

Purpose

This PR avoids computing a full-vocabulary softmax in the v1 speculative decoding rejection sampler when the entire batch is greedy (sampling_metadata.all_greedy).

For all-greedy decoding, the rejection sampler only needs argmax(target_logits); a dense softmax is unnecessary work. Since argmax(softmax(logits)) == argmax(logits), this change is behavior-preserving for the greedy path while reducing compute/memory overhead.

Test Result

Correctness (pytest)

Command

pytest -q tests/v1/sample/test_rejection_sampler.py

Result

.....................................                                                                                                                   [100%]
====================================================================== warnings summary =======================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
37 passed, 2 warnings in 30.04s

Performance

Compared to main, on NVIDIA H800, this PR improves Output token throughput (tok/s) by ~3.14%, reduces Mean TPOT (ms) by ~7.04%, and reduces Mean E2EL (ms) by ~3.67%.

main
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  11.14
Total input tokens:                      128000
Total generated tokens:                  100000
Request throughput (req/s):              89.74
Output token throughput (tok/s):         8974.16
Peak output token throughput (tok/s):    7331.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          20461.07
---------------Time to First Token----------------
Mean TTFT (ms):                          4673.80
Median TTFT (ms):                        2721.85
P99 TTFT (ms):                           8599.31
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          39.22
Median TPOT (ms):                        37.92
P99 TPOT (ms):                           67.62
---------------Inter-token Latency----------------
Mean ITL (ms):                           79.31
Median ITL (ms):                         72.16
P99 ITL (ms):                            122.20
----------------End-to-end Latency----------------
Mean E2EL (ms):                          8557.05
Median E2EL (ms):                        8733.72
P99 E2EL (ms):                           10952.13
==================================================
PR
============ Serving Benchmark Result ============
Successful requests:                     1000
Failed requests:                         0
Benchmark duration (s):                  10.80
Total input tokens:                      128000
Total generated tokens:                  100000
Request throughput (req/s):              92.56
Output token throughput (tok/s):         9255.84
Peak output token throughput (tok/s):    8709.00
Peak concurrent requests:                1000.00
Total token throughput (tok/s):          21103.32
---------------Time to First Token----------------
Mean TTFT (ms):                          4633.34
Median TTFT (ms):                        2941.03
P99 TTFT (ms):                           8419.22
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          36.46
Median TPOT (ms):                        34.74
P99 TPOT (ms):                           65.33
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.20
Median ITL (ms):                         69.39
P99 ITL (ms):                            109.64
----------------End-to-end Latency----------------
Mean E2EL (ms):                          8243.21
Median E2EL (ms):                        8503.44
P99 E2EL (ms):                           10646.14
==================================================

@mergify mergify Bot added the v1 label Jan 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable performance optimization by skipping the full-vocabulary softmax calculation during greedy decoding in the rejection sampler. The change correctly leverages the mathematical equivalence of argmax(softmax(logits)) and argmax(logits), leading to improved output token throughput and reduced latency as demonstrated by the provided benchmark results. The logic appears sound and the change is well-justified for performance gains.

Comment thread vllm/v1/sample/rejection_sampler.py Outdated
# NOTE: For all-greedy decoding, the rejection sampler only needs
# argmax(target_logits), so computing a full-vocab softmax is wasted.
if sampling_metadata.all_greedy:
target_probs = target_logits
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable target_probs is assigned target_logits when sampling_metadata.all_greedy is true. While the current usage in rejection_sample correctly handles this dual meaning (using it for argmax in greedy mode and as probabilities in random mode), the name target_probs can be misleading as it typically implies a probability distribution (values summing to 1). This could lead to confusion for future developers and potentially introduce bugs if the variable is used in contexts where actual probabilities are strictly expected, without explicitly checking the all_greedy flag. Consider using a more generic name for this variable, such as target_sampling_input, to accurately reflect its conditional content.

@mgoin mgoin requested a review from benchislett January 22, 2026 13:56
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me but I'd like @benchislett or @WoosukKwon to sign off

@mgoin mgoin requested a review from WoosukKwon January 22, 2026 13:57
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Jan 22, 2026
@benchislett
Copy link
Copy Markdown
Collaborator

Only concern would be if those probs are used downstream in anything besides sampling. Otherwise looks good

@caozuoba
Copy link
Copy Markdown
Contributor Author

Only concern would be if those probs are used downstream in anything besides sampling. Otherwise looks good

Regarding this concern: in the all_greedy case we only use argmax and return early in rejection_sample, so the values are never treated as normalized probabilities. Using logits there is behavior-preserving (argmax(softmax(x)) == argmax(x)). Non-greedy paths still use softmax as before. Thanks for your review

@jeejeelee
Copy link
Copy Markdown
Collaborator

@benchislett Could you please take another look?

@caozuoba
Copy link
Copy Markdown
Contributor Author

@mgoin @benchislett Hi, could you please let me know if this PR is ready to be merged? If you’d like me to run any additional tests on my side, please tell me which ones and I’ll do that.Thanks for your time and review.

@caozuoba
Copy link
Copy Markdown
Contributor Author

Could someone please help move this PR forward? Thanks. @mgoin @benchislett @WoosukKwon @njhill

Copy link
Copy Markdown
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @caozuoba.

I think it would be clearer to have rejection_sample take the target logits rather than probs, and move the softmax inside there.

Please also sign-off your commits for the DCO.

@caozuoba
Copy link
Copy Markdown
Contributor Author

Thanks @caozuoba.

I think it would be clearer to have rejection_sample take the target logits rather than probs, and move the softmax inside there.

Please also sign-off your commits for the DCO.

@njhill Thanks for the feedback! Agree this would be clearer. I’ll update rejection_sample to take the target logits and move the softmax inside the function.
I’ll also sign off my commits for the DCO and push an updated version shortly.

@caozuoba caozuoba force-pushed the perf/rejection-sampler-greedy branch from 4785ff8 to bf858cb Compare January 30, 2026 17:35
@caozuoba
Copy link
Copy Markdown
Contributor Author

@njhill Thanks! I’ve updated the code to have rejection_sample take target logits and moved the softmax inside.I also rebased and added DCO sign-offs to all commits (force-pushed to the same PR branch).Could you please take another look when you have a chance?

Copy link
Copy Markdown
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@benchislett benchislett enabled auto-merge (squash) January 30, 2026 19:09
@benchislett benchislett merged commit 8980001 into vllm-project:main Jan 31, 2026
40 checks passed
PiratePai pushed a commit to PiratePai/epd_shm that referenced this pull request Feb 3, 2026
…vllm-project#32852)

Signed-off-by: hdj <1293066020@qq.com>
Signed-off-by: Pai <416932041@qq.com>
whx-sjtu pushed a commit to vllm-project/vllm-ascend that referenced this pull request Feb 11, 2026
)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
chenchuw886 pushed a commit to chenchuw886/vllm-ascend that referenced this pull request Feb 12, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: momochenchuw <chenchuw@huawei.com>
banxiaduhuo pushed a commit to banxiaduhuo/vllm-ascend that referenced this pull request Feb 26, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Feb 28, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
maoxx241 pushed a commit to maoxx241/vllm-ascend that referenced this pull request Mar 2, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
ZRJ026 pushed a commit to ZRJ026/vllm-ascend that referenced this pull request Mar 4, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 6, 2026
…lm-project#6685)

### What this PR does / why we need it?
This PR aims to update `target_probs` to `target_logits` in
`rejection_sample`, for catching up with
vllm-project/vllm#32852. Otherwise, sampling
with temperature will incur accuracy problem where tokens can be
accepted or rejected unreasonably.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.15.0
- vLLM main:
vllm-project/vllm@1339784

Signed-off-by: Zetong Li <slippersss@126.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants