Skip to content

feat: native MTP speculative decoding for Qwen3.5#990

Open
AirRunner wants to merge 10 commits intoml-explore:mainfrom
AirRunner:feat/mtp-native
Open

feat: native MTP speculative decoding for Qwen3.5#990
AirRunner wants to merge 10 commits intoml-explore:mainfrom
AirRunner:feat/mtp-native

Conversation

@AirRunner
Copy link
Copy Markdown

@AirRunner AirRunner commented Mar 13, 2026

Summary

Qwen3.5 checkpoints ship with a built-in Multi-Token Prediction head (mtp_num_hidden_layers: 1 in config) that predicts token t+2 from the backbone hidden state at t and the embedding of token t+1. This PR adds support for using it as a native speculative decoding mechanism. No separate draft model needed, at minimal extra compute (1 extra transformer layer).

Changes

  • mlx_lm/models/qwen3_5.py: MTP head module, hidden state passthrough, SSM state rollback support
  • mlx_lm/generate.py: MTP generation loop with draft/verify, --mtp CLI flag
  • mlx_lm/models/cache.py: SSM state snapshot slot for rollback on draft rejection
  • mlx_lm/server.py: --mtp flag, disable batching when MTP is active
  • tests/test_mtp.py: 8 unit tests

How it works

Each backbone forward pass returns both logits and pre-norm hidden states. The MTP head fuses pre_fc_norm_hidden(h_t) and pre_fc_norm_embedding(embed(t+1)) via a linear projection, runs one full-attention transformer layer, and produces draft logits through the shared lm_head.

The generation loop verifies drafts by feeding [confirmed_tok, draft_tok] to the backbone with n_confirmed=1. This causes GatedDeltaNet to snapshot its conv/SSM state after the confirmed token. On acceptance, both tokens are emitted. On rejection, the SSM state is rolled back to the snapshot and KV caches are trimmed.

Results (with Qwen3.5-27B 4-bit on M4 Pro)

Metric Standard MTP
Throughput 15.3 tok/s 23.3 tok/s (1.52x)
Acceptance rate 80.6% avg
Identity test Pass (greedy MTP == standard)

Usage

mlx_lm.generate --model <path> --mtp
mlx_lm.server   --model <path> --mtp

Note: Requires a checkpoint converted with MTP weights (the default sanitize() previously stripped them). Re-convert from HF with this branch to preserve mtp.* weights.

Known limitation

MTP disables batch serving (is_batchable=False). As an improvement one could dynamically switch between MTP for single requests and batch generation for concurrent requests.

Test plan

  • Unit tests (8/8 passing) — module existence, cache creation, shapes, pre-norm hidden states, quant predicate, generation identity, end-to-end
  • Manual validation on Qwen3.5-27B (4-bit) and Qwen3.5-0.8B (4-bit)
  • Not yet tested on MoE variants (though code paths are shared with backbone)

Relates to #872 — cc @janhilgard


Update - Production benchmarks

Benchmarks from Thump604 (M2 Ultra 128GB, greedy/temp=0):

Model Baseline tok/s MTP tok/s Speedup Implied acceptance
27B dense 8-bit 20.6 27.1 1.32x ~32%
35B-A3B MoE 8-bit 74.4 82.3 1.11x ~11%
122B-A10B MoE 5-bit 43.0 46.7 1.09x ~5% (temp=0.6)

My initial 80.6% acceptance was measured on simple short prompts, which likely makes the difference.

MoE acceptance rates are structurally lower: the MTP layer must predict expert routing with only 1 layer of context depth.

fp16 models are net negative (0.61x on M2 Ultra). MTP overhead exceeds savings when bandwidth is saturated. So quantized models are the intended use case.

@vlbosch
Copy link
Copy Markdown

vlbosch commented Mar 15, 2026

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

@AirRunner
Copy link
Copy Markdown
Author

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

Thanks!

Yes mtp_generate_step() is fully reusable, but each model still needs its own model-side interface.

The Qwen3.5-specific part is MTPDecoderLayer, mtp_forward (produce draft logits), make_mtp_cache and the backbone's __call__ (with n_confirmed for SSM state rollback on hybrid models).

So the speculative-decoding logic lives in one place, and adding a new model is just a matter of exposing the right interface.

For GLM5 specifically, it would certainly be feasible yeah. But I don't think there is even a glm5.py currently.

@Thump604
Copy link
Copy Markdown

Great work on this! We've been using it on M2 Ultra (128GB) with all three Qwen3.5 sizes and it works well.

MoE fix needed

The PR works out of the box for the dense 27B, but MoE models (35B-A3B, 122B-A10B) fail conversion with "768 parameters not in model". The MTP layer's expert weights use unfused per-expert format (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) unlike the backbone which uses pre-fused gate_up_proj. The existing sanitize() in qwen3_5_moe.py only handles backbone expert stacking.

Fix (add to qwen3_5_moe.py sanitize(), after the backbone expert stacking loop):

# Stack per-expert MTP weights into switch_mlp format.
mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0)
num_experts = self.language_model.args.num_experts
for l in range(mtp_num):
    prefix = f"language_model.mtp.layers.{l}.mlp"
    test_key = f"{prefix}.experts.0.gate_proj.weight"
    if test_key in new_weights:
        for n in ["gate_proj", "up_proj", "down_proj"]:
            to_join = [
                new_weights.pop(f"{prefix}.experts.{e}.{n}.weight")
                for e in range(num_experts)
            ]
            new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)

Also needs import mlx.core as mx at the top of the file.

Full fix on our fork: Thump604/mlx-lm@04a4383

Benchmark results (M2 Ultra, greedy)

Model Baseline MTP Speedup
27B-8bit (dense) 20.6 tok/s 27.1 tok/s 1.32x
35B-A3B-8bit (MoE) 74.4 tok/s 82.3 tok/s 1.11x
122B-A10B-5bit (MoE) 43.0 tok/s 46.7 tok/s 1.09x

Pre-converted models with MTP weights: Thump604/Qwen3.5-27B-MLX-8bit, 35B, 122B

@AirRunner
Copy link
Copy Markdown
Author

@Thump604 Thanks for the report and the fix! I've integrated it in AirRunner/mlx-lm@8d06796 with a credit.

Also, what acceptance rates did you get with MoE? I'm curious if it's somehow correlated to the speedup.

@Thump604
Copy link
Copy Markdown

Thanks for the quick integration!

Here are the acceptance rates derived from our benchmarks (M2 Ultra 128GB, greedy/temp=0):

Model Baseline tok/s MTP tok/s Speedup Implied Accept Rate
27B dense 8-bit 20.6 27.1 1.32x ~32%
35B-A3B MoE 8-bit 74.4 82.3 1.11x ~11%
122B-A10B MoE 5-bit 43.0 46.7 1.09x ~9%

At temp=0.6 (production sampling), 122B drops to 1.05x (~5% acceptance).

So yes — it does correlate with architecture. MoE acceptance rates are significantly lower than dense. My hypothesis: the MTP layer contains a full 256-expert MoE routing step (same expert count as the backbone), but with only a single layer of context depth it struggles to predict the correct expert routing. The dense 27B's MTP layer is a standard transformer layer — much simpler prediction task, much higher acceptance.

The fp16 27B was actually 0.61x (slower) — bandwidth-saturated, the MTP overhead exceeds the savings. 8-bit quantization is the sweet spot where MTP helps most.

@Thump604
Copy link
Copy Markdown

Hey @AirRunner — thanks for integrating the MoE sanitize fix! The PR has merge conflicts with main now though. Would you be able to rebase? Happy to help if needed.

Also, any thoughts on tagging a maintainer for review? This has been open since March 13 with zero maintainer engagement. The implementation is solid (8 tests, code review feedback addressed, MoE fix integrated), just needs someone to look at it.

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Mar 21, 2026

Hey @Goekdeniz-Guelmez, would you be able to take a look when you get a chance?

Quick summary: 8 unit tests, code review feedback from @janhilgard and @Thump604, rebased on main.
Results: 1.52x token generation on Qwen3.5-27B dense on M4 Pro, validated independently on M2 Ultra across three Qwen3.5 sizes (MoE and dense).

@layer4down
Copy link
Copy Markdown

Subject: Successfully running Qwen3.5-27B locally with workaround

Transparency Note: This comment was drafted with the assistance of an AI assistant to help document the troubleshooting process. All technical details and findings are from actual testing.


Thanks for this PR! I was able to get Qwen3.5-27B working locally with MLX, but encountered an issue that might help others.

The Bug I Was Addressing

When trying to use the model with a client that passes short model IDs, I encountered:

401 Client Error. (Request ID: Root=1-69bfb0a8...)
Repository Not Found for url: https://huggingface.co/api/models/qwen3_5-27b_4bit/revision/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
User Access Token "Claude-flow-ro" is expired

The error message was misleading - it suggested an expired token, but the real issue was a config/weight mismatch described below.

