[Model Runner V2] Spec decode rejection sampler greedy support#37238
Conversation
d0bbcc1 to
f6618c3
Compare
| rejected_steps = sampled.new_empty(num_reqs) | ||
| _probabilistic_rejection_sample_kernel[(num_reqs,)]( | ||
| # [num_reqs] | ||
| rejected_pos = pos.new_empty(num_reqs) |
There was a problem hiding this comment.
I felt it made more sense to compute this in _probabilistic_rejection_kernel rather than _compute_residual_logits_kernel, so i moved it here. Also, renamed it from residual_pos to rejected_pos.
|
Hi @TheEpicDolphin, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
There was a problem hiding this comment.
Code Review
This pull request adds support for greedy sampling (temperature=0) in the speculative decoding rejection sampler. The changes are well-structured, introducing new Triton kernels to handle greedy and probabilistic paths efficiently. The logic for rejection sampling and resampling in the greedy case is sound. I've found one potential issue in a newly added but currently unused kernel that should be addressed.
2c46096 to
f188893
Compare
WoosukKwon
left a comment
There was a problem hiding this comment.
Thanks for the PR!
I think we can fuse more kernels to minimize the materialization of *_logits tensors, but we can probably follow up after this.
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
f188893 to
47f633e
Compare
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…project#37238) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Purpose
Following up on #35461, specifically with support for greedy sampling (temperature = 0).
In order to support this in an efficient way, I get local argmax/max from target logits for greedy requests in
_gather_draft_logits_and_target_argmax_kernel. Then during_probabilistic_rejection_kernel, the target argmax token is sampled only for the greedy requests. This limits the performance impact of greedy requests on the rest of the batch.Benchmarks
Server
Client
Base model:
openai/gpt-oss-20bDraft model:
RedHatAI/gpt-oss-20b-speculator.eagle3Benchmark: mt-bench · 1000 requests · max concurrency 16 · 256K total generated tokens
Benchmark Comparison