From 603b005ef7d8d77a9469a3ac0dccbb50d69dc54d Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 18:03:46 +0100 Subject: [PATCH 01/10] feat: native MTP speculative decoding for Qwen3.5 Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights. Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable). --- mlx_lm/generate.py | 196 +++++++++++++++++++++++++++++++++++++-- mlx_lm/models/qwen3_5.py | 153 ++++++++++++++++++++++++++++-- 2 files changed, 335 insertions(+), 14 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 22531c644..1a49208db 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -643,6 +643,183 @@ def _draft_generate(y, num_draft): _rewind_cache(num_draft, n) +def mtp_generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + sampler: Optional[Callable[[mx.array], mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: + """A generator using the model's native MTP head for speculative decoding. + + Produces up to 2 tokens per forward pass: + - 1 backbone token (always accepted) + - 1 MTP draft token (accepted if the backbone agrees on the next step) + + The model must expose ``mtp_forward(hidden, next_tok, mtp_cache)`` and + support ``return_hidden=True`` in its ``__call__``. + + Yields: + Tuple[mx.array, mx.array, bool]: token, log-probabilities, from_draft. + """ + y = prompt.astype(mx.uint32) + prev_tokens = None + + if prompt_cache is None: + model_cache = cache.make_prompt_cache(model) + mtp_cache = model.make_mtp_cache() + else: + # When a pre-built cache is provided, split at backbone length + n_main = len(model.layers) + model_cache = prompt_cache[:n_main] + mtp_cache = prompt_cache[n_main:] + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + quantize_cache_fn = functools.partial( + maybe_quantize_kv_cache, + quantized_kv_start=quantized_kv_start, + kv_group_size=kv_group_size, + kv_bits=kv_bits, + ) + + def _process_and_sample(tokens, logits): + if logits_processors: + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + tok = sampler(logprobs) + return tok, logprobs + + def _step_backbone(y, n_predict=1): + """One backbone forward pass. Returns (tokens, logprobs, hidden).""" + with mx.stream(generation_stream): + logits, hidden = model(y[None], cache=model_cache, return_hidden=True) + logits = logits[:, -n_predict:, :] + quantize_cache_fn(model_cache) + nonlocal prev_tokens + toks, lps = [], [] + y_ctx = y if n_predict == 1 else y[: -(n_predict - 1)] + for i in range(n_predict): + if logits_processors: + prev_tokens = ( + mx.concatenate([prev_tokens, y_ctx]) + if prev_tokens is not None + else y_ctx + ) + tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0)) + toks.append(tok) + lps.append(lp) + return mx.stack(toks), mx.stack(lps), hidden + + def _step_mtp(hidden_last, main_tok): + """Run MTP head. Returns (draft_token, draft_logprobs).""" + # hidden_last: (1, 1, H), main_tok: 0-d or scalar + next_ids = main_tok.reshape(1, 1) + with mx.stream(generation_stream): + mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) + quantize_cache_fn(mtp_cache) + mtp_logits = mtp_logits[:, -1, :].squeeze(0) + draft_tok, draft_lp = _process_and_sample(prev_tokens, mtp_logits) + return draft_tok, draft_lp + + def _prefill(y): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=model_cache) + quantize_cache_fn(model_cache) + mx.eval([c.state for c in model_cache if hasattr(c, "state")]) + y = y[prefill_step_size:] + mx.clear_cache() + return y + + with mx.stream(generation_stream): + y = _prefill(y) + + ntoks = 0 + draft_tok = None + draft_lp = None + + try: + while True: + if draft_tok is None: + # No pending draft — run backbone only, then generate first draft + toks, lps, hidden = _step_backbone(y, n_predict=1) + mx.eval(toks) + main_tok = toks[0] + main_lp = lps[0] + + ntoks += 1 + yield main_tok, main_lp, False + if ntoks >= max_tokens: + break + + draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok) + mx.eval(draft_tok) + y = mx.array([main_tok.item()], mx.uint32) + else: + # Verify draft: process [y, draft_tok] through backbone together + y_with_draft = mx.concatenate( + [y, mx.array([draft_tok.item()], mx.uint32)] + ) + toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2) + mx.eval(toks, draft_tok) + + verify_pred = toks[0] # backbone prediction after y → verify draft + bonus_tok = toks[1] # backbone prediction after draft_tok + verify_lp = lps[0] + bonus_lp = lps[1] + + if verify_pred.item() == draft_tok.item(): + # Draft accepted + ntoks += 1 + yield draft_tok, draft_lp, True + if ntoks >= max_tokens: + break + + ntoks += 1 + yield bonus_tok, bonus_lp, False + if ntoks >= max_tokens: + break + + # Next draft from MTP at draft_tok's hidden state + draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok) + mx.eval(draft_tok) + y = mx.array([bonus_tok.item()], mx.uint32) + else: + # Draft rejected — trim caches. + # + # Qwen3.5 is a hybrid SSM+Attention model: attention layers use + # KVCache (trimmable), SSM layers use ArraysCache (not trimmable). + # trim_prompt_cache() is all-or-nothing, so we trim KV entries + # individually. The SSM state will retain a 1-token contamination + # from the rejected draft, which is empirically negligible compared + # to the sequence length but means output may differ slightly from + # standard generate_step. A correct fix would require exposing + # per-token intermediate SSM states from GatedDeltaNet (future work). + for c in model_cache: + if c.is_trimmable(): + c.trim(1) + cache.trim_prompt_cache(mtp_cache, 1) + + ntoks += 1 + yield verify_pred, verify_lp, False + if ntoks >= max_tokens: + break + + # Next draft from MTP at y's hidden state + draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred) + mx.eval(draft_tok) + y = mx.array([verify_pred.item()], mx.uint32) + finally: + pass + + def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], @@ -687,19 +864,24 @@ def stream_generate( kwargs["max_tokens"] = max_tokens - if draft_model is None: + if draft_model is not None: + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + token_generator = speculative_generate_step( + prompt, model, draft_model, **kwargs + ) + elif hasattr(model, "mtp_forward"): + kwargs.pop("max_kv_size", None) + kwargs.pop("prompt_progress_callback", None) + kwargs.pop("num_draft_tokens", None) + token_generator = mtp_generate_step(prompt, model, **kwargs) + else: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) # from_draft always false for non-speculative generation token_generator = ( (token, logprobs, False) for token, logprobs in token_generator ) - else: - kwargs.pop("max_kv_size", None) - kwargs.pop("prompt_progress_callback", None) - token_generator = speculative_generate_step( - prompt, model, draft_model, **kwargs - ) with wired_limit(model, [generation_stream]): tic = time.perf_counter() for n, (token, logprobs, from_draft) in enumerate(token_generator): diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 16bb8b7ca..0b8adfa6b 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -50,6 +50,10 @@ class TextModelArgs(BaseModelArgs): moe_intermediate_size: int = 0 norm_topk_prob: bool = True + # MTP fields + mtp_num_hidden_layers: int = 0 + mtp_use_dedicated_embeddings: bool = False + # Rope parameters rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( default_factory=lambda: { @@ -234,6 +238,79 @@ def __call__( return out +class MTPDecoderLayer(nn.Module): + """Full-attention-only transformer layer for the MTP head (no GatedDeltaNet).""" + + def __init__(self, args: TextModelArgs): + super().__init__() + self.self_attn = Attention(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + if args.num_experts > 0: + self.mlp = SparseMoeBlock(args) + else: + self.mlp = MLP(args.hidden_size, args.intermediate_size) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + return h + self.mlp(self.post_attention_layernorm(h)) + + +class MTPModule(nn.Module): + """Multi-Token Prediction head. + + Predicts the token at position t+2 given: + - h_t : backbone hidden state at the last accepted position t + - t_main: the main model's sampled prediction for t+1 + + Forward: + fused = fc(cat([pre_fc_norm_embedding(embed(t_main)), + pre_fc_norm_hidden(h_t)])) + → MTPDecoderLayer(s) + → norm + → (caller applies lm_head, shared with backbone) + """ + + def __init__(self, args: TextModelArgs): + super().__init__() + self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) + self.layers = [ + MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + embed_tokens: nn.Embedding, + cache: Optional[Any] = None, + ) -> mx.array: + # hidden_states : (B, 1, H) — backbone hidden at last accepted position + # next_token_ids: (B, 1) — t_main (main model's prediction for t+1) + embeds = embed_tokens(next_token_ids) # (B, 1, H) + e = self.pre_fc_norm_embedding(embeds) + h = self.pre_fc_norm_hidden(hidden_states) + fused = self.fc(mx.concatenate([e, h], axis=-1)) # (B, 1, H) + + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(fused, cache[0]) + for layer, c in zip(self.layers, cache): + fused = layer(fused, mask, c) + + return self.norm(fused) # (B, 1, H) + + class Qwen3_5TextModel(nn.Module): def __init__(self, args: TextModelArgs): super().__init__() @@ -277,20 +354,51 @@ def __init__(self, args: TextModelArgs): self.model = Qwen3_5TextModel(args) if not args.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + if args.mtp_num_hidden_layers > 0: + self.mtp = MTPModule(args) def __call__( self, inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, ) -> mx.array: - out = self.model(inputs, cache, input_embeddings=input_embeddings) + hidden = self.model(inputs, cache, input_embeddings=input_embeddings) if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) + out = self.model.embed_tokens.as_linear(hidden) else: - out = self.lm_head(out) + out = self.lm_head(hidden) + if return_hidden: + return out, hidden return out + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Run the MTP head and apply the shared lm_head. + + Args: + hidden_states: (B, 1, H) — backbone hidden state at the last position. + next_token_ids: (B, 1) — sampled main token (t_main). + mtp_cache: list of KVCache entries for the MTP transformer layer(s). + + Returns: + logits: (B, 1, vocab_size) + """ + mtp_out = self.mtp( + hidden_states, + next_token_ids, + self.model.embed_tokens, + mtp_cache, + ) + if self.args.tie_word_embeddings: + return self.model.embed_tokens.as_linear(mtp_out) + return self.lm_head(mtp_out) + @property def layers(self): return self.model.layers @@ -298,13 +406,23 @@ def layers(self): def make_cache(self): return [ArraysCache(size=2) if l.is_linear else KVCache() for l in self.layers] + def make_mtp_cache(self): + """Return a fresh list of KVCache entries for the MTP layer(s).""" + if hasattr(self, "mtp"): + return [KVCache() for _ in self.mtp.layers] + return [] + def sanitize(self, weights): - has_mtp_weights = any("mtp." in k for k in weights) has_unsanitized_conv1d = any( "conv1d.weight" in k and v.shape[-1] != 1 for k, v in weights.items() ) - should_shift_norm_weights = has_mtp_weights or has_unsanitized_conv1d - weights = {k: v for k, v in weights.items() if "mtp." not in k} + # Norm weights need a +1 shift only in raw HF checkpoints (detected via + # unsanitized conv1d). Already-converted MLX models (conv1d fixed) must NOT + # be shifted again — even when they contain MTP weights. + should_shift_norm_weights = has_unsanitized_conv1d + # Keep MTP weights if this model has an MTP head; drop them otherwise + if not hasattr(self, "mtp"): + weights = {k: v for k, v in weights.items() if "mtp." not in k} if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None) @@ -315,6 +433,10 @@ def sanitize(self, weights): "model.norm.weight", ".q_norm.weight", ".k_norm.weight", + # MTP-specific norms (not covered by the patterns above) + ".pre_fc_norm_hidden.weight", + ".pre_fc_norm_embedding.weight", + "mtp.norm.weight", ) for k, v in weights.items(): if "conv1d.weight" in k and v.shape[-1] != 1: @@ -370,9 +492,13 @@ def __call__( inputs: mx.array, cache=None, input_embeddings: Optional[mx.array] = None, + return_hidden: bool = False, ): return self.language_model( - inputs, cache=cache, input_embeddings=input_embeddings + inputs, + cache=cache, + input_embeddings=input_embeddings, + return_hidden=return_hidden, ) def sanitize(self, weights): @@ -509,6 +635,19 @@ def _repeat(p): layer.mlp.switch_mlp.up_proj, "all-to-sharded", group=group ) + def mtp_forward( + self, + hidden_states: mx.array, + next_token_ids: mx.array, + mtp_cache: Any, + ) -> mx.array: + """Delegate to language_model.mtp_forward. See TextModel.mtp_forward.""" + return self.language_model.mtp_forward(hidden_states, next_token_ids, mtp_cache) + + def make_mtp_cache(self): + """Return fresh KVCache entries for the MTP layer(s).""" + return self.language_model.make_mtp_cache() + @property def layers(self): return self.language_model.model.layers From 9cb4b92883650205ea4c3b237cc565b047398323 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 18:27:29 +0100 Subject: [PATCH 02/10] fix(mtp): eliminate SSM state contamination on draft rejection Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls. After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state. On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before. Acceptance rate ~64% average / ~85% on 100-token run. --- mlx_lm/generate.py | 56 ++++++++++-------- mlx_lm/models/cache.py | 3 + mlx_lm/models/qwen3_5.py | 125 ++++++++++++++++++++++----------------- 3 files changed, 106 insertions(+), 78 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 1a49208db..b5d91cc78 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -656,17 +656,19 @@ def mtp_generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: - """A generator using the model's native MTP head for speculative decoding. + """A generator that uses the model's native MTP head for speculative decoding. - Produces up to 2 tokens per forward pass: - - 1 backbone token (always accepted) - - 1 MTP draft token (accepted if the backbone agrees on the next step) + Each iteration runs one backbone forward pass over the current token and its + pending draft, then one MTP forward pass to propose the next draft. Up to 2 + tokens are emitted per backbone step: one always-accepted backbone token and + one conditionally-accepted draft token. - The model must expose ``mtp_forward(hidden, next_tok, mtp_cache)`` and + The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and support ``return_hidden=True`` in its ``__call__``. Yields: - Tuple[mx.array, mx.array, bool]: token, log-probabilities, from_draft. + Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft). + ``from_draft`` is ``True`` when the token came from the MTP head. """ y = prompt.astype(mx.uint32) prev_tokens = None @@ -697,10 +699,10 @@ def _process_and_sample(tokens, logits): tok = sampler(logprobs) return tok, logprobs - def _step_backbone(y, n_predict=1): - """One backbone forward pass. Returns (tokens, logprobs, hidden).""" + def _step_backbone(y, n_predict=1, n_confirmed=0): + """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" with mx.stream(generation_stream): - logits, hidden = model(y[None], cache=model_cache, return_hidden=True) + logits, hidden = model(y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) nonlocal prev_tokens @@ -719,8 +721,7 @@ def _step_backbone(y, n_predict=1): return mx.stack(toks), mx.stack(lps), hidden def _step_mtp(hidden_last, main_tok): - """Run MTP head. Returns (draft_token, draft_logprobs).""" - # hidden_last: (1, 1, H), main_tok: 0-d or scalar + """Run the MTP head and return (draft_token, draft_logprobs).""" next_ids = main_tok.reshape(1, 1) with mx.stream(generation_stream): mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache) @@ -763,11 +764,13 @@ def _prefill(y): mx.eval(draft_tok) y = mx.array([main_tok.item()], mx.uint32) else: - # Verify draft: process [y, draft_tok] through backbone together + # Verify draft: run backbone over [y, draft_tok]. + # n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state + # after the confirmed token y, enabling exact rollback on rejection. y_with_draft = mx.concatenate( [y, mx.array([draft_tok.item()], mx.uint32)] ) - toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2) + toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) mx.eval(toks, draft_tok) verify_pred = toks[0] # backbone prediction after y → verify draft @@ -776,7 +779,11 @@ def _prefill(y): bonus_lp = lps[1] if verify_pred.item() == draft_tok.item(): - # Draft accepted + # Draft accepted — discard rollback snapshots. + for c in model_cache: + if hasattr(c, "rollback_state"): + c.rollback_state = None + ntoks += 1 yield draft_tok, draft_lp, True if ntoks >= max_tokens: @@ -792,18 +799,17 @@ def _prefill(y): mx.eval(draft_tok) y = mx.array([bonus_tok.item()], mx.uint32) else: - # Draft rejected — trim caches. - # - # Qwen3.5 is a hybrid SSM+Attention model: attention layers use - # KVCache (trimmable), SSM layers use ArraysCache (not trimmable). - # trim_prompt_cache() is all-or-nothing, so we trim KV entries - # individually. The SSM state will retain a 1-token contamination - # from the rejected draft, which is empirically negligible compared - # to the sequence length but means output may differ slightly from - # standard generate_step. A correct fix would require exposing - # per-token intermediate SSM states from GatedDeltaNet (future work). + # Draft rejected — roll back all caches to the state after y. + # SSM layers (ArraysCache): restore the conv/ssm snapshot saved + # by GatedDeltaNet after the confirmed token. + # Attention layers (KVCache): trim the draft-token entry. for c in model_cache: - if c.is_trimmable(): + if hasattr(c, "rollback_state") and c.rollback_state is not None: + conv_snap, ssm_snap = c.rollback_state + c[0] = conv_snap + c[1] = ssm_snap + c.rollback_state = None + elif c.is_trimmable(): c.trim(1) cache.trim_prompt_cache(mtp_cache, 1) diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index e6993243c..0b880ff1b 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -596,6 +596,9 @@ def __new__(cls, *args, **kwargs): instance = super().__new__(cls) instance.left_padding = None instance.lengths = None + # Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens + # in an MTP draft-verification step. Cleared after each step. + instance.rollback_state = None return instance def __init__(self, size, left_padding: Optional[List[int]] = None): diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 0b8adfa6b..71d61e857 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -52,7 +52,6 @@ class TextModelArgs(BaseModelArgs): # MTP fields mtp_num_hidden_layers: int = 0 - mtp_use_dedicated_embeddings: bool = False # Rope parameters rope_parameters: Optional[Dict[str, Union[float, str, bool, List[int]]]] = field( @@ -133,11 +132,45 @@ def __init__(self, config: TextModelArgs): self.sharding_group = None + def _process_chunk( + self, + qkv_chunk: mx.array, + a_chunk: mx.array, + b_chunk: mx.array, + conv_state: mx.array, + ssm_state: Optional[mx.array], + ssm_mask: Optional[mx.array] = None, + ): + B, S_chunk = qkv_chunk.shape[:2] + conv_in = mx.concatenate([conv_state, qkv_chunk], axis=1) + new_conv_state = conv_in[:, -(self.conv_kernel_size - 1) :] + conv_out = nn.silu(self.conv1d(conv_in)) + + q, k, v = [ + t.reshape(B, S_chunk, h, d) + for t, h, d in zip( + mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), + [self.num_k_heads, self.num_k_heads, self.num_v_heads], + [self.head_k_dim, self.head_k_dim, self.head_v_dim], + ) + ] + inv_scale = k.shape[-1] ** -0.5 + q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) + k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) + + out, new_ssm_state = gated_delta_update( + q, k, v, a_chunk, b_chunk, + self.A_log, self.dt_bias, ssm_state, ssm_mask, + use_kernel=not self.training, + ) + return out, new_conv_state, new_ssm_state + def __call__( self, inputs: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: B, S, _ = inputs.shape @@ -149,50 +182,38 @@ def __call__( b = self.in_proj_b(inputs) a = self.in_proj_a(inputs) - if cache is not None and cache[0] is not None: - conv_state = cache[0] - else: - conv_state = mx.zeros( - (B, self.conv_kernel_size - 1, self.conv_dim), - dtype=inputs.dtype, - ) + conv_state = ( + cache[0] + if cache is not None and cache[0] is not None + else mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) + ) + ssm_state = cache[1] if cache else None if mask is not None: qkv = mx.where(mask[..., None], qkv, 0) - conv_input = mx.concatenate([conv_state, qkv], axis=1) - if cache is not None: - cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] - conv_out = nn.silu(self.conv1d(conv_input)) - q, k, v = [ - t.reshape(B, S, h, d) - for t, h, d in zip( - mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1), - [self.num_k_heads, self.num_k_heads, self.num_v_heads], - [self.head_k_dim, self.head_k_dim, self.head_v_dim], + if n_confirmed > 0 and n_confirmed < S: + # Process confirmed and draft tokens separately so we can snapshot the + # SSM/conv state between them for exact rollback on draft rejection. + mask_c = mask[:, :n_confirmed] if mask is not None else None + mask_d = mask[:, n_confirmed:] if mask is not None else None + out_c, conv_c, ssm_c = self._process_chunk( + qkv[:, :n_confirmed], a[:, :n_confirmed], b[:, :n_confirmed], + conv_state, ssm_state, mask_c, ) - ] - - state = cache[1] if cache else None - inv_scale = k.shape[-1] ** -0.5 - q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6) - k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) - - out, state = gated_delta_update( - q, - k, - v, - a, - b, - self.A_log, - self.dt_bias, - state, - mask, - use_kernel=not self.training, - ) + if cache is not None: + cache.rollback_state = (conv_c, ssm_c) + out_d, conv_f, ssm_f = self._process_chunk( + qkv[:, n_confirmed:], a[:, n_confirmed:], b[:, n_confirmed:], + conv_c, ssm_c, mask_d, + ) + out = mx.concatenate([out_c, out_d], axis=1) + else: + out, conv_f, ssm_f = self._process_chunk(qkv, a, b, conv_state, ssm_state, mask) if cache is not None: - cache[1] = state + cache[0] = conv_f + cache[1] = ssm_f cache.advance(S) out = self.norm(out, z) @@ -228,9 +249,10 @@ def __call__( x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, + n_confirmed: int = 0, ) -> mx.array: if self.is_linear: - r = self.linear_attn(self.input_layernorm(x), mask, cache) + r = self.linear_attn(self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed) else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -263,18 +285,10 @@ def __call__( class MTPModule(nn.Module): - """Multi-Token Prediction head. - - Predicts the token at position t+2 given: - - h_t : backbone hidden state at the last accepted position t - - t_main: the main model's sampled prediction for t+1 - - Forward: - fused = fc(cat([pre_fc_norm_embedding(embed(t_main)), - pre_fc_norm_hidden(h_t)])) - → MTPDecoderLayer(s) - → norm - → (caller applies lm_head, shared with backbone) + """Multi-Token Prediction head (Qwen3.5 native speculative decoding). + + Predicts token t+2 from the backbone hidden state h_t and the sampled + token t+1, using a shared lm_head with the backbone. """ def __init__(self, args: TextModelArgs): @@ -327,6 +341,7 @@ def __call__( inputs: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, + n_confirmed: int = 0, ) -> mx.array: if input_embeddings is not None: hidden_states = input_embeddings @@ -341,7 +356,8 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - hidden_states = layer(hidden_states, mask=mask, cache=c) + kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} + hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) return self.norm(hidden_states) @@ -363,8 +379,9 @@ def __call__( cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, return_hidden: bool = False, + n_confirmed: int = 0, ) -> mx.array: - hidden = self.model(inputs, cache, input_embeddings=input_embeddings) + hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(hidden) else: @@ -493,12 +510,14 @@ def __call__( cache=None, input_embeddings: Optional[mx.array] = None, return_hidden: bool = False, + n_confirmed: int = 0, ): return self.language_model( inputs, cache=cache, input_embeddings=input_embeddings, return_hidden=return_hidden, + n_confirmed=n_confirmed, ) def sanitize(self, weights): From 573cde59eecfe117d1705407e0ffc5d8a4afa920 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 21:50:57 +0100 Subject: [PATCH 03/10] fix(mtp): server integration (yield types, cache fallback, batching) - Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate) - Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache) - Disable batch generation for MTP models (draft/verify loop requires single-sequence processing) Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator --- mlx_lm/generate.py | 13 +++++++------ mlx_lm/server.py | 4 ++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index b5d91cc78..3cb368f6f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -677,10 +677,11 @@ def mtp_generate_step( model_cache = cache.make_prompt_cache(model) mtp_cache = model.make_mtp_cache() else: - # When a pre-built cache is provided, split at backbone length + # Split a pre-built cache at backbone length. If MTP entries are + # absent (e.g. cache created by make_prompt_cache), create them. n_main = len(model.layers) model_cache = prompt_cache[:n_main] - mtp_cache = prompt_cache[n_main:] + mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache() sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) @@ -756,7 +757,7 @@ def _prefill(y): main_lp = lps[0] ntoks += 1 - yield main_tok, main_lp, False + yield main_tok.item(), main_lp, False if ntoks >= max_tokens: break @@ -785,12 +786,12 @@ def _prefill(y): c.rollback_state = None ntoks += 1 - yield draft_tok, draft_lp, True + yield draft_tok.item(), draft_lp, True if ntoks >= max_tokens: break ntoks += 1 - yield bonus_tok, bonus_lp, False + yield bonus_tok.item(), bonus_lp, False if ntoks >= max_tokens: break @@ -814,7 +815,7 @@ def _prefill(y): cache.trim_prompt_cache(mtp_cache, 1) ntoks += 1 - yield verify_pred, verify_lp, False + yield verify_pred.item(), verify_lp, False if ntoks >= max_tokens: break diff --git a/mlx_lm/server.py b/mlx_lm/server.py index c5d1f95c3..9ac67802f 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -399,6 +399,10 @@ def validate_draft_tokenizer(draft_tokenizer): self.is_batchable = all( hasattr(c, "merge") for c in make_prompt_cache(self.model) ) + # MTP speculative decoding requires single-sequence generation + # (draft/verify loop is incompatible with batch generation). + if hasattr(self.model, "mtp_forward"): + self.is_batchable = False return self.model, self.tokenizer From 937c7210091e4bfffc84d14796b78fcdbf7ce07f Mon Sep 17 00:00:00 2001 From: AirRunner Date: Thu, 12 Mar 2026 22:26:12 +0100 Subject: [PATCH 04/10] fix(mtp): address @janhilgard code review feedback (double-norm, quant_predicate) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden). - Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy). 27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x). --- mlx_lm/models/qwen3_5.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 71d61e857..08b52306f 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -359,7 +359,7 @@ def __call__( kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) - return self.norm(hidden_states) + return hidden_states class TextModel(nn.Module): @@ -382,12 +382,13 @@ def __call__( n_confirmed: int = 0, ) -> mx.array: hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) + normed = self.model.norm(hidden) if self.args.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(hidden) + out = self.model.embed_tokens.as_linear(normed) else: - out = self.lm_head(hidden) + out = self.lm_head(normed) if return_hidden: - return out, hidden + return out, hidden # pre-norm hidden for MTP head return out def mtp_forward( @@ -465,14 +466,16 @@ def sanitize(self, weights): @property def quant_predicate(self): - if self.args.num_experts <= 0: - return None - def predicate(path, _): if path.endswith("mlp.gate") or path.endswith("shared_expert_gate"): return {"group_size": 64, "bits": 8} + # Keep the MTP fusion projection in full precision. + if path.endswith("mtp.fc"): + return False return True + if self.args.num_experts <= 0 and self.args.mtp_num_hidden_layers <= 0: + return None return predicate @property From fdd36c75005fdf8a7d06f8721c592af6093d01da Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:07:52 +0100 Subject: [PATCH 05/10] feat(mtp): add --mtp CLI flag for generate and server Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding. MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional. --- mlx_lm/generate.py | 12 +++++++++++- mlx_lm/server.py | 13 +++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3cb368f6f..af831ceaf 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -215,6 +215,12 @@ def setup_arg_parser(): help="Number of tokens to draft when using speculative decoding.", default=3, ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) return parser @@ -833,6 +839,7 @@ def stream_generate( prompt: Union[str, mx.array, List[int]], max_tokens: int = 256, draft_model: Optional[nn.Module] = None, + mtp: bool = False, **kwargs, ) -> Generator[GenerationResponse, None, None]: """ @@ -848,6 +855,8 @@ def stream_generate( draft_model (Optional[nn.Module]): An optional draft model. If provided then speculative decoding is used. The draft model must use the same tokenizer as the main model. Default: ``None``. + mtp (bool): Use native Multi-Token Prediction for speculative + decoding. Requires a model with an MTP head. Default: ``False``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -877,7 +886,7 @@ def stream_generate( token_generator = speculative_generate_step( prompt, model, draft_model, **kwargs ) - elif hasattr(model, "mtp_forward"): + elif mtp and hasattr(model, "mtp_forward"): kwargs.pop("max_kv_size", None) kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) @@ -1714,6 +1723,7 @@ def main(): quantized_kv_start=args.quantized_kv_start, draft_model=draft_model, num_draft_tokens=args.num_draft_tokens, + mtp=args.mtp, ) if not args.verbose: print(response) diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 9ac67802f..171f68f76 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -399,9 +399,11 @@ def validate_draft_tokenizer(draft_tokenizer): self.is_batchable = all( hasattr(c, "merge") for c in make_prompt_cache(self.model) ) - # MTP speculative decoding requires single-sequence generation + # MTP speculative decoding uses single-sequence generation # (draft/verify loop is incompatible with batch generation). - if hasattr(self.model, "mtp_forward"): + # TODO: dynamically switch between MTP (1 request) and + # BatchGenerator (>= 2 concurrent requests). + if self.cli_args.mtp and hasattr(self.model, "mtp_forward"): self.is_batchable = False return self.model, self.tokenizer @@ -865,6 +867,7 @@ def progress(tokens_processed, tokens_total): num_draft_tokens=args.num_draft_tokens, prompt_progress_callback=progress, prefill_step_size=self.cli_args.prefill_step_size, + mtp=getattr(self.cli_args, "mtp", False), ): rqueue.put( Response( @@ -1846,6 +1849,12 @@ def main(): action="store_true", help="Use pipelining instead of tensor parallelism", ) + parser.add_argument( + "--mtp", + action="store_true", + help="Use native Multi-Token Prediction for speculative decoding " + "(requires a model with an MTP head, e.g. Qwen3.5).", + ) args = parser.parse_args() if mx.metal.is_available(): wired_limit = mx.device_info()["max_recommended_working_set_size"] From 7e06f36e2216ca89d6c7ce4bc62f3f126f75a2b4 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:25:53 +0100 Subject: [PATCH 06/10] test(mtp): add unit tests for MTP speculative decoding 8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers. - MTP module instantiation and cache creation - return_hidden shape and pre-norm verification - mtp_forward output shape - quant_predicate excludes mtp.fc - Token identity: mtp_generate_step == generate_step (greedy) - End-to-end mtp_generate_step completion --- tests/test_mtp.py | 182 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 tests/test_mtp.py diff --git a/tests/test_mtp.py b/tests/test_mtp.py new file mode 100644 index 000000000..ba1a0e92f --- /dev/null +++ b/tests/test_mtp.py @@ -0,0 +1,182 @@ +import importlib +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.generate import generate_step, mtp_generate_step + + +def _make_qwen3_5_mtp_model(): + """Create a tiny Qwen3.5 model with an MTP head for testing.""" + module = importlib.import_module("mlx_lm.models.qwen3_5") + args = module.ModelArgs.from_dict( + { + "model_type": "qwen3_5", + "text_config": { + "model_type": "qwen3_5", + "hidden_size": 64, + "intermediate_size": 128, + "num_hidden_layers": 4, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "vocab_size": 256, + "linear_num_value_heads": 2, + "linear_num_key_heads": 2, + "linear_key_head_dim": 16, + "linear_value_head_dim": 16, + "linear_conv_kernel_dim": 3, + "full_attention_interval": 2, + "tie_word_embeddings": True, + "rms_norm_eps": 1e-5, + "head_dim": 32, + "rope_theta": 1000.0, + "partial_rotary_factor": 0.5, + "max_position_embeddings": 128, + "mtp_num_hidden_layers": 1, + }, + } + ) + model = module.Model(args) + model.set_dtype(mx.float32) + mx.eval(model.parameters()) + return model + + +class TestMTP(unittest.TestCase): + """Tests for native MTP (Multi-Token Prediction) speculative decoding. + + Uses a tiny synthetic Qwen3.5 model (4 layers, hidden=64, vocab=256) + with mtp_num_hidden_layers=1 and full_attention_interval=2, giving a + mix of GatedDeltaNet (SSM) and full-attention layers. + + Not tested here (would require a real tokenizer loaded from HF): + - stream_generate() with mtp=True/False flag dispatch + - Server integration (--mtp flag, is_batchable) + """ + + @classmethod + def setUpClass(cls): + cls.model = _make_qwen3_5_mtp_model() + + def test_mtp_module_exists(self): + """Model with mtp_num_hidden_layers=1 should have MTP head.""" + self.assertTrue(hasattr(self.model, "mtp_forward")) + self.assertTrue(hasattr(self.model, "make_mtp_cache")) + lm = self.model.language_model + self.assertTrue(hasattr(lm, "mtp")) + self.assertEqual(len(lm.mtp.layers), 1) + + def test_make_mtp_cache(self): + """make_mtp_cache should return one KVCache per MTP layer.""" + mtp_cache = self.model.make_mtp_cache() + self.assertEqual(len(mtp_cache), 1) + self.assertTrue(mtp_cache[0].is_trimmable()) + + def test_return_hidden(self): + """return_hidden=True should return (logits, hidden) with correct shapes.""" + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + out, hidden = self.model(inputs, cache=cache, return_hidden=True) + self.assertEqual(out.shape, (1, 3, 256)) + self.assertEqual(hidden.shape, (1, 3, 64)) + + def test_mtp_forward_shape(self): + """mtp_forward should produce logits of shape (B, 1, vocab).""" + hidden = mx.random.normal((1, 1, 64)) + next_ids = mx.array([[5]]) + mtp_cache = self.model.make_mtp_cache() + logits = self.model.mtp_forward(hidden, next_ids, mtp_cache) + self.assertEqual(logits.shape, (1, 1, 256)) + + def test_hidden_is_pre_norm(self): + """Hidden states returned with return_hidden should be pre-norm. + + This verifies the fix for double normalization: the backbone returns + pre-norm hidden states, and the final norm is applied only before + lm_head (not before the MTP head). + """ + lm = self.model.language_model + inputs = mx.array([[0, 1, 2]]) + cache = make_prompt_cache(self.model) + + _, hidden = lm(inputs, cache=cache, return_hidden=True) + + # Apply the final norm manually and check it changes the values. + normed = lm.model.norm(hidden) + self.assertFalse(mx.allclose(hidden, normed, atol=1e-5).item()) + + def test_quant_predicate_excludes_mtp_fc(self): + """quant_predicate should exclude mtp.fc from quantization.""" + lm = self.model.language_model + predicate = lm.quant_predicate + self.assertIsNotNone(predicate) + # mtp.fc should not be quantized + self.assertFalse(predicate("mtp.fc", None)) + # Regular layers should be quantized + self.assertTrue(predicate("layers.0.mlp.gate_proj", None)) + + def test_mtp_generate_identity(self): + """mtp_generate_step should produce the same greedy tokens as generate_step. + + This is the most important correctness test: it proves that the + draft/verify loop, SSM state rollback on rejection, and MTP cache + management are all correct. Any bug in these would cause the MTP + path to diverge from standard generation. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def greedy(logprobs): + return mx.argmax(logprobs, axis=-1) + + # Standard generation + std_cache = make_prompt_cache(self.model) + std_tokens = [] + for i, (tok, _) in enumerate( + generate_step(prompt, self.model, sampler=greedy, prompt_cache=std_cache) + ): + std_tokens.append(int(tok)) + if i + 1 >= n_tokens: + break + + # MTP generation + mtp_tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, sampler=greedy, max_tokens=n_tokens + ): + mtp_tokens.append(int(tok)) + if len(mtp_tokens) >= n_tokens: + break + + self.assertEqual( + std_tokens, + mtp_tokens, + f"Token mismatch: std={std_tokens}, mtp={mtp_tokens}", + ) + + def test_mtp_generate_runs(self): + """mtp_generate_step should complete without errors. + + Exercises the full end-to-end path: prefill, backbone forward with + return_hidden, MTP draft generation, verification with n_confirmed, + SSM rollback on rejection, and MTP cache trimming. + """ + prompt = mx.array([0, 1, 2, 3], dtype=mx.uint32) + n_tokens = 10 + + def greedy(logprobs): + return mx.argmax(logprobs, axis=-1) + + tokens = [] + for tok, _, _ in mtp_generate_step( + prompt, self.model, sampler=greedy, max_tokens=n_tokens + ): + tokens.append(int(tok)) + if len(tokens) >= n_tokens: + break + + self.assertEqual(len(tokens), n_tokens) + + +if __name__ == "__main__": + unittest.main() From 826ce800068b36be87822ae4a6aef55d2f696afe Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 00:48:59 +0100 Subject: [PATCH 07/10] fix(mtp): warn when --mtp flag is used with a model without MTP head Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect. --- mlx_lm/generate.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index af831ceaf..43b489f95 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -3,6 +3,7 @@ import argparse import contextlib import functools +import warnings import json import sys import time @@ -891,6 +892,17 @@ def stream_generate( kwargs.pop("prompt_progress_callback", None) kwargs.pop("num_draft_tokens", None) token_generator = mtp_generate_step(prompt, model, **kwargs) + elif mtp: + warnings.warn( + "--mtp flag ignored: model does not have an MTP head. " + "Falling back to standard generation.", + stacklevel=2, + ) + kwargs.pop("num_draft_tokens", None) + token_generator = generate_step(prompt, model, **kwargs) + token_generator = ( + (token, logprobs, False) for token, logprobs in token_generator + ) else: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) From ac3aae5e36a39fcc761928105682dbd65ceeafc1 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Fri, 13 Mar 2026 01:05:55 +0100 Subject: [PATCH 08/10] style: apply black and isort formatting --- mlx_lm/generate.py | 19 +++++++++----- mlx_lm/models/qwen3_5.py | 57 +++++++++++++++++++++++++++++----------- tests/test_mtp.py | 3 ++- 3 files changed, 57 insertions(+), 22 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 43b489f95..2f21800d7 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -3,10 +3,10 @@ import argparse import contextlib import functools -import warnings import json import sys import time +import warnings from dataclasses import dataclass from functools import partial from typing import ( @@ -710,7 +710,9 @@ def _process_and_sample(tokens, logits): def _step_backbone(y, n_predict=1, n_confirmed=0): """Run the backbone on ``y`` and return (tokens, logprobs, hidden).""" with mx.stream(generation_stream): - logits, hidden = model(y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed) + logits, hidden = model( + y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed + ) logits = logits[:, -n_predict:, :] quantize_cache_fn(model_cache) nonlocal prev_tokens @@ -778,11 +780,13 @@ def _prefill(y): y_with_draft = mx.concatenate( [y, mx.array([draft_tok.item()], mx.uint32)] ) - toks, lps, hidden = _step_backbone(y_with_draft, n_predict=2, n_confirmed=1) + toks, lps, hidden = _step_backbone( + y_with_draft, n_predict=2, n_confirmed=1 + ) mx.eval(toks, draft_tok) - verify_pred = toks[0] # backbone prediction after y → verify draft - bonus_tok = toks[1] # backbone prediction after draft_tok + verify_pred = toks[0] # backbone prediction after y → verify draft + bonus_tok = toks[1] # backbone prediction after draft_tok verify_lp = lps[0] bonus_lp = lps[1] @@ -812,7 +816,10 @@ def _prefill(y): # by GatedDeltaNet after the confirmed token. # Attention layers (KVCache): trim the draft-token entry. for c in model_cache: - if hasattr(c, "rollback_state") and c.rollback_state is not None: + if ( + hasattr(c, "rollback_state") + and c.rollback_state is not None + ): conv_snap, ssm_snap = c.rollback_state c[0] = conv_snap c[1] = ssm_snap diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index 08b52306f..bbc3bff46 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -159,8 +159,15 @@ def _process_chunk( k = inv_scale * mx.fast.rms_norm(k, None, 1e-6) out, new_ssm_state = gated_delta_update( - q, k, v, a_chunk, b_chunk, - self.A_log, self.dt_bias, ssm_state, ssm_mask, + q, + k, + v, + a_chunk, + b_chunk, + self.A_log, + self.dt_bias, + ssm_state, + ssm_mask, use_kernel=not self.training, ) return out, new_conv_state, new_ssm_state @@ -185,7 +192,9 @@ def __call__( conv_state = ( cache[0] if cache is not None and cache[0] is not None - else mx.zeros((B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype) + else mx.zeros( + (B, self.conv_kernel_size - 1, self.conv_dim), dtype=inputs.dtype + ) ) ssm_state = cache[1] if cache else None @@ -198,18 +207,28 @@ def __call__( mask_c = mask[:, :n_confirmed] if mask is not None else None mask_d = mask[:, n_confirmed:] if mask is not None else None out_c, conv_c, ssm_c = self._process_chunk( - qkv[:, :n_confirmed], a[:, :n_confirmed], b[:, :n_confirmed], - conv_state, ssm_state, mask_c, + qkv[:, :n_confirmed], + a[:, :n_confirmed], + b[:, :n_confirmed], + conv_state, + ssm_state, + mask_c, ) if cache is not None: cache.rollback_state = (conv_c, ssm_c) out_d, conv_f, ssm_f = self._process_chunk( - qkv[:, n_confirmed:], a[:, n_confirmed:], b[:, n_confirmed:], - conv_c, ssm_c, mask_d, + qkv[:, n_confirmed:], + a[:, n_confirmed:], + b[:, n_confirmed:], + conv_c, + ssm_c, + mask_d, ) out = mx.concatenate([out_c, out_d], axis=1) else: - out, conv_f, ssm_f = self._process_chunk(qkv, a, b, conv_state, ssm_state, mask) + out, conv_f, ssm_f = self._process_chunk( + qkv, a, b, conv_state, ssm_state, mask + ) if cache is not None: cache[0] = conv_f @@ -252,7 +271,9 @@ def __call__( n_confirmed: int = 0, ) -> mx.array: if self.is_linear: - r = self.linear_attn(self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed) + r = self.linear_attn( + self.input_layernorm(x), mask, cache, n_confirmed=n_confirmed + ) else: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -267,7 +288,9 @@ def __init__(self, args: TextModelArgs): super().__init__() self.self_attn = Attention(args) self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) if args.num_experts > 0: self.mlp = SparseMoeBlock(args) else: @@ -296,9 +319,7 @@ def __init__(self, args: TextModelArgs): self.pre_fc_norm_hidden = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.pre_fc_norm_embedding = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) self.fc = nn.Linear(args.hidden_size * 2, args.hidden_size, bias=False) - self.layers = [ - MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers) - ] + self.layers = [MTPDecoderLayer(args) for _ in range(args.mtp_num_hidden_layers)] self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) def __call__( @@ -356,7 +377,11 @@ def __call__( for layer, c in zip(self.layers, cache): mask = ssm_mask if layer.is_linear else fa_mask - kw = {"n_confirmed": n_confirmed} if layer.is_linear and n_confirmed > 0 else {} + kw = ( + {"n_confirmed": n_confirmed} + if layer.is_linear and n_confirmed > 0 + else {} + ) hidden_states = layer(hidden_states, mask=mask, cache=c, **kw) return hidden_states @@ -381,7 +406,9 @@ def __call__( return_hidden: bool = False, n_confirmed: int = 0, ) -> mx.array: - hidden = self.model(inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed) + hidden = self.model( + inputs, cache, input_embeddings=input_embeddings, n_confirmed=n_confirmed + ) normed = self.model.norm(hidden) if self.args.tie_word_embeddings: out = self.model.embed_tokens.as_linear(normed) diff --git a/tests/test_mtp.py b/tests/test_mtp.py index ba1a0e92f..1db571bde 100644 --- a/tests/test_mtp.py +++ b/tests/test_mtp.py @@ -2,8 +2,9 @@ import unittest import mlx.core as mx -from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.generate import generate_step, mtp_generate_step +from mlx_lm.models.cache import make_prompt_cache def _make_qwen3_5_mtp_model(): From 4acf17579a155419a84d959d6324b5e1a0c465ba Mon Sep 17 00:00:00 2001 From: AirRunner Date: Tue, 17 Mar 2026 12:07:02 +0100 Subject: [PATCH 09/10] fix(mtp): stack per-expert MTP weights for MoE models in sanitize() MTP layers in MoE models (35B-A3B, 122B-A10B) ship unfused per-expert weights (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) whereas the backbone uses pre-fused switch_mlp format. Conversion was failing with ~768 parameters not in model. Add a stacking loop in qwen3_5_moe.py sanitize() after the backbone expert loop, mirroring the same pattern for MTP prefixes. Co-authored-by: Thump604 --- mlx_lm/models/qwen3_5_moe.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlx_lm/models/qwen3_5_moe.py b/mlx_lm/models/qwen3_5_moe.py index 53ab8530e..23a6216de 100644 --- a/mlx_lm/models/qwen3_5_moe.py +++ b/mlx_lm/models/qwen3_5_moe.py @@ -2,6 +2,8 @@ from dataclasses import dataclass +import mlx.core as mx + from .base import BaseModelArgs from .qwen3_5 import Model as Qwen3_5Model @@ -49,4 +51,20 @@ def sanitize(self, weights): f"{prefix}.experts.down_proj" ) + # Stack per-expert MTP weights into switch_mlp format. + # MTP layers use unfused per-expert weights (experts.{i}.gate_proj etc) + # unlike backbone layers which use fused gate_up_proj. + mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0) + num_experts = self.language_model.args.num_experts + for l in range(mtp_num): + prefix = f"language_model.mtp.layers.{l}.mlp" + test_key = f"{prefix}.experts.0.gate_proj.weight" + if test_key in new_weights: + for n in ["gate_proj", "up_proj", "down_proj"]: + to_join = [ + new_weights.pop(f"{prefix}.experts.{e}.{n}.weight") + for e in range(num_experts) + ] + new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join) + return self.language_model.sanitize(new_weights) From 04da2467c396e83904f0c5d20524d76f0791ff37 Mon Sep 17 00:00:00 2001 From: AirRunner Date: Sun, 22 Mar 2026 12:07:04 +0100 Subject: [PATCH 10/10] fix(mtp): raise clear error when config has MTP but weights do not When mtp_num_hidden_layers > 0 but the model weights contain no MTP parameters, the previous error was a cryptic 'Missing N parameters'. Now raises a ValueError with an actionable message. --- mlx_lm/models/qwen3_5.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlx_lm/models/qwen3_5.py b/mlx_lm/models/qwen3_5.py index bbc3bff46..d399220db 100644 --- a/mlx_lm/models/qwen3_5.py +++ b/mlx_lm/models/qwen3_5.py @@ -468,6 +468,11 @@ def sanitize(self, weights): # Keep MTP weights if this model has an MTP head; drop them otherwise if not hasattr(self, "mtp"): weights = {k: v for k, v in weights.items() if "mtp." not in k} + elif not any("mtp." in k for k in weights): + raise ValueError( + "Config specifies mtp_num_hidden_layers > 0 but the model weights " + "contain no MTP parameters. Set mtp_num_hidden_layers=0 to disable MTP." + ) if self.args.tie_word_embeddings: weights.pop("lm_head.weight", None)