feat: native MTP speculative decoding for Qwen3.5#990
feat: native MTP speculative decoding for Qwen3.5#990AirRunner wants to merge 10 commits intoml-explore:mainfrom
Conversation
|
Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far! |
Thanks! Yes The Qwen3.5-specific part is So the speculative-decoding logic lives in one place, and adding a new model is just a matter of exposing the right interface. For GLM5 specifically, it would certainly be feasible yeah. But I don't think there is even a glm5.py currently. |
fb12ea6 to
a66d242
Compare
|
Great work on this! We've been using it on M2 Ultra (128GB) with all three Qwen3.5 sizes and it works well. MoE fix neededThe PR works out of the box for the dense 27B, but MoE models (35B-A3B, 122B-A10B) fail conversion with "768 parameters not in model". The MTP layer's expert weights use unfused per-expert format ( Fix (add to # Stack per-expert MTP weights into switch_mlp format.
mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0)
num_experts = self.language_model.args.num_experts
for l in range(mtp_num):
prefix = f"language_model.mtp.layers.{l}.mlp"
test_key = f"{prefix}.experts.0.gate_proj.weight"
if test_key in new_weights:
for n in ["gate_proj", "up_proj", "down_proj"]:
to_join = [
new_weights.pop(f"{prefix}.experts.{e}.{n}.weight")
for e in range(num_experts)
]
new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)Also needs Full fix on our fork: Thump604/mlx-lm@04a4383 Benchmark results (M2 Ultra, greedy)
Pre-converted models with MTP weights: Thump604/Qwen3.5-27B-MLX-8bit, 35B, 122B |
|
@Thump604 Thanks for the report and the fix! I've integrated it in AirRunner/mlx-lm@8d06796 with a credit. Also, what acceptance rates did you get with MoE? I'm curious if it's somehow correlated to the speedup. |
|
Thanks for the quick integration! Here are the acceptance rates derived from our benchmarks (M2 Ultra 128GB, greedy/temp=0):
At temp=0.6 (production sampling), 122B drops to 1.05x (~5% acceptance). So yes — it does correlate with architecture. MoE acceptance rates are significantly lower than dense. My hypothesis: the MTP layer contains a full 256-expert MoE routing step (same expert count as the backbone), but with only a single layer of context depth it struggles to predict the correct expert routing. The dense 27B's MTP layer is a standard transformer layer — much simpler prediction task, much higher acceptance. The fp16 27B was actually 0.61x (slower) — bandwidth-saturated, the MTP overhead exceeds the savings. 8-bit quantization is the sweet spot where MTP helps most. |
8d06796 to
85583a0
Compare
|
Hey @AirRunner — thanks for integrating the MoE sanitize fix! The PR has merge conflicts with main now though. Would you be able to rebase? Happy to help if needed. Also, any thoughts on tagging a maintainer for review? This has been open since March 13 with zero maintainer engagement. The implementation is solid (8 tests, code review feedback addressed, MoE fix integrated), just needs someone to look at it. |
|
Hey @Goekdeniz-Guelmez, would you be able to take a look when you get a chance? Quick summary: 8 unit tests, code review feedback from @janhilgard and @Thump604, rebased on main. |
|
Subject: Successfully running Qwen3.5-27B locally with workaround
Thanks for this PR! I was able to get Qwen3.5-27B working locally with MLX, but encountered an issue that might help others. The Bug I Was AddressingWhen trying to use the model with a client that passes short model IDs, I encountered: The error message was misleading - it suggested an expired token, but the real issue was a config/weight mismatch described below. Issue EncounteredThe model failed to load with: Root CauseThe model's {
"text_config": {
"mtp_num_hidden_layers": 1
}
}However, the actual WorkaroundSet cat config.json | jq '.text_config.mtp_num_hidden_layers = 0' > config_fixed.json
mv config_fixed.json config.jsonOther Configuration NotesFor anyone trying this setup:
SuggestionIt might be helpful to add a check/warning when:
This would help users identify config/weight mismatches more quickly and avoid confusing auth error messages. |
|
@layer4down thanks for the write-up! You're right, To actually use MTP acceleration, the model needs to be re-quantized including the MTP layers using this branch. As you suggested I just pushed a fix that raises a clear |
|
@angeloskath -- this PR has been open 11 days with no maintainer review. AirRunner rebased on 2026-03-21, all conflicts resolved, 8 unit tests passing. We've been running this in production on M2 Ultra 128GB since day one. Qwen3.5-122B-A10B-VLM-MTP-5bit, 24/7 inference serving coding agents. MTP acceptance rates:
MoE acceptance rates are lower because a single MTP layer can't predict expert routing well. Still a net win for the latency-sensitive use case. The MoE sanitize fix (commit 8d06796) is essential for Qwen3.5 MoE models -- without it, 768 MTP parameters are silently missing. We've also published pre-converted VLM+MTP models on HuggingFace that depend on this code path. Would be great to get this reviewed and merged so the community models work out of the box. |
|
Can we at the reviewer again? it's an important update for qwen3.5 |
|
@angeloskath @awni — this PR has been open 17 days with no maintainer review or feedback. Multiple community members have asked for review (AirRunner, ourselves, cresseelia). Is there a concern with the approach, scope, or implementation that's blocking review? We're happy to help address any issues — split the PR, rework the API surface, add tests, whatever is needed. We're running this in production on 122B and have validated it across three Qwen3.5 model sizes. The community is actively hitting the config/weight mismatch that AirRunner already fixed in this branch (layer4down's report above). Without this merged, users have to manually patch config.json to use MTP on Qwen3.5 models. If the PR needs changes or a different direction, we'd rather know than wait. Let us know how we can help move this forward. |
|
as you can see 6 files have been changes/added alongside 700 lines of added code. This is a PR that has big changes int he codebase itself. Reviewing and (correctly) implementing it will take time. 17 days not long enough. My full weight fine-tuning PR took multiple weeks to be merged. Just keep it open, update it and please be patient. Adding completely new features will take long. |
|
@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it. To make the review easier, we can split this into two smaller PRs: PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently. PR 2 — Generation + tests (~420 lines): Would splitting it this way help with the review process? Happy to do the work if so. @AirRunner — would you be open to splitting the PR this way? |
Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights. Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable).
Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls. After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state. On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before. Acceptance rate ~64% average / ~85% on 100-token run.
- Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate) - Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache) - Disable batch generation for MTP models (draft/verify loop requires single-sequence processing) Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator
…t_predicate) - Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden). - Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy). 27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x).
Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding. MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional.
8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers. - MTP module instantiation and cache creation - return_hidden shape and pre-norm verification - mtp_forward output shape - quant_predicate excludes mtp.fc - Token identity: mtp_generate_step == generate_step (greedy) - End-to-end mtp_generate_step completion
Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect.
MTP layers in MoE models (35B-A3B, 122B-A10B) ship unfused per-expert weights (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) whereas the backbone uses pre-fused switch_mlp format. Conversion was failing with ~768 parameters not in model.
Add a stacking loop in qwen3_5_moe.py sanitize() after the backbone expert loop, mirroring the same pattern for MTP prefixes.
Co-authored-by: Thump604 <thump604@users.noreply.github.com>
When mtp_num_hidden_layers > 0 but the model weights contain no MTP parameters, the previous error was a cryptic 'Missing N parameters'. Now raises a ValueError with an actionable message.
a358ace to
04da246
Compare
@janhilgard I'm not sure splitting would actually help the review here actually? The PRs you suggest wouldn't be reviewable in isolation, because the architecture changes only make sense in the context of how (Also, 183 of the 683 added lines are just unit tests). That said, I'm open to whatever helps, happy to reorganize if it does :). |
|
@angeloskath @awni — this PR has been open 20+ days with no maintainer review. It is the foundation for MTP speculative decoding on Qwen3.5 models, which several of us are using in production. My PR #1085 (probabilistic acceptance, 2.3x throughput on 122B) builds directly on top of it. AirRunner's implementation is solid: 8 tests, 80.6% acceptance on M4 Pro. Is there a concern about scope or approach blocking review? |
|
@Thump604 can you stop pinging people? The more annoying you are the less likely anyone is going to respond. |
|
Great work — I've been running MTP on Qwen3.5 MoE models in production (M3 Ultra, 256 GB) and wanted to share findings that might explain the low MoE acceptance rates. BF16 MTP weights are critical for MoE acceptanceYour if path.endswith("mtp.fc"):
return FalseBut the MTP transformer layer (attention, MLP, norms) still gets quantized. We found that quantized MTP weights give near-0% acceptance on MoE models — the quantization error compounds through the expert routing prediction. Fix: exclude ALL MTP weights from quantization: if "mtp." in path:
return FalseOur MoE results with BF16 MTP weights
vs your MoE benchmarks (quantized MTP weights):
The difference is stark: BF16 MTP weights → 79-85% acceptance, quantized → 5-11%. Batch auto-skipYour PR sets if len(active_batch) > 1:
# Skip MTP, fall back to standard generation
return _orig_step(input_tokens, cache)This gives the best of both worlds:
Instead of disabling batching entirely, you could dynamically switch. Weight extractionWe extract BF16 MTP weights from the original HF model (not the quantized MLX model) with a dedicated script. See vllm-mlx PR #245 for the
Happy to collaborate on getting BF16 MTP weights into the standard conversion pipeline. |
Summary
Qwen3.5 checkpoints ship with a built-in Multi-Token Prediction head (
mtp_num_hidden_layers: 1in config) that predicts token t+2 from the backbone hidden state at t and the embedding of token t+1. This PR adds support for using it as a native speculative decoding mechanism. No separate draft model needed, at minimal extra compute (1 extra transformer layer).Changes
mlx_lm/models/qwen3_5.py: MTP head module, hidden state passthrough, SSM state rollback supportmlx_lm/generate.py: MTP generation loop with draft/verify,--mtpCLI flagmlx_lm/models/cache.py: SSM state snapshot slot for rollback on draft rejectionmlx_lm/server.py:--mtpflag, disable batching when MTP is activetests/test_mtp.py: 8 unit testsHow it works
Each backbone forward pass returns both logits and pre-norm hidden states. The MTP head fuses
pre_fc_norm_hidden(h_t)andpre_fc_norm_embedding(embed(t+1))via a linear projection, runs one full-attention transformer layer, and produces draft logits through the sharedlm_head.The generation loop verifies drafts by feeding
[confirmed_tok, draft_tok]to the backbone withn_confirmed=1. This causesGatedDeltaNetto snapshot its conv/SSM state after the confirmed token. On acceptance, both tokens are emitted. On rejection, the SSM state is rolled back to the snapshot and KV caches are trimmed.Results (with Qwen3.5-27B 4-bit on M4 Pro)
Usage
Note: Requires a checkpoint converted with MTP weights (the default
sanitize()previously stripped them). Re-convert from HF with this branch to preservemtp.*weights.Known limitation
MTP disables batch serving (
is_batchable=False). As an improvement one could dynamically switch between MTP for single requests and batch generation for concurrent requests.Test plan
Relates to #872 — cc @janhilgard
Update - Production benchmarks
Benchmarks from Thump604 (M2 Ultra 128GB, greedy/temp=0):
My initial 80.6% acceptance was measured on simple short prompts, which likely makes the difference.
MoE acceptance rates are structurally lower: the MTP layer must predict expert routing with only 1 layer of context depth.
fp16 models are net negative (0.61x on M2 Ultra). MTP overhead exceeds savings when bandwidth is saturated. So quantized models are the intended use case.