[WIP] Speculative Decoding#1797
Conversation
pian13131
left a comment
There was a problem hiding this comment.
Just some tiny comments.
| last_child_sample = child_samples[-1] | ||
| parent.append_token_id(last_child_sample.output_token, | ||
| last_child_sample.logprobs) | ||
| if last_child_sample.accepted_tokens: |
There was a problem hiding this comment.
nit: if FLAGS.ENABLE_SD:?
| self.propose_cnt = config.propose_cnt | ||
| self.draft_model_config = config.draft_model_config |
There was a problem hiding this comment.
nit: self.config = config?
|
|
||
| # propose draft tokens | ||
| # the function will run the draft model and set draft_tokens and draft_token_probs of each seq | ||
| def set_draft_tokens(self, seq_group_list: List[SequenceGroupMetadata], |
There was a problem hiding this comment.
propose() might be a better name
| ) | ||
| if FLAGS.ENABLE_SD: | ||
| output = _multi_query_cached_kv_attention( | ||
| query, key, value, key_cache, value_cache, input_metadata) |
There was a problem hiding this comment.
Why need to pass key and value? I think the two vars already have copied to key_cache and value_cache by cache_ops.reshape_and_cache(). Maybe I am missing something?
| for seq_group_metadata in seq_group_metadata_list: | ||
| assert len( | ||
| seq_group_metadata.seq_data | ||
| ) == 1, f"Speculative Decoding does nor beam search for now: {len(seq_group_metadata.seq_data)}" |
|
|
||
| # only enable speculative decoding for generation run | ||
| if self.spec_dec_worker and (not scheduler_outputs.prompt_run): | ||
| self.spec_dec_worker.set_draft_tokens(seq_group_metadata_list, |
There was a problem hiding this comment.
in multi GPU inference scenario, will this method be called by all the workers?
do you think it's a better idea to only run on rank 0, and broadcast the tokens to other ranks?
| logger.setLevel("WARNING") | ||
|
|
||
|
|
||
| class SpecDecWorker(Worker): |
There was a problem hiding this comment.
This worker is too tightly coupled with assisted decoding.
Do you think it's a good idea if we abstract an base class for SpD, and move these specific implementations to a concrete class like AssistedSpcDecWorker?
But I believe we could refactor this later.
| tokenizer_revision: Optional[str] = None, | ||
| seed: int = 0, | ||
| gpu_memory_utilization: float = 0.9, | ||
| gpu_memory_utilization: float = 0.8, |
There was a problem hiding this comment.
this is a little hacky to me. what if the sequence is long and could take more than 0.2 gpu memory?
do you think it's a better idea if we actual run the assisted model in profile_num_available_blocks?
| pass | ||
|
|
||
|
|
||
| if triton.__version__ >= "2.1.0": |
There was a problem hiding this comment.
maybe we assert the version should be greater or equal to 2.1.0?
| offs_d[:, None] // x) * stride_k_cache_d + ( | ||
| (start_n + offs_n[None, :]) % | ||
| block_size) * stride_k_cache_bl + ( | ||
| offs_d[:, None] % x) * stride_k_cache_x |
There was a problem hiding this comment.
good job! this would be faster than my version! 👍
| block_mask = tl.where( | ||
| block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) | ||
|
|
||
| for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): |
There was a problem hiding this comment.
I wonder what's special for K and V of the draft tokens, why we need process these tokens separately?
| self.scale, | ||
| self.alibi_slopes, | ||
| ) | ||
| if FLAGS.ENABLE_SD: |
There was a problem hiding this comment.
correct me if I'm wrong, but for assisted decoding, usually the propse_cnt is small (maybe around 4?), which would cause first dimension of q to be small, thus the q@k gemm and qk@v gemm are small. for such cases, does it really worth using Tensor Core for GEMM?
| return ignored | ||
|
|
||
| # only enable speculative decoding for generation run | ||
| if self.spec_dec_worker and (not scheduler_outputs.prompt_run): |
There was a problem hiding this comment.
report a bug here, when we start vllm api server with python3 -m vllm.entrypoints.api_server --model=/path/to/tgt_model/ --draft-model=/path/to/draft/model/ --propose-cnt=5, the server errors out. looks like you forgot set_draft_tokens and accept_tokens in AsyncLLMEngine
Fix url 404 error in doc - vLLM version: v0.9.2 - vLLM main: vllm-project@9ad0a45 Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This is an attempt to implement speculative decoding (paper) in vllm. It is not optimized, not tested (please avoid using it for now). The current design: