Skip to content

[SpecDecoding] extend mtp support for mimo 2.5#41905

Merged
jeejeelee merged 3 commits into
vllm-project:mainfrom
ZJY0516:mimo-mtp
May 9, 2026
Merged

[SpecDecoding] extend mtp support for mimo 2.5#41905
jeejeelee merged 3 commits into
vllm-project:mainfrom
ZJY0516:mimo-mtp

Conversation

@ZJY0516

@ZJY0516 ZJY0516 commented May 7, 2026

Copy link
Copy Markdown
Member

Purpose

support num_speculative_tokens>1 for mimo 2.5 mtp

Test Plan

vllm serve XiaomiMiMo/MiMo-V2.5 -tp 4 --trust-remote-code 
--speculative_config '{"method":"mtp","num_speculative_tokens":3}' 
--no-async-scheduling
lm_eval --model local-completions --model_args "model=XiaomiMiMo/MiMo-V2.5,base_url=http://0.0.0.0:8000/v1/completions,tokenized_requests=False,tokenizer_backend=None,num_concurrent=256,timeout=5000,max_length=40960" --tasks gsm8k --num_fewshot 5 --gen_kwargs max_gen_toks=5120

Test Result

w/o mtp

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9530 ± 0.0058
strict-match 5 exact_match 0.9538 ± 0.0058

mtp3

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9454 ± 0.0063
strict-match 5 exact_match 0.9469 ± 0.0062

mtp3 w/o async-scheduling

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9522 ± 0.0059
strict-match 5 exact_match 0.9522 ± 0.0059

Note: async shceduling may affect mtp accuracy


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>

@claude claude Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 16e86d293a

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +204 to +205
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.mtp.layers[str(current_step_idx)](

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Load the remaining MTP layers for multi-step drafts

When num_speculative_tokens > 1, this still sets self.num_mtp_layers to 1, so every spec_step_idx maps back to model.mtp.layers.0; load_weights then ignores the checkpoint's later MTP layers. The MiMo-V2.5-Pro model card lists 3 MTP layers and its deployment example uses 3 speculative steps (https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro/blob/main/README.md), so enabling multi-token drafting here runs all draft steps through the first-layer weights instead of the trained layer sequence, which makes the new >1 support materially incorrect/low-acceptance. Please instantiate/load the available MTP layers or keep rejecting num_speculative_tokens > 1.

Useful? React with 👍 / 👎.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the restriction limiting MiMo-V2 MTP to a single speculative token and implements modulo-based indexing for MTP layers. Feedback indicates that the module-level documentation needs updating to reflect this change. Additionally, a critical issue was identified where num_mtp_layers remains hardcoded to 1, which would cause all speculative steps to incorrectly reuse the first layer instead of utilizing the appropriate layers for each step.

Comment on lines 51 to +52
# MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports
# only the first layer and only one speculative token.
# only the first layer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment still states that vLLM only supports the first MTP layer. Since this PR aims to support multiple speculative tokens, this comment should be updated to reflect that multiple layers are now supported (assuming the hardcoded layer count is also addressed).

Suggested change
# MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports
# only the first layer and only one speculative token.
# only the first layer
# MiMo-V2 checkpoints contain multiple MTP layers, and vLLM supports
# multiple speculative tokens by using these layers.

raise ValueError(
"MiMo-V2 MTP in vLLM only supports num_speculative_tokens=1."
)
num_mtp_layers = 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable num_mtp_layers is still hardcoded to 1. This contradicts the PR's objective of supporting num_speculative_tokens > 1. If num_mtp_layers remains 1, only the first MTP layer will be initialized and loaded. When num_speculative_tokens > 1, the forward method (line 204) will reuse this single layer for all speculative steps due to the modulo operation (spec_step_idx % 1). This is mathematically incorrect for MTP architectures like DeepSeek/MiMo where each speculative step typically requires a distinct layer trained for that specific offset. You should derive num_mtp_layers from the model configuration (e.g., config.num_nextn_predict_layers) or set it based on spec_cfg.num_speculative_tokens while ensuring it does not exceed the model's actual capacity.

Suggested change
num_mtp_layers = 1
num_mtp_layers = spec_cfg.num_speculative_tokens

@jeejeelee

Copy link
Copy Markdown
Member

Can you provide the accuracy result here?

@ZJY0516

ZJY0516 commented May 7, 2026

Copy link
Copy Markdown
Member Author

Can you provide the accuracy result here?

Updated, PTAL

@jeejeelee jeejeelee left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label May 7, 2026
@jeejeelee jeejeelee enabled auto-merge (squash) May 7, 2026 13:40
@jeejeelee jeejeelee merged commit 2ee8c2a into vllm-project:main May 9, 2026
60 checks passed
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request May 11, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
weifang231 pushed a commit to weifang231/eb-vllm that referenced this pull request May 13, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
mfylcek pushed a commit to mfylcek/vllm that referenced this pull request May 19, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
jhu960213 pushed a commit to jhu960213/vllm that referenced this pull request May 20, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
mvanhorn pushed a commit to mvanhorn/vllm that referenced this pull request Jun 4, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
knight0528 pushed a commit to knight0528/vllm that referenced this pull request Jun 8, 2026
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants