[Speculative decoding 1/9] Optimized rejection sampler#2336
[Speculative decoding 1/9] Optimized rejection sampler#2336LiuXiaoxuanPKU merged 10 commits intovllm-project:mainfrom
Conversation
d032887 to
434c525
Compare
|
The next PR will be cadedaniel#1, will create it once this is merged. |
|
|
||
| # Create masks using the indices. | ||
| indices = torch.arange(k, device=accepted.device).unsqueeze(0) | ||
| accepted_mask = indices < limits.unsqueeze(1) |
There was a problem hiding this comment.
what's the difference between accepted and accepted_mask?
There was a problem hiding this comment.
accepted is the result of the rejection sampling condition. accepted_mask is True up until the first position rejected by the rejection sampling condition.
Example for k=3, bs=5:
>>> accepted
tensor([[ True, False, True],
[False, False, False],
[ True, True, False],
[ True, True, True],
[False, True, False]])
>>> accepted_mask
tensor([[ True, False, False],
[False, False, False],
[ True, True, False],
[ True, True, True],
[False, False, False]])| super().__init__() | ||
| self.probs_dtype = torch.float32 | ||
| self.token_id_dtype = torch.int64 | ||
| self._num_bonus_tokens = 1 |
There was a problem hiding this comment.
when can num_bonus_tokens > 1? Is it the last generated token by the target model iff all drafted tokens are accepted?
There was a problem hiding this comment.
when can num_bonus_tokens > 1?
It is always 1. This variable is for readability only. I'll add a comment.
Is it the last generated token by the target model iff all drafted tokens are accepted?
Yep!
|
Very exciting work! I hope this feature can be merged soon as many other framworks such as TGI, TRT-LLM, llama.cpp, gpt-fast have supported Speculative sampling. |
| f = torch.clamp(difference, min=self._smallest_positive_value) | ||
|
|
||
| # shape [batch_size, k, vocab_size] | ||
| recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) |
There was a problem hiding this comment.
Nit: torch.multinomial does not require the probability to be normalized.
There was a problem hiding this comment.
I will leave this in to keep the maths consistent with https://arxiv.org/pdf/2302.01318.pdf. This operation is not the compute or scheduling bottleneck.
Speculative decoding
This PR is a part of a larger series of PRs implementing speculative decoding, contributed to open source vLLM by Anyscale. See #2188 and Speculative decoding open sourcing plan for more information.
Rejection sampling
This PR implements optimized rejection sampling, including the following features:
It also contributes tests which verify the rejection sampler's ability to approximate distributions, given enough samples.
The following people contributed to it: @cadedaniel @Yard1 @amogkam
Details
The basic idea behind rejection sampling is that one can sample from the target distribution (larger model) using samples from a proposal distribution (smaller draft model), while guaranteeing the output distribution is equivalent to the target distribution.
"Modified" rejection sampling is introduced in the paper. It ensures that at least one token will always be emitted from the rejection sampling routine, even if all proposal tokens are rejected.
With LLMs, modified rejection sampling can reduce latency because multiple proposal sequences can be evaluated at once (batching on the GPU).
Finally, the paper introduces the notion of a "bonus" token. In the case where all proposed tokens are accepted, an additional token can be emitted. This is possible by having the target model predict the next token given the entire proposed sequence as context.
Visual confirmation that modified rejection sampling approximates the target distribution:


code for visualizations: https://gist.github.com/cadedaniel/07c1cd4ac003f51140b205580ac02613