Issue Encountered

The model failed to load with:

ValueError: Missing 15 parameters: 
language_model.mtp.fc.weight,
language_model.mtp.layers.0.input_layernorm.weight,
...

Root Cause

The model's config.json (from mlx-community/Qwen3.5-27B-4bit on HuggingFace) has:

{
  "text_config": {
    "mtp_num_hidden_layers": 1
  }
}

However, the actual .safetensors weights do not contain any MTP parameters. The PR code correctly expects MTP weights when mtp_num_hidden_layers > 0, but this particular model's config claims MTP support that isn't present in the weights.

Workaround

Set mtp_num_hidden_layers to 0 in the model's config:

cat config.json | jq '.text_config.mtp_num_hidden_layers = 0' > config_fixed.json
mv config_fixed.json config.json

Other Configuration Notes

For anyone trying this setup:

  • Context length: Model supports 98K+ context; works with --max-tokens 98304
  • KV cache quantization: Works with MLX_KV_CACHE_QUANT=true environment variable
  • Model path as ID: The server uses the full local path as the model ID in API calls. For example:
    // Request to /v1/chat/completions
    {
      "model": "/path/to/local/models/mlx-community/Qwen3.5-27B-4bit",
      "messages": [...]
    }
    Short names like "Qwen3.5-27B" will trigger a HuggingFace lookup (and fail if the repo doesn't exist or auth is expired).

Suggestion

It might be helpful to add a check/warning when:

  1. mtp_num_hidden_layers > 0 in config
  2. But MTP weights are missing from the loaded model

This would help users identify config/weight mismatches more quickly and avoid confusing auth error messages.

@AirRunner
Copy link
Copy Markdown
Author

@layer4down thanks for the write-up!

You're right, mlx-community/Qwen3.5-27B-4bit was quantized without the MTP head weights, the mtp_num_hidden_layers: 1 in the config is inherited from the original Qwen3.5 config but the MTP parameters were not included when quantizing.

To actually use MTP acceleration, the model needs to be re-quantized including the MTP layers using this branch.

As you suggested I just pushed a fix that raises a clear ValueError instead of the cryptic "Missing N parameters" crash.

@Thump604
Copy link
Copy Markdown

@angeloskath -- this PR has been open 11 days with no maintainer review. AirRunner rebased on 2026-03-21, all conflicts resolved, 8 unit tests passing.

We've been running this in production on M2 Ultra 128GB since day one. Qwen3.5-122B-A10B-VLM-MTP-5bit, 24/7 inference serving coding agents. MTP acceptance rates:

  • 27B dense 8-bit: 1.32x (32% acceptance, best fit)
  • 35B MoE 8-bit: 1.11x (11% acceptance)
  • 122B MoE 5-bit: 1.09x (9% acceptance)

MoE acceptance rates are lower because a single MTP layer can't predict expert routing well. Still a net win for the latency-sensitive use case.

The MoE sanitize fix (commit 8d06796) is essential for Qwen3.5 MoE models -- without it, 768 MTP parameters are silently missing. We've also published pre-converted VLM+MTP models on HuggingFace that depend on this code path.

Would be great to get this reviewed and merged so the community models work out of the box.

@cresseelia
Copy link
Copy Markdown

cresseelia commented Mar 29, 2026

Can we at the reviewer again? it's an important update for qwen3.5

@Thump604
Copy link
Copy Markdown

@angeloskath @awni — this PR has been open 17 days with no maintainer review or feedback. Multiple community members have asked for review (AirRunner, ourselves, cresseelia).

Is there a concern with the approach, scope, or implementation that's blocking review? We're happy to help address any issues — split the PR, rework the API surface, add tests, whatever is needed.

We're running this in production on 122B and have validated it across three Qwen3.5 model sizes. The community is actively hitting the config/weight mismatch that AirRunner already fixed in this branch (layer4down's report above). Without this merged, users have to manually patch config.json to use MTP on Qwen3.5 models.

If the PR needs changes or a different direction, we'd rather know than wait. Let us know how we can help move this forward.

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

as you can see 6 files have been changes/added alongside 700 lines of added code. This is a PR that has big changes int he codebase itself. Reviewing and (correctly) implementing it will take time. 17 days not long enough. My full weight fine-tuning PR took multiple weeks to be merged. Just keep it open, update it and please be patient. Adding completely new features will take long.

@janhilgard
Copy link
Copy Markdown

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer
in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights.

Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable).
Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls.
After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state.
On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before.

