diff --git a/vllm/model_executor/models/mimo_v2_mtp.py b/vllm/model_executor/models/mimo_v2_mtp.py index 442f4986b669..c863cedaeb88 100644 --- a/vllm/model_executor/models/mimo_v2_mtp.py +++ b/vllm/model_executor/models/mimo_v2_mtp.py @@ -49,7 +49,7 @@ from .utils import _merge_multimodal_embeddings, maybe_prefix # MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports -# only the first layer and only one speculative token. +# only the first layer _MIMO_V2_PRO_NUM_MTP_LAYERS = 1 _MIMO_V2_FLASH_NUM_MTP_LAYERS = 1 @@ -170,10 +170,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config spec_cfg = vllm_config.speculative_config assert spec_cfg is not None - if spec_cfg.num_speculative_tokens != 1: - raise ValueError( - "MiMo-V2 MTP in vLLM only supports num_speculative_tokens=1." - ) num_mtp_layers = 1 self.num_mtp_layers = num_mtp_layers @@ -203,10 +199,10 @@ def forward( inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: - assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." if inputs_embeds is None: inputs_embeds = self.embed_input_ids(input_ids) - return self.mtp.layers[str(spec_step_idx)]( + current_step_idx = spec_step_idx % self.num_mtp_layers + return self.mtp.layers[str(current_step_idx)]( inputs_embeds, positions, previous_hidden_states ) @@ -216,7 +212,6 @@ def compute_logits( lm_head: ParallelLMHead, spec_step_idx: int = 0, ) -> torch.Tensor: - assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." return self.logits_processor(lm_head, hidden_states) @@ -245,7 +240,6 @@ def forward( inputs_embeds: torch.Tensor | None = None, spec_step_idx: int = 0, ) -> torch.Tensor: - assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." return self.model( input_ids, positions, hidden_states, inputs_embeds, spec_step_idx ) @@ -255,7 +249,6 @@ def compute_logits( hidden_states: torch.Tensor, spec_step_idx: int = 0, ) -> torch.Tensor | None: - assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: