[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling#32852
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable performance optimization by skipping the full-vocabulary softmax calculation during greedy decoding in the rejection sampler. The change correctly leverages the mathematical equivalence of argmax(softmax(logits)) and argmax(logits), leading to improved output token throughput and reduced latency as demonstrated by the provided benchmark results. The logic appears sound and the change is well-justified for performance gains.
| # NOTE: For all-greedy decoding, the rejection sampler only needs | ||
| # argmax(target_logits), so computing a full-vocab softmax is wasted. | ||
| if sampling_metadata.all_greedy: | ||
| target_probs = target_logits |
There was a problem hiding this comment.
The variable target_probs is assigned target_logits when sampling_metadata.all_greedy is true. While the current usage in rejection_sample correctly handles this dual meaning (using it for argmax in greedy mode and as probabilities in random mode), the name target_probs can be misleading as it typically implies a probability distribution (values summing to 1). This could lead to confusion for future developers and potentially introduce bugs if the variable is used in contexts where actual probabilities are strictly expected, without explicitly checking the all_greedy flag. Consider using a more generic name for this variable, such as target_sampling_input, to accurately reflect its conditional content.
mgoin
left a comment
There was a problem hiding this comment.
This seems reasonable to me but I'd like @benchislett or @WoosukKwon to sign off
|
Only concern would be if those probs are used downstream in anything besides sampling. Otherwise looks good |
Regarding this concern: in the all_greedy case we only use argmax and return early in rejection_sample, so the values are never treated as normalized probabilities. Using logits there is behavior-preserving (argmax(softmax(x)) == argmax(x)). Non-greedy paths still use softmax as before. Thanks for your review |
|
@benchislett Could you please take another look? |
|
@mgoin @benchislett Hi, could you please let me know if this PR is ready to be merged? If you’d like me to run any additional tests on my side, please tell me which ones and I’ll do that.Thanks for your time and review. |
|
Could someone please help move this PR forward? Thanks. @mgoin @benchislett @WoosukKwon @njhill |
@njhill Thanks for the feedback! Agree this would be clearer. I’ll update rejection_sample to take the target logits and move the softmax inside the function. |
Signed-off-by: hdj <1293066020@qq.com>
Signed-off-by: hdj <1293066020@qq.com>
4785ff8 to
bf858cb
Compare
|
@njhill Thanks! I’ve updated the code to have rejection_sample take target logits and moved the softmax inside.I also rebased and added DCO sign-offs to all commits (force-pushed to the same PR branch).Could you please take another look when you have a chance? |
…vllm-project#32852) Signed-off-by: hdj <1293066020@qq.com> Signed-off-by: Pai <416932041@qq.com>
) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com> Signed-off-by: momochenchuw <chenchuw@huawei.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com> Signed-off-by: zrj026 <zhangrunjiang026@gmail.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com>
…lm-project#6685) ### What this PR does / why we need it? This PR aims to update `target_probs` to `target_logits` in `rejection_sample`, for catching up with vllm-project/vllm#32852. Otherwise, sampling with temperature will incur accuracy problem where tokens can be accepted or rejected unreasonably. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: v0.15.0 - vLLM main: vllm-project/vllm@1339784 Signed-off-by: Zetong Li <slippersss@126.com>
Purpose
This PR avoids computing a full-vocabulary softmax in the v1 speculative decoding rejection sampler when the entire batch is greedy (sampling_metadata.all_greedy).
For all-greedy decoding, the rejection sampler only needs argmax(target_logits); a dense softmax is unnecessary work. Since argmax(softmax(logits)) == argmax(logits), this change is behavior-preserving for the greedy path while reducing compute/memory overhead.
Test Result
Correctness (pytest)
Command
Result
Performance
Compared to
main, on NVIDIA H800, this PR improves Output token throughput (tok/s) by ~3.14%, reduces Mean TPOT (ms) by ~7.04%, and reduces Mean E2EL (ms) by ~3.67%.