-
-
Notifications
You must be signed in to change notification settings - Fork 17.9k
[SpecDecoding] extend mtp support for mimo 2.5 #41905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -49,7 +49,7 @@ | |||||
| from .utils import _merge_multimodal_embeddings, maybe_prefix | ||||||
|
|
||||||
| # MiMo-V2 checkpoints contain multiple MTP layers, but vLLM currently supports | ||||||
| # only the first layer and only one speculative token. | ||||||
| # only the first layer | ||||||
| _MIMO_V2_PRO_NUM_MTP_LAYERS = 1 | ||||||
| _MIMO_V2_FLASH_NUM_MTP_LAYERS = 1 | ||||||
|
|
||||||
|
|
@@ -170,10 +170,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: | |||||
| config = vllm_config.model_config.hf_config | ||||||
| spec_cfg = vllm_config.speculative_config | ||||||
| assert spec_cfg is not None | ||||||
| if spec_cfg.num_speculative_tokens != 1: | ||||||
| raise ValueError( | ||||||
| "MiMo-V2 MTP in vLLM only supports num_speculative_tokens=1." | ||||||
| ) | ||||||
| num_mtp_layers = 1 | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The variable
Suggested change
|
||||||
|
|
||||||
| self.num_mtp_layers = num_mtp_layers | ||||||
|
|
@@ -203,10 +199,10 @@ def forward( | |||||
| inputs_embeds: torch.Tensor | None = None, | ||||||
| spec_step_idx: int = 0, | ||||||
| ) -> torch.Tensor: | ||||||
| assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." | ||||||
| if inputs_embeds is None: | ||||||
| inputs_embeds = self.embed_input_ids(input_ids) | ||||||
| return self.mtp.layers[str(spec_step_idx)]( | ||||||
| current_step_idx = spec_step_idx % self.num_mtp_layers | ||||||
| return self.mtp.layers[str(current_step_idx)]( | ||||||
|
Comment on lines
+204
to
+205
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When Useful? React with 👍 / 👎. |
||||||
| inputs_embeds, positions, previous_hidden_states | ||||||
| ) | ||||||
|
|
||||||
|
|
@@ -216,7 +212,6 @@ def compute_logits( | |||||
| lm_head: ParallelLMHead, | ||||||
| spec_step_idx: int = 0, | ||||||
| ) -> torch.Tensor: | ||||||
| assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." | ||||||
| return self.logits_processor(lm_head, hidden_states) | ||||||
|
|
||||||
|
|
||||||
|
|
@@ -245,7 +240,6 @@ def forward( | |||||
| inputs_embeds: torch.Tensor | None = None, | ||||||
| spec_step_idx: int = 0, | ||||||
| ) -> torch.Tensor: | ||||||
| assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." | ||||||
| return self.model( | ||||||
| input_ids, positions, hidden_states, inputs_embeds, spec_step_idx | ||||||
| ) | ||||||
|
|
@@ -255,7 +249,6 @@ def compute_logits( | |||||
| hidden_states: torch.Tensor, | ||||||
| spec_step_idx: int = 0, | ||||||
| ) -> torch.Tensor | None: | ||||||
| assert spec_step_idx == 0, "MiMo-V2 MTP only supports one speculative token." | ||||||
| return self.model.compute_logits(hidden_states, self.lm_head, spec_step_idx) | ||||||
|
|
||||||
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment still states that vLLM only supports the first MTP layer. Since this PR aims to support multiple speculative tokens, this comment should be updated to reflect that multiple layers are now supported (assuming the hardcoded layer count is also addressed).