Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions vllm/model_executor/models/mimo_v2_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .utils import _merge_multimodal_embeddings, maybe_prefix

# MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports
# only the first layer and only one speculative token.
# only the first layer
Comment on lines 51 to +52

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The comment still states that vLLM only supports the first MTP layer. Since this PR aims to support multiple speculative tokens, this comment should be updated to reflect that multiple layers are now supported (assuming the hardcoded layer count is also addressed).

Suggested change
# MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports
# only the first layer and only one speculative token.
# only the first layer
# MiMo-V2 checkpoints contain multiple MTP layers, and vLLM supports
# multiple speculative tokens by using these layers.

_MIMO_V2_PRO_NUM_MTP_LAYERS = 1
_MIMO_V2_FLASH_NUM_MTP_LAYERS = 1

Expand Down Expand Up @@ -170,10 +170,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
config = vllm_config.model_config.hf_config
spec_cfg = vllm_config.speculative_config
assert spec_cfg is not None
if spec_cfg.num_speculative_tokens != 1:
raise ValueError(
"MiMo-V2 MTP in vLLM only supports num_speculative_tokens=1."
)
num_mtp_layers = 1

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable num_mtp_layers is still hardcoded to 1. This contradicts the PR's objective of supporting num_speculative_tokens > 1. If num_mtp_layers remains 1, only the first MTP layer will be initialized and loaded. When num_speculative_tokens > 1, the forward method (line 204) will reuse this single layer for all speculative steps due to the modulo operation (spec_step_idx % 1). This is mathematically incorrect for MTP architectures like DeepSeek/MiMo where each speculative step typically requires a distinct layer trained for that specific offset. You should derive num_mtp_layers from the model configuration (e.g., config.num_nextn_predict_layers) or set it based on spec_cfg.num_speculative_tokens while ensuring it does not exceed the model's actual capacity.

Suggested change
num_mtp_layers = 1
num_mtp_layers = spec_cfg.num_speculative_tokens


self.num_mtp_layers = num_mtp_layers
Expand Down Expand Up @@ -203,10 +199,10 @@ def forward(
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token."
if inputs_embeds is None:
inputs_embeds = self.embed_input_ids(input_ids)
return self.mtp.layers[str(spec_step_idx)](
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.mtp.layers[str(current_step_idx)](
Comment on lines +204 to +205

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Load the remaining MTP layers for multi-step drafts

When num_speculative_tokens > 1, this still sets self.num_mtp_layers to 1, so every spec_step_idx maps back to model.mtp.layers.0; load_weights then ignores the checkpoint's later MTP layers. The MiMo-V2.5-Pro model card lists 3 MTP layers and its deployment example uses 3 speculative steps (https://huggingface.co/XiaomiMiMo/MiMo-V2.5-Pro/blob/main/README.md), so enabling multi-token drafting here runs all draft steps through the first-layer weights instead of the trained layer sequence, which makes the new >1 support materially incorrect/low-acceptance. Please instantiate/load the available MTP layers or keep rejecting num_speculative_tokens > 1.

Useful? React with 👍 / 👎.

inputs_embeds, positions, previous_hidden_states
)

Expand All @@ -216,7 +212,6 @@ def compute_logits(
lm_head: ParallelLMHead,
spec_step_idx: int = 0,
) -> torch.Tensor:
assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token."
return self.logits_processor(lm_head, hidden_states)


Expand Down Expand Up @@ -245,7 +240,6 @@ def forward(
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token."
return self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
Expand All @@ -255,7 +249,6 @@ def compute_logits(
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token."
return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx)

def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
Expand Down
Loading