[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling#38045
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new "synthetic" rejection sampling method for speculative decoding, including its configuration and implementation. The RejectionSampler class has been updated to support this new method, alongside existing "strict" and "probabilistic" methods. A review comment suggests replacing an assert statement with a ValueError for validating the synthetic_acceptance_rate to ensure robust error handling.
b91ef17 to
e91b99b
Compare
e91b99b to
b918c9e
Compare
b918c9e to
5873fdc
Compare
|
@claude review |
benchislett
left a comment
There was a problem hiding this comment.
- Way too much complexity here for an optional feature.
- I don't like the idea of skipping the main rejection sampling and replacing it with a synthetic kernel. If we are using synthetic sampling to approximate a workload, I would prefer to still call the main rejection sampler and then override the outputs so that we never underestimate the cost of the sampling itself
|
@benchislett Thanks for sharing your thought.
I don't really agree. I think 99% of the "complexity" in this PR can be gone if we relocate the code for synthetic rejection sampling into a separate file. After that, one can clearly tell that the synthetic rejection sampling is not entangled with other options.
I think the overhead matches that of strict rejection sampling. Probabilistic rejection sampling will be more expensive, but we probably don't need to model it precisely. |
5873fdc to
f94cfcf
Compare
|
Addressed the following:
And echoing what Woosuk said, the overhead here should closely approximate the strict rejection sample kernel, which is the main one we are focused on approximating. |
| if ( | ||
| synthetic_acceptance_rate is None | ||
| or not 0.0 <= synthetic_acceptance_rate <= 1.0 | ||
| ): | ||
| raise ValueError( | ||
| f"synthetic_acceptance_rate must be in [0, 1], " | ||
| f"but got {synthetic_acceptance_rate}" | ||
| ) |
There was a problem hiding this comment.
nit: Maybe we can move this logic into compute_synthetic_rejection_sampler_params? This might look a bit cleaner
There was a problem hiding this comment.
I'm tempted to, but we'd still want to raise a ValueError here if synthetic_acceptance_rate is None (because it would be weird for compute_synthetic_rejection_sampler_params to accept a nullable input). So all we'd gain in terms of cleanliness is losing the not 0.0 <= synthetic_acceptance_rate <= 1.0 condition here
f94cfcf to
b200a1c
Compare
|
@WoosukKwon added |
WoosukKwon
left a comment
There was a problem hiding this comment.
@benchislett The PR looks good to me. Can you take another look?
| return sampled, num_sampled | ||
|
|
||
|
|
||
| def compute_synthetic_rejection_sampler_params( |
There was a problem hiding this comment.
would this be too complicated? I was thinking we just impose a position-independent acceptance rate, like accepting every token with a constant ratio (e.g. 90%).
There was a problem hiding this comment.
So you're suggesting that the exposed parameter determines the fixed conditional acceptance rate at each draft position.
In this PR, the parameter determines the average of the joint per-positon acceptance rates (the average acceptance rate you see in the benchmark metrics). That is the metric i generally find most useful and intuitive, but i'm open to hear other people's thoughts
|
@TheEpicDolphin Can you please rebase? |
…jection sampling Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
b200a1c to
860c81e
Compare
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Michel Belleau <michel.belleau@malaiwah.com>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Monishver Chandrasekaran <monishverchandrasekaran@gmail.com>
…hold to 0.91 The LM Eval Large Models (H200) CI job was failing because the NVIDIA-Nemotron-3-Super-120B-A12B-BF16 model scored slightly below the 0.93 accuracy threshold on GSM8K. The model uses MTP speculative decoding with 5 speculative tokens. Recent changes to the Model Runner V2 spec decode path (PRs vllm-project#38045 and vllm-project#38311) adjusted rejection sampling behavior and rebuilt attention metadata before eagle decode, which can marginally affect the acceptance rate and therefore the final accuracy score. Lower the threshold from 0.93 to 0.91 to reflect the current achievable accuracy with the updated spec decode implementation. The model still demonstrates strong GSM8K performance above 91%. Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Nithin Chalapathi <nithin.ch10@gmail.com>
…hold to 0.91 The LM Eval Large Models (H200) CI job was failing because the NVIDIA-Nemotron-3-Super-120B-A12B-BF16 model scored slightly below the 0.93 accuracy threshold on GSM8K. The model uses MTP speculative decoding with 5 speculative tokens. Recent changes to the Model Runner V2 spec decode path (PRs vllm-project#38045 and vllm-project#38311) adjusted rejection sampling behavior and rebuilt attention metadata before eagle decode, which can marginally affect the acceptance rate and therefore the final accuracy score. Lower the threshold from 0.93 to 0.91 to reflect the current achievable accuracy with the updated spec decode implementation. The model still demonstrates strong GSM8K performance above 91%. Signed-off-by: SandishKumarHN <sandishkumarhn@gmail.com>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: Vinay Damodaran <vrdn@hey.com>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: EricccYang <yangyang4991@gmail.com>
…jection sampling (vllm-project#38045) Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai> Signed-off-by: bhargav-patel-29 <bhargav.patel@tihiitb.org>
Purpose
Enable testing/debugging with a fixed, expected acceptance rate. To easily support this in MRV2, I added a rejection sampling method to the speculative config (synthetic), and added a field for setting the desired, mean, acceptance rate (
synthetic_acceptance_rate). That field is only used ifsyntheticrejection sampling is selected.In practice with speculative decoding with > 1 draft tokens, we see a drop off after each position. To simulate this drop-off, the base acceptance rate (position = 0) and decay factor are computed based on the desired, mean acceptance rate. Those are used to compute the per-position conditional acceptance rates, which are used for accepting/rejecting draft tokens. This is all to add realism to the results. The decay factor (minimum value of 0.85, what is generally seen in practice for well-tuned draft models) has to be selected carefully to ensure that the base acceptance rate does not exceed 1, which would prevent from achieving the desired mean acceptance rate. That is what
compute_base_acceptance_rateandmin_valid_decay_factordo.Test Plan
Added test to ensure that the base acceptance rate and decay factor are computed correctly, and that the average of the joint probabilities in fact equals the desired value.
Additionally I manually tested using spec decoding with Llama3:
Server
Benchmark
Results
Speculative Decoding Benchmark Results