Acceptance rate ~64% average / ~85% on 100-token run.
- Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate)
- Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache)
- Disable batch generation for MTP models (draft/verify loop requires single-sequence processing)

Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator
…t_predicate)

- Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden).
- Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy).

27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x).
Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding.

MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional.
8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers.
- MTP module instantiation and cache creation
- return_hidden shape and pre-norm verification
- mtp_forward output shape
- quant_predicate excludes mtp.fc
- Token identity: mtp_generate_step == generate_step (greedy)
- End-to-end mtp_generate_step completion
AirRunner and others added 4 commits April 1, 2026 04:17
Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect.
MTP layers in MoE models (35B-A3B, 122B-A10B) ship unfused per-expert weights (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) whereas the backbone uses pre-fused switch_mlp format. Conversion was failing with ~768 parameters not in model.

Add a stacking loop in qwen3_5_moe.py sanitize() after the backbone expert loop, mirroring the same pattern for MTP prefixes.

Co-authored-by: Thump604 <thump604@users.noreply.github.com>
When mtp_num_hidden_layers > 0 but the model weights contain no MTP parameters, the previous error was a cryptic 'Missing N parameters'.
Now raises a ValueError with an actionable message.
@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Apr 1, 2026

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@janhilgard I'm not sure splitting would actually help the review here actually?

The PRs you suggest wouldn't be reviewable in isolation, because the architecture changes only make sense in the context of how mtp_generate_step uses them. Also the changes in generate would be dead code until the other PR lands, so one would need to review both PRs together anyways.

(Also, 183 of the 683 added lines are just unit tests).

That said, I'm open to whatever helps, happy to reorganize if it does :).

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 1, 2026

@angeloskath @awni — this PR has been open 20+ days with no maintainer review. It is the foundation for MTP speculative decoding on Qwen3.5 models, which several of us are using in production. My PR #1085 (probabilistic acceptance, 2.3x throughput on 122B) builds directly on top of it.

AirRunner's implementation is solid: 8 tests, 80.6% acceptance on M4 Pro. Is there a concern about scope or approach blocking review?

@gyzerok
Copy link
Copy Markdown

gyzerok commented Apr 1, 2026

@Thump604 can you stop pinging people? The more annoying you are the less likely anyone is going to respond.

@janhilgard
Copy link
Copy Markdown

Great work — I've been running MTP on Qwen3.5 MoE models in production (M3 Ultra, 256 GB) and wanted to share findings that might explain the low MoE acceptance rates.

BF16 MTP weights are critical for MoE acceptance

Your quant_predicate excludes only mtp.fc:

if path.endswith("mtp.fc"):
    return False

But the MTP transformer layer (attention, MLP, norms) still gets quantized. We found that quantized MTP weights give near-0% acceptance on MoE models — the quantization error compounds through the expert routing prediction.

Fix: exclude ALL MTP weights from quantization:

if "mtp." in path:
    return False

Our MoE results with BF16 MTP weights

Model Quantization MTP weights Acceptance Speedup
35B-A3B 4-bit BF16 79-85% 1.18x
122B-A10B 4-bit BF16 77-78% 1.12x
35B-A3B 4-bit dequantized 4→BF16 ~0%

vs your MoE benchmarks (quantized MTP weights):

Model MTP weights Implied acceptance Speedup
35B-A3B 8-bit quantized ~11% 1.11x
122B-A10B 5-bit quantized ~5% 1.09x

The difference is stark: BF16 MTP weights → 79-85% acceptance, quantized → 5-11%.

Batch auto-skip

Your PR sets is_batchable = False when MTP is active. In our vllm-mlx integration (#245 on waybarrios/vllm-mlx) we auto-skip MTP when batch_size > 1:

if len(active_batch) > 1:
    # Skip MTP, fall back to standard generation
    return _orig_step(input_tokens, cache)

This gives the best of both worlds:

  • 1 request: MTP active → 86 tok/s (1.18x)
  • 8 requests: MTP skipped → 307 tok/s (full batching throughput)

Instead of disabling batching entirely, you could dynamically switch.

Weight extraction

We extract BF16 MTP weights from the original HF model (not the quantized MLX model) with a dedicated script. See vllm-mlx PR #245 for the add_mtp_weights_qwen35.py script that:

  • Downloads only MTP-containing shards (not entire model)
  • Stacks per-expert weights into SwitchLinear format
  • Applies RMSNorm +1.0 shift
  • Outputs native BF16

Happy to collaborate on getting BF16 MTP weights into the standard conversion pipeline.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants