-
Notifications
You must be signed in to change notification settings - Fork 538
[feat] add draft_model spec_decode #4003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: 01267596 <[email protected]>
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for speculative decoding using a draft model. The implementation involves adding a DraftModelProposer and refactoring the existing EagleProposer into a base class to share common logic. The changes are extensive, touching attention utilities, scheduling configuration, and the model runner.
My review has identified a critical bug in the shared _propose method that will cause a crash when using the new draft_model method due to incorrect handling of hidden states. I have also found a high-severity issue related to logging that could impact production environments. I recommend addressing these issues to ensure the feature is robust and maintainable.
| # # Replace the last token with the next token. | ||
| # # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] | ||
| # self.input_ids[last_token_indices] = next_token_ids | ||
| self.set_input_ids_first_pass(target_token_ids, next_token_ids, num_tokens, last_token_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _propose method in the base class SpecDecodeBaseProposer is not fully generic and will raise an exception when used by DraftModelProposer. The DraftModelProposer calls _propose with target_hidden_states=None, but _propose attempts to use this value unconditionally, which will lead to a crash.
Specifically:
- The block for
SpecDcodeType.EAGLE3(line 437) accessestarget_hidden_stateswithout checking if it'sNone. - The assignment
self.hidden_states[:num_tokens] = target_hidden_states(line 496) will fail whentarget_hidden_statesisNone.
This is a critical issue that will prevent draft_model speculative decoding from working. Since the problematic lines are not part of this diff, I recommend either modifying _propose to be fully generic by adding the necessary guards (e.g., if self.pass_hidden_states_to_model:) or overriding _propose in DraftModelProposer with a simplified implementation that doesn't handle hidden states.
| new_token_ids = extend_flat_seqs( | ||
| seqs=input_token_ids, end_locs=query_end_locs, new_vals=next_token_ids | ||
| ) | ||
| logger.warning("new_token_ids: {}".format(new_token_ids)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This log message appears to be for debugging. Using logger.warning for diagnostic information can flood the logs and obscure actual warnings. Please use logger.debug instead.
| logger.warning("new_token_ids: {}".format(new_token_ids)) | |
| logger.debug("new_token_ids: {}".format(new_token_ids)) |
Signed-off-by: 01267596 <[email protected]>
Signed-off-by: 01267596 <[email protected]>
04d7929 to
458036d
Compare
Signed-off-by: 01267596 <[email protected]>
Signed-off-by: 01267596 <[email protected]>
What this PR does / why we need it?
This PR implements the feature of draft_madel speculative decode, and the corresponding RFC is here:#3585
This PR depends on adjusting the code in VLLM, specific adjustments be made here:https://github.com/HF-001/vllm/pull/1/files#diff-645d58630d5acf3a0b07226bfef1e890a584c32502ab97c3d4642070f39a783c, or this PR: vllm-project/vllm#24322
Does this PR introduce any user-facing change?
How was this patch tested?
export CUDA_VISIBLE_DEVICES=7
export TP=1
export MODEL_PATH=/model/qwen3-0.6b
export MODEL_NAME=qwen3-0.6b
export PORT=10113
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.9 --max-model-len 32768 --trust-remote-code --seed 42 --speculative_config '{"method":"draft_model","model":"/model/qwen3-0.6b","num_speculative_tokens":3,"draft_tensor_parallel_size":1, "disable_padded_drafter_batch":true}'