Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 224 additions & 6 deletions mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import sys
import time
import warnings
from dataclasses import dataclass
from functools import partial
from typing import (
Expand Down Expand Up @@ -215,6 +216,12 @@ def setup_arg_parser():
help="Number of tokens to draft when using speculative decoding.",
default=3,
)
parser.add_argument(
"--mtp",
action="store_true",
help="Use native Multi-Token Prediction for speculative decoding "
"(requires a model with an MTP head, e.g. Qwen3.5).",
)
return parser


Expand Down Expand Up @@ -643,12 +650,204 @@ def _draft_generate(y, num_draft):
_rewind_cache(num_draft, n)


def mtp_generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[[mx.array], mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""A generator that uses the model's native MTP head for speculative decoding.

Each iteration runs one backbone forward pass over the current token and its
pending draft, then one MTP forward pass to propose the next draft. Up to 2
tokens are emitted per backbone step: one always-accepted backbone token and
one conditionally-accepted draft token.

The model must implement ``mtp_forward(hidden, next_tok, mtp_cache)`` and
support ``return_hidden=True`` in its ``__call__``.

Yields:
Tuple[mx.array, mx.array, bool]: (token, log-probabilities, from_draft).
``from_draft`` is ``True`` when the token came from the MTP head.
"""
y = prompt.astype(mx.uint32)
prev_tokens = None

if prompt_cache is None:
model_cache = cache.make_prompt_cache(model)
mtp_cache = model.make_mtp_cache()
else:
# Split a pre-built cache at backbone length. If MTP entries are
# absent (e.g. cache created by make_prompt_cache), create them.
n_main = len(model.layers)
model_cache = prompt_cache[:n_main]
mtp_cache = prompt_cache[n_main:] or model.make_mtp_cache()

sampler = sampler or (lambda x: mx.argmax(x, axis=-1))

quantize_cache_fn = functools.partial(
maybe_quantize_kv_cache,
quantized_kv_start=quantized_kv_start,
kv_group_size=kv_group_size,
kv_bits=kv_bits,
)

def _process_and_sample(tokens, logits):
if logits_processors:
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True)
tok = sampler(logprobs)
return tok, logprobs

def _step_backbone(y, n_predict=1, n_confirmed=0):
"""Run the backbone on ``y`` and return (tokens, logprobs, hidden)."""
with mx.stream(generation_stream):
logits, hidden = model(
y[None], cache=model_cache, return_hidden=True, n_confirmed=n_confirmed
)
logits = logits[:, -n_predict:, :]
quantize_cache_fn(model_cache)
nonlocal prev_tokens
toks, lps = [], []
y_ctx = y if n_predict == 1 else y[: -(n_predict - 1)]
for i in range(n_predict):
if logits_processors:
prev_tokens = (
mx.concatenate([prev_tokens, y_ctx])
if prev_tokens is not None
else y_ctx
)
tok, lp = _process_and_sample(prev_tokens, logits[:, i, :].squeeze(0))
toks.append(tok)
lps.append(lp)
return mx.stack(toks), mx.stack(lps), hidden

def _step_mtp(hidden_last, main_tok):
"""Run the MTP head and return (draft_token, draft_logprobs)."""
next_ids = main_tok.reshape(1, 1)
with mx.stream(generation_stream):
mtp_logits = model.mtp_forward(hidden_last, next_ids, mtp_cache)
quantize_cache_fn(mtp_cache)
mtp_logits = mtp_logits[:, -1, :].squeeze(0)
draft_tok, draft_lp = _process_and_sample(prev_tokens, mtp_logits)
return draft_tok, draft_lp

def _prefill(y):
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=model_cache)
quantize_cache_fn(model_cache)
mx.eval([c.state for c in model_cache if hasattr(c, "state")])
y = y[prefill_step_size:]
mx.clear_cache()
return y

with mx.stream(generation_stream):
y = _prefill(y)

ntoks = 0
draft_tok = None
draft_lp = None

try:
while True:
if draft_tok is None:
# No pending draft — run backbone only, then generate first draft
toks, lps, hidden = _step_backbone(y, n_predict=1)
mx.eval(toks)
main_tok = toks[0]
main_lp = lps[0]

ntoks += 1
yield main_tok.item(), main_lp, False
if ntoks >= max_tokens:
break

draft_tok, draft_lp = _step_mtp(hidden[:, -1:, :], main_tok)
mx.eval(draft_tok)
y = mx.array([main_tok.item()], mx.uint32)
else:
# Verify draft: run backbone over [y, draft_tok].
# n_confirmed=1 causes GatedDeltaNet to snapshot its SSM/conv state
# after the confirmed token y, enabling exact rollback on rejection.
y_with_draft = mx.concatenate(
[y, mx.array([draft_tok.item()], mx.uint32)]
)
toks, lps, hidden = _step_backbone(
y_with_draft, n_predict=2, n_confirmed=1
)
mx.eval(toks, draft_tok)

verify_pred = toks[0] # backbone prediction after y → verify draft
bonus_tok = toks[1] # backbone prediction after draft_tok
verify_lp = lps[0]
bonus_lp = lps[1]

if verify_pred.item() == draft_tok.item():
# Draft accepted — discard rollback snapshots.
for c in model_cache:
if hasattr(c, "rollback_state"):
c.rollback_state = None

ntoks += 1
yield draft_tok.item(), draft_lp, True
if ntoks >= max_tokens:
break

ntoks += 1
yield bonus_tok.item(), bonus_lp, False
if ntoks >= max_tokens:
break

# Next draft from MTP at draft_tok's hidden state
draft_tok, draft_lp = _step_mtp(hidden[:, 1:2, :], bonus_tok)
mx.eval(draft_tok)
y = mx.array([bonus_tok.item()], mx.uint32)
else:
# Draft rejected — roll back all caches to the state after y.
# SSM layers (ArraysCache): restore the conv/ssm snapshot saved
# by GatedDeltaNet after the confirmed token.
# Attention layers (KVCache): trim the draft-token entry.
for c in model_cache:
if (
hasattr(c, "rollback_state")
and c.rollback_state is not None
):
conv_snap, ssm_snap = c.rollback_state
c[0] = conv_snap
c[1] = ssm_snap
c.rollback_state = None
elif c.is_trimmable():
c.trim(1)
cache.trim_prompt_cache(mtp_cache, 1)

ntoks += 1
yield verify_pred.item(), verify_lp, False
if ntoks >= max_tokens:
break

# Next draft from MTP at y's hidden state
draft_tok, draft_lp = _step_mtp(hidden[:, 0:1, :], verify_pred)
mx.eval(draft_tok)
y = mx.array([verify_pred.item()], mx.uint32)
finally:
pass


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 256,
draft_model: Optional[nn.Module] = None,
mtp: bool = False,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
Expand All @@ -664,6 +863,8 @@ def stream_generate(
draft_model (Optional[nn.Module]): An optional draft model. If provided
then speculative decoding is used. The draft model must use the same
tokenizer as the main model. Default: ``None``.
mtp (bool): Use native Multi-Token Prediction for speculative
decoding. Requires a model with an MTP head. Default: ``False``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.

Expand All @@ -687,18 +888,34 @@ def stream_generate(

kwargs["max_tokens"] = max_tokens

if draft_model is None:
if draft_model is not None:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
)
elif mtp and hasattr(model, "mtp_forward"):
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
kwargs.pop("num_draft_tokens", None)
token_generator = mtp_generate_step(prompt, model, **kwargs)
elif mtp:
warnings.warn(
"--mtp flag ignored: model does not have an MTP head. "
"Falling back to standard generation.",
stacklevel=2,
)
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else:
kwargs.pop("max_kv_size", None)
kwargs.pop("prompt_progress_callback", None)
token_generator = speculative_generate_step(
prompt, model, draft_model, **kwargs
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
with wired_limit(model, [generation_stream]):
tic = time.perf_counter()
Expand Down Expand Up @@ -1525,6 +1742,7 @@ def main():
quantized_kv_start=args.quantized_kv_start,
draft_model=draft_model,
num_draft_tokens=args.num_draft_tokens,
mtp=args.mtp,
)
if not args.verbose:
print(response)
Expand Down
3 changes: 3 additions & 0 deletions mlx_lm/models/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)
instance.left_padding = None
instance.lengths = None
# Snapshot of (conv_state, ssm_state) saved after processing confirmed tokens
# in an MTP draft-verification step. Cleared after each step.
instance.rollback_state = None
return instance

def __init__(self, size, left_padding: Optional[List[int]] = None):
Expand Down
Loading