[WIP] [BugFix] Forward spec_step_idx in MTP wrappers and eagle proposer/speculator#36910
[WIP] [BugFix] Forward spec_step_idx in MTP wrappers and eagle proposer/speculator#36910haosdent wants to merge 1 commit intovllm-project:mainfrom
Conversation
…culator Fix spec_step_idx not being forwarded from MTP wrapper classes (Qwen3NextMTP, Qwen3_5MTP) to the inner MultiTokenPredictor, and not being passed by the eagle proposer/speculator loops. This caused all draft tokens to always use MTP layer 0 instead of cycling through layers via spec_step_idx % num_mtp_layers. Fixes vllm-project#36872 Signed-off-by: haosdent <haosdent@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request aims to fix a bug where spec_step_idx is not correctly forwarded in MTP wrappers and the Eagle speculative decoding implementation. The changes correctly propagate this index through the MTP model wrappers and the Eagle proposer's forward and compute_logits calls. However, I've identified a critical issue where an optimized code path in the proposer is missed, which will lead to incorrect behavior. Additionally, there's a high-severity issue in the speculator code that could cause a TypeError when used with non-MTP target models. Please see my detailed comments below.
| if self.use_local_argmax_reduction: | ||
| return self.model.get_top_tokens(hidden_states) | ||
| return self.model.compute_logits(hidden_states).argmax(dim=-1) | ||
| return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) |
There was a problem hiding this comment.
The get_top_tokens method is called without the spec_step_idx argument. When use_local_argmax_reduction is enabled, this will cause the wrong decoder layer to be used for speculative steps beyond the first, leading to incorrect draft tokens. This seems to defeat the purpose of this bug fix for this code path.
To fix this, spec_step_idx should be passed to get_top_tokens. You will also need to update the get_top_tokens method on the MTP models to accept and use this parameter, similar to how compute_logits is being updated.
| if self.use_local_argmax_reduction: | |
| return self.model.get_top_tokens(hidden_states) | |
| return self.model.compute_logits(hidden_states).argmax(dim=-1) | |
| return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) | |
| if self.use_local_argmax_reduction: | |
| return self.model.get_top_tokens(hidden_states, spec_step_idx=spec_step_idx) | |
| return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1) |
| slot_mappings, | ||
| num_tokens_across_dp, | ||
| cudagraph_runtime_mode, | ||
| spec_step_idx=step, | ||
| ) | ||
| last_hidden_states = last_hidden_states[:num_reqs] | ||
| hidden_states = hidden_states[:num_reqs] | ||
| logits = self.model.compute_logits(last_hidden_states) | ||
| logits = self.model.compute_logits(last_hidden_states, spec_step_idx=step) |
There was a problem hiding this comment.
The run_model and compute_logits calls now pass spec_step_idx. However, self.model here is the target model, which is not guaranteed to be an MTP model that accepts this argument. If a non-MTP model is used as the target with eagle speculative decoding, this will raise a TypeError as its forward and compute_logits methods may not accept spec_step_idx. A similar issue exists in the propose method.
To make this more robust, you could check if the model's methods support the spec_step_idx parameter before passing it, for example by using inspect.signature.
benchislett
left a comment
There was a problem hiding this comment.
AFAIK, passing spec_step_idx=0 is intentional since we do eagle-style drafting and not multi-mtp drafting. For multi-mtp support, see #33561
|
How did you validate that the fix worked? |
|
Closing for now |
|
Thanks @benchislett , I misunderstand the code at here. |
Purpose
Fix
spec_step_idxnot being forwarded from MTP wrapper classes (Qwen3NextMTP,Qwen3_5MTP) to the innerMultiTokenPredictor, and not being passed by the eagle proposer/speculator loops.MTP models with
num_mtp_layers > 1usespec_step_idx % num_mtp_layersto select which decoder layer to use for each speculative step. Two issues caused the wrong layer to always be selected:Wrapper classes swallow
spec_step_idx:Qwen3NextMTP.forward()andQwen3_5MTP.forward()accept**kwargsbut never forwardspec_step_idxto the inner model'sforward(), which expects it.Eagle proposer/speculator never pass
spec_step_idx: The proposer loop and speculator loop never includespec_step_idxinmodel_kwargsorcompute_logitscalls, so every draft token beyond the first silently uses layer 0 instead of the correct layer.This follows the existing correct pattern used by
DeepSeekMTPModel,ExaoneMoeMTP, and other MTP implementations that explicitly accept and forwardspec_step_idx.Fixes #36872
Test Plan
Test Result