Skip to content

[WIP] [BugFix] Forward spec_step_idx in MTP wrappers and eagle proposer/speculator#36910

Closed
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36872
Closed

[WIP] [BugFix] Forward spec_step_idx in MTP wrappers and eagle proposer/speculator#36910
haosdent wants to merge 1 commit intovllm-project:mainfrom
haosdent:fix-36872

Conversation

@haosdent
Copy link
Contributor

Purpose

Fix spec_step_idx not being forwarded from MTP wrapper classes (Qwen3NextMTP, Qwen3_5MTP) to the inner MultiTokenPredictor, and not being passed by the eagle proposer/speculator loops.

MTP models with num_mtp_layers > 1 use spec_step_idx % num_mtp_layers to select which decoder layer to use for each speculative step. Two issues caused the wrong layer to always be selected:

  1. Wrapper classes swallow spec_step_idx: Qwen3NextMTP.forward() and Qwen3_5MTP.forward() accept **kwargs but never forward spec_step_idx to the inner model's forward(), which expects it.

  2. Eagle proposer/speculator never pass spec_step_idx: The proposer loop and speculator loop never include spec_step_idx in model_kwargs or compute_logits calls, so every draft token beyond the first silently uses layer 0 instead of the correct layer.

This follows the existing correct pattern used by DeepSeekMTPModel, ExaoneMoeMTP, and other MTP implementations that explicitly accept and forward spec_step_idx.

Fixes #36872

Test Plan

Test Result

…culator

Fix spec_step_idx not being forwarded from MTP wrapper classes
(Qwen3NextMTP, Qwen3_5MTP) to the inner MultiTokenPredictor, and
not being passed by the eagle proposer/speculator loops. This caused
all draft tokens to always use MTP layer 0 instead of cycling through
layers via spec_step_idx % num_mtp_layers.

Fixes vllm-project#36872

Signed-off-by: haosdent <haosdent@gmail.com>
@mergify mergify bot added qwen Related to Qwen models speculative-decoding v1 bug Something isn't working labels Mar 12, 2026
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix a bug where spec_step_idx is not correctly forwarded in MTP wrappers and the Eagle speculative decoding implementation. The changes correctly propagate this index through the MTP model wrappers and the Eagle proposer's forward and compute_logits calls. However, I've identified a critical issue where an optimized code path in the proposer is missed, which will lead to incorrect behavior. Additionally, there's a high-severity issue in the speculator code that could cause a TypeError when used with non-MTP target models. Please see my detailed comments below.

Comment on lines 383 to +385
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The get_top_tokens method is called without the spec_step_idx argument. When use_local_argmax_reduction is enabled, this will cause the wrong decoder layer to be used for speculative steps beyond the first, leading to incorrect draft tokens. This seems to defeat the purpose of this bug fix for this code path.

To fix this, spec_step_idx should be passed to get_top_tokens. You will also need to update the get_top_tokens method on the MTP models to accept and use this parameter, similar to how compute_logits is being updated.

Suggested change
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states, spec_step_idx=spec_step_idx)
return self.model.compute_logits(hidden_states, spec_step_idx).argmax(dim=-1)

Comment on lines 155 to +162
slot_mappings,
num_tokens_across_dp,
cudagraph_runtime_mode,
spec_step_idx=step,
)
last_hidden_states = last_hidden_states[:num_reqs]
hidden_states = hidden_states[:num_reqs]
logits = self.model.compute_logits(last_hidden_states)
logits = self.model.compute_logits(last_hidden_states, spec_step_idx=step)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The run_model and compute_logits calls now pass spec_step_idx. However, self.model here is the target model, which is not guaranteed to be an MTP model that accepts this argument. If a non-MTP model is used as the target with eagle speculative decoding, this will raise a TypeError as its forward and compute_logits methods may not accept spec_step_idx. A similar issue exists in the propose method.

To make this more robust, you could check if the model's methods support the spec_step_idx parameter before passing it, for example by using inspect.signature.

Copy link
Collaborator

@benchislett benchislett left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, passing spec_step_idx=0 is intentional since we do eagle-style drafting and not multi-mtp drafting. For multi-mtp support, see #33561

@benchislett
Copy link
Collaborator

How did you validate that the fix worked?

@benchislett
Copy link
Collaborator

Closing for now

@haosdent
Copy link
Contributor Author

Thanks @benchislett , I misunderstand the code at here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working qwen Related to Qwen models speculative-decoding v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug]: Gibberish output and collapsing generation throughput with Qwen3.5-35B-A3B-FP8 and speculative decoding enabled

2 participants