Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for using draft model probabilities in speculative decoding, a significant enhancement. The changes involve refactoring the sampling logic and updating data structures like InputBatch and RequestState to handle draft logits. While the overall direction is good, I've identified two critical bugs related to incorrect tensor shapes and arguments being passed to kernel functions. These issues will lead to incorrect behavior and must be addressed.
| return draft_tokens.view(-1, 1) | ||
| return Speculation( | ||
| draft_tokens=draft_tokens.view(-1, 1), | ||
| draft_logits=draft_logits.view(-1, 1), |
There was a problem hiding this comment.
The draft_logits tensor is being reshaped incorrectly. draft_logits has a shape of [num_reqs, vocab_size]. Using .view(-1, 1) results in a tensor of shape [num_reqs * vocab_size, 1], which is incorrect.
The Speculation dataclass expects draft_logits to have a shape of [num_reqs, num_speculative_steps, vocab_size]. For num_speculative_steps=1, the correct shape is [num_reqs, 1, vocab_size]. You should use .unsqueeze(1) to add a dimension.
| draft_logits=draft_logits.view(-1, 1), | |
| draft_logits=draft_logits.unsqueeze(1), |
|
/gemini review I think there's a bug when cudagraph is enabled, do you see it |
There was a problem hiding this comment.
Code Review
This pull request introduces support for using draft model probabilities in speculative decoding, which is a significant enhancement. The changes involve propagating draft logits through the system, refactoring the sampler for probabilistic rejection sampling, and updating the speculator accordingly. A new correctness test for sampling recovered and bonus tokens is a valuable addition. While the overall direction is good, I've identified a critical issue with the log-probability calculation in the speculative decoding path. The logic for computing logprobs for accepted tokens appears to be incorrect due to tensor shape mismatches and not accounting for different token origins (accepted draft vs. recovered), which will likely lead to incorrect output.
| max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) | ||
| if max_num_logprobs != NO_LOGPROBS: | ||
| expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] | ||
| cu_num_logits_list = cu_num_logits_np.tolist() if expanded_logits else None | ||
| logprobs_tensors = compute_topk_logprobs( | ||
| processed_logits, max_num_logprobs, sampled, cu_num_logits_list | ||
| ) |
There was a problem hiding this comment.
The log-probability calculation for speculative decoding appears to be incorrect. The compute_topk_logprobs function is called with a sampled tensor of shape [num_reqs, num_speculative_steps + 1] and processed_logits of shape [num_draft_tokens + num_reqs, vocab_size]. The sampled tensor is 2D and padded, while compute_topk_logprobs likely expects a 1D tensor of token IDs that aligns with the provided logits.
Furthermore, the sampled tokens are a mix of accepted draft tokens and "recovered" tokens. The logprobs for these two types of tokens come from different distributions (p_target vs. renormalize(max(0, p_target - p_draft))). The current implementation does not seem to distinguish between them when computing logprobs, passing only processed_logits (derived from p_target for proposed tokens). This will lead to incorrect logprob values.
|
Hi @andylolu2, I was tasked with working on probabilistic rejection sampling for MRV2 in #35461, and stumbled upon your PR here. Looks like we took similar approaches. I have some benchmark results on my PR showing some acceptance rate improvements. I'm happy to compare implementations and figure out the best path forward for us to unlock this feature. Perhaps we can combine the best parts of both! cc: @WoosukKwon |
|
This pull request has merge conflicts that must be resolved before it can be |
Oooo amazing, I am a bit short on bandwidth atm so please feel free to take over! I have some take aways while implementing this, let me comment in your PR |
Purpose
Support drafter probabilities in ModelRunnerV2.
Test Plan
TBD
Test Result
TBD