feat: Add MTP speculative decoding for Qwen3-Next#82
feat: Add MTP speculative decoding for Qwen3-Next#82waybarrios merged 2 commits intowaybarrios:mainfrom
Conversation
|
@waybarrios look what you made me do 😄 You gave me write access and I got a little... carried away. 1.76x speedup, MTP speculative decoding, optimistic mode... I blame you for this — one "welcome aboard" and suddenly I'm rewriting the inference engine 🚀 |
51adb7d to
3ee6623
Compare
3ee6623 to
04c2d7b
Compare
|
How do I test this with |
35b316f to
e8827f0
Compare
Multi-Token Prediction (MTP) support for Qwen3-Next models with speculative decoding integration and related infrastructure. MTP speculative decoding: - Monkey-patch MTP head onto Qwen3-Next models at startup - Draft 1 extra token per step using the model's own MTP head - Always-advance verify with trim-on-reject for KVCache models - Hybrid model support: snapshot/restore RNN state on draft reject to prevent recurrent state corruption in mixed attention+RNN models - Configurable via --enable-mtp flag - ~1.4x throughput improvement on Qwen3-Next-80B-A3B-Instruct-6bit Weight conversion script (scripts/add_mtp_weights.py): - Convert MTP projection weights from HF format to MLX quantized - CPU-only processing to avoid Metal GPU crashes with 512 MoE experts - Per-projection quantize with eager eval to cap peak memory (~3.4GB) - Supports 4-bit and 8-bit quantization with group_size detection Chunked prefill OOM fix: - mx.contiguous() on chunked prefill slices to break lazy dependency on full padded tensor (84GB peak vs 212GB OOM) - _recover_from_generation_error(): abort running requests + clear batch state on fatal errors (OOM, Metal crash) - Exception handler in step() recovers instead of re-raising (prevents infinite loop in engine_core) Additional features: - GPT-OSS reasoning parser (--reasoning-parser gpt_oss) - KV cache quantization support (--kv-cache-quantization) - CLI args: --enable-mtp, --kv-cache-quantization, --reasoning-parser - API utils for reasoning content extraction - Comprehensive test coverage for new features Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
e8827f0 to
e6dba59
Compare
|
@TomLucidor Good question! When you quantize with mlx-my-repo (or We have a script in this PR that fixes that: It downloads the MTP shard from the original BF16 model, quantizes the weights to match your target quant level, and injects them into the MLX model directory. Usage: python scripts/add_mtp_weights.py \
--mlx-model-path /path/to/your/quantized/model \
--source-model bknyaz/Qwen3-Coder-Next-REAM \
--bits 4 # or 3, matching your target quantCaveat for the REAM model: The script was written for Qwen3-Next-80B (512 experts). REAM uses 384 experts (expert merging), so the MTP shard layout might differ. The script reads Simplest path: quantize the model first, then try running the script pointing Once MTP weights are in place, just add vllm-mlx serve /path/to/quantized/model --enable-mtp --port 8080 |
|
Interesting, we are providing post-processing code that is awesome! Let me take a deep look on this PR. |
|
Thanks! Yeah, the post-processing script ( I've been actively testing this PR against opencode — MTP verified mode with continuous batching, prefix cache, tool calling — everything works well so far. No hangs, no stale state between requests, prefix cache + MTP batch dimension is stable. Take your time with the review, happy to address any feedback! |
|
I suspect the Opencode problem lies elsewhere. I'm still looking into it and waiting to be 100% sure before posting updates, but I'll share more soon. However, this idea is interesting, though. |
|
Lines 687-689 unconditionally quantize the cache before the duplicate-entry early return at line 694 and before vllm-mlx/vllm_mlx/memory_cache.py Lines 686 to 709 in e6dba59 |
|
@janhilgard I found one issue, but I fixed it in the last commit |
|
Take a look at this: https://docs.vllm.ai/projects/ascend/en/main/user_guide/feature_guide/speculative_decoding.html Since they validated this across many models, we should aim to replicate that. We can test MTP with various models to see how it impacts TTFT and throughput. Remind me, what are your hardware specs? |
|
Good catch on the duplicate quantization, thanks for the fix! Pulled Re: hardware — Apple M3 Ultra, 256GB unified memory. Happy to run MTP benchmarks across different models. The vLLM Ascend validation matrix is a great reference — we should at minimum cover:
I can set up a benchmark script that measures both TTFT and throughput with/without |
|
Ok will accept this. But then do the benchmarking thing. |
MTP Benchmark Results — GPU-standard metricsRan a comprehensive MTP benchmark on Qwen3-Next-80B-A3B-Instruct-6bit (M3 Ultra 256GB) with Single-request Latency Profile (21 runs, short/medium/long prompts)
Concurrency Matrix (MTP enabled, 2 rounds averaged)
Key observations
Comparison with GPU MTP results
Benchmark scriptAttached below — standalone Python script that measures TTFT, ITL distribution (p50/p95/p99), throughput, E2E latency at configurable concurrency levels. Outputs CSV. #!/usr/bin/env python3
"""
MTP Benchmark Suite for vllm-mlx.
Usage:
python3 benchmark_mtp.py # full benchmark
python3 benchmark_mtp.py --test latency # latency only
python3 benchmark_mtp.py --test concurrency --concurrency 1,4,8,16
python3 benchmark_mtp.py --compare localhost:1239,localhost:8080 # A/B
"""
import argparse
import asyncio
import csv
import json
import os
import statistics
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from typing import Optional
import aiohttp
DEFAULT_SERVER = "http://localhost:1239"
RESULTS_DIR = "benchmark_results"
PROMPTS = {
"short": [
"What is 2+2?",
"Say hello.",
"What color is the sky?",
],
"medium": [
"Explain the difference between TCP and UDP in networking.",
"Write a Python function that checks if a string is a palindrome.",
"Describe how garbage collection works in modern programming languages.",
],
"long": [
(
"You are a senior software architect reviewing a microservices system. "
"The system has 12 services communicating via gRPC and Kafka. "
"Service A (user-auth) handles authentication with JWT tokens. "
"Service B (order-processing) manages orders and communicates with "
"Service C (inventory) and Service D (payment-gateway). "
"Service E (notification) sends emails and push notifications. "
"The team reports that during peak hours (10K req/s), the system "
"experiences cascading failures starting from Service D. "
"Analyze the potential causes and propose a comprehensive solution "
"including circuit breakers, bulkheads, and retry strategies."
),
],
}
@dataclass
class TokenTiming:
index: int
timestamp: float
itl: float
@dataclass
class RequestMetrics:
prompt_category: str
prompt_length: int
completion_tokens: int
ttft: float
total_time: float
throughput: float
itl_mean: float
itl_p50: float
itl_p95: float
itl_p99: float
itl_min: float
itl_max: float
itl_std: float
error: Optional[str] = None
@dataclass
class ConcurrencyResult:
concurrency: int
num_requests: int
wall_time: float
aggregate_tps: float
per_request_tps_mean: float
per_request_tps_min: float
per_request_tps_max: float
ttft_mean: float
ttft_p50: float
ttft_p95: float
ttft_min: float
ttft_max: float
itl_mean: float
itl_p50: float
itl_p95: float
e2e_mean: float
e2e_p50: float
e2e_p95: float
errors: int
def _percentile(data: list[float], p: float) -> float:
if not data:
return 0.0
s = sorted(data)
k = (len(s) - 1) * p / 100
f = int(k)
c = f + 1
if c >= len(s):
return s[-1]
return s[f] + (k - f) * (s[c] - s[f])
async def streaming_request_with_itl(
session, url, model, messages, max_tokens=300,
):
payload = {
"model": model, "messages": messages,
"max_tokens": max_tokens, "stream": True, "temperature": 0.7,
}
timings = []
t_start = time.perf_counter()
first_token_time = None
token_count = 0
last_token_time = None
prompt_tokens = 0
try:
async with session.post(f"{url}/v1/chat/completions", json=payload) as resp:
if resp.status != 200:
body = await resp.text()
m = RequestMetrics("", 0, 0, 0, time.perf_counter()-t_start,
0, 0, 0, 0, 0, 0, 0, 0, error=f"HTTP {resp.status}: {body[:200]}")
return m, []
async for line in resp.content:
decoded = line.decode("utf-8").strip()
if not decoded.startswith("data: "): continue
data_str = decoded[6:]
if data_str == "[DONE]": break
try: chunk = json.loads(data_str)
except json.JSONDecodeError: continue
choices = chunk.get("choices", [])
if not choices: continue
content = choices[0].get("delta", {}).get("content", "")
if content:
now = time.perf_counter()
if first_token_time is None: first_token_time = now
itl = now - last_token_time if last_token_time else now - t_start
timings.append(TokenTiming(token_count, now, itl))
last_token_time = now
token_count += 1
usage = chunk.get("usage")
if usage: prompt_tokens = usage.get("prompt_tokens", 0)
except Exception as e:
m = RequestMetrics("", 0, 0, 0, time.perf_counter()-t_start,
0, 0, 0, 0, 0, 0, 0, 0, error=str(e))
return m, []
t_end = time.perf_counter()
if first_token_time is None: first_token_time = t_end
ttft = first_token_time - t_start
gen_time = t_end - first_token_time
throughput = (token_count - 1) / gen_time if gen_time > 0 and token_count > 1 else 0
itl_values = [t.itl for t in timings[1:]] if len(timings) > 1 else [0]
return RequestMetrics(
"", prompt_tokens, token_count, ttft, t_end - t_start, throughput,
statistics.mean(itl_values), _percentile(itl_values, 50),
_percentile(itl_values, 95), _percentile(itl_values, 99),
min(itl_values), max(itl_values),
statistics.stdev(itl_values) if len(itl_values) > 1 else 0,
), timings
async def check_server(url):
try:
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as s:
async with s.get(f"{url}/v1/models") as r:
if r.status == 200:
data = await r.json()
models = data.get("data", [])
if models: return models[0].get("id", "unknown")
except Exception: pass
return None
async def test_latency_profile(url, model, name, num_runs=3, max_tokens=300):
print(f"\n{'─'*70}\n Latency Profile — {name}\n{'─'*70}")
all_metrics = []
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=180)) as session:
for cat, prompts in PROMPTS.items():
for prompt in prompts:
for run in range(num_runs):
msgs = [{"role": "user", "content": prompt}]
m, _ = await streaming_request_with_itl(session, url, model, msgs, max_tokens)
m.prompt_category = cat
if not m.error:
print(f" [{cat}] Run {run+1}: {m.completion_tokens} tok, "
f"TTFT={m.ttft:.3f}s, {m.throughput:.1f} tok/s, "
f"ITL_p50={m.itl_p50*1000:.1f}ms, ITL_p95={m.itl_p95*1000:.1f}ms")
all_metrics.append(m)
await asyncio.sleep(0.3)
valid = [m for m in all_metrics if not m.error]
if valid:
print(f"\n Summary ({len(valid)} runs):")
print(f" TTFT: avg={statistics.mean(m.ttft for m in valid):.3f}s")
print(f" TPS: avg={statistics.mean(m.throughput for m in valid):.1f}")
print(f" ITL: mean={statistics.mean(m.itl_mean for m in valid)*1000:.1f}ms, "
f"p95={statistics.mean(m.itl_p95 for m in valid)*1000:.1f}ms")
return all_metrics
async def test_concurrency_matrix(url, model, name, levels, max_tokens=200, rounds=2):
print(f"\n{'─'*70}\n Concurrency Matrix — {name}\n{'─'*70}")
prompts = PROMPTS["short"][:2] + PROMPTS["medium"][:2] + PROMPTS["long"][:1]
results = []
for n in levels:
round_results = []
for rnd in range(rounds):
req_prompts = [prompts[i % len(prompts)] for i in range(n)]
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session:
t0 = time.perf_counter()
tasks = [streaming_request_with_itl(session, url, model,
[{"role":"user","content":p}], max_tokens) for p in req_prompts]
responses = await asyncio.gather(*tasks)
wall = time.perf_counter() - t0
mlist = [r[0] for r in responses]
valid = [m for m in mlist if not m.error]
if not valid: continue
total_tok = sum(m.completion_tokens for m in valid)
cr = ConcurrencyResult(
n, len(mlist), wall, total_tok/wall if wall>0 else 0,
statistics.mean(m.throughput for m in valid),
min(m.throughput for m in valid), max(m.throughput for m in valid),
statistics.mean(m.ttft for m in valid),
_percentile([m.ttft for m in valid], 50),
_percentile([m.ttft for m in valid], 95),
min(m.ttft for m in valid), max(m.ttft for m in valid),
statistics.mean(m.itl_mean for m in valid),
_percentile([m.itl_mean for m in valid], 50),
_percentile([m.itl_mean for m in valid], 95),
statistics.mean(m.total_time for m in valid),
_percentile([m.total_time for m in valid], 50),
_percentile([m.total_time for m in valid], 95),
len(mlist)-len(valid),
)
round_results.append(cr)
print(f" ×{n} round {rnd+1}: agg={cr.aggregate_tps:.1f} tok/s, "
f"per-req={cr.per_request_tps_mean:.1f}, TTFT_p50={cr.ttft_p50:.3f}s")
await asyncio.sleep(2)
if round_results:
avg = ConcurrencyResult(
n, round_results[0].num_requests,
statistics.mean(r.wall_time for r in round_results),
statistics.mean(r.aggregate_tps for r in round_results),
statistics.mean(r.per_request_tps_mean for r in round_results),
min(r.per_request_tps_min for r in round_results),
max(r.per_request_tps_max for r in round_results),
statistics.mean(r.ttft_mean for r in round_results),
statistics.mean(r.ttft_p50 for r in round_results),
statistics.mean(r.ttft_p95 for r in round_results),
min(r.ttft_min for r in round_results),
max(r.ttft_max for r in round_results),
statistics.mean(r.itl_mean for r in round_results),
statistics.mean(r.itl_p50 for r in round_results),
statistics.mean(r.itl_p95 for r in round_results),
statistics.mean(r.e2e_mean for r in round_results),
statistics.mean(r.e2e_p50 for r in round_results),
statistics.mean(r.e2e_p95 for r in round_results),
sum(r.errors for r in round_results),
)
results.append(avg)
if results:
print(f"\n {'Conc':>5} {'Agg tok/s':>10} {'Per-req':>10} {'TTFT p50':>10} "
f"{'ITL mean':>10} {'E2E p50':>10}")
for r in results:
print(f" {r.concurrency:>5} {r.aggregate_tps:>10.1f} "
f"{r.per_request_tps_mean:>10.1f} {r.ttft_p50:>9.3f}s "
f"{r.itl_mean*1000:>9.1f}ms {r.e2e_p50:>9.2f}s")
return results
def save_csv(metrics, filepath, fieldnames):
os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True)
with open(filepath, "w", newline="") as f:
w = csv.DictWriter(f, fieldnames=fieldnames)
w.writeheader()
for m in metrics: w.writerow(asdict(m))
print(f" Saved: {filepath}")
async def main():
parser = argparse.ArgumentParser(description="MTP Benchmark Suite")
parser.add_argument("--server", default=DEFAULT_SERVER)
parser.add_argument("--model", default=None)
parser.add_argument("--test", choices=["latency","concurrency","all"], default="all")
parser.add_argument("--concurrency", default="1,4,8,16")
parser.add_argument("--runs", type=int, default=3)
parser.add_argument("--rounds", type=int, default=2)
parser.add_argument("--max-tokens", type=int, default=300)
parser.add_argument("--compare", default=None,
help="A/B compare: host1:port1,host2:port2")
args = parser.parse_args()
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
url = args.server
model = args.model or await check_server(url)
if not model:
print(f"Server {url} unavailable!"); sys.exit(1)
print(f"Server: {url}\nModel: {model}")
# Warmup
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as s:
wr, _ = await streaming_request_with_itl(s, url, model,
[{"role":"user","content":"Hello."}], 10)
print(f"Warmup: TTFT={wr.ttft:.3f}s, {wr.throughput:.1f} tok/s")
await asyncio.sleep(1)
if args.test in ("latency", "all"):
m = await test_latency_profile(url, model, model, args.runs, args.max_tokens)
save_csv(m, f"{RESULTS_DIR}/mtp_latency_{ts}.csv",
[f.name for f in RequestMetrics.__dataclass_fields__.values()])
await asyncio.sleep(2)
if args.test in ("concurrency", "all"):
levels = [int(x) for x in args.concurrency.split(",")]
r = await test_concurrency_matrix(url, model, model, levels, args.max_tokens, args.rounds)
save_csv(r, f"{RESULTS_DIR}/mtp_concurrency_{ts}.csv",
[f.name for f in ConcurrencyResult.__dataclass_fields__.values()])
if __name__ == "__main__":
asyncio.run(main())Next steps — I'll open a separate issue proposing:
|
|
@janhilgard I need to compare the metadata of the original Qwen3-Coder-Next to the REAM version for the MTP feature, BUT I can't seem to locate the metadata/namespace for it. Wondering where it is so I can check if the REAM version can be made compatible (and not just an MLX stripper issue)? |
|
@TomLucidor The MTP weights live under the However — Qwen3-Coder-Next doesn't ship with MTP weights at all. The non-Coder variant (Qwen3-Next-80B-A3B-Instruct) has them in its last shard ( So the compatibility chain breaks at the source:
What you can check to confirm: Download the Coder model's curl -s https://huggingface.co/Qwen/Qwen3-Coder-Next/resolve/main/model.safetensors.index.json | python3 -c "import sys,json; wm=json.load(sys.stdin)['weight_map']; mtp=[k for k in wm if 'mtp' in k]; print(f'{len(mtp)} MTP keys found') if mtp else print('No MTP keys — MTP not included in this model')"If Qwen releases MTP weights for the Coder variant in the future, the |
|
@janhilgard okay now I am slightly spooked
REAM has no MTP support https://huggingface.co/SamsungSAILMontreal/Qwen3-Next-80B-A3B-Instruct-REAM/raw/main/model.safetensors.index.json compared to the original https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct/raw/main/model.safetensors.index.json (Now I gonna ask the REAM dev to see if they can support MTP just in case) |
Merge 17 upstream commits including: - KV cache quantization for prefix cache memory reduction (waybarrios#62) - Streaming tool call parsing via ToolParser integration (waybarrios#46) - MTP speculative decoding for Qwen3-Next (waybarrios#82) - GPT-OSS reasoning parser and Harmony format parsers - mlx-lm >= 0.30.5 requirement, transformers >= 5.0.0 - BatchMambaCache fix for mlx-lm >= 0.30.6 (waybarrios#89) - MLLM continuous batching fixes (waybarrios#76) - Force MLLM mode option (waybarrios#81) - Various bug fixes Conflict resolution: - server.py: Replaced local tool_call_buffering with upstream's ToolParser-based streaming (more robust) - cli.py: Deduplicated --mllm, --default-temperature, --default-top-p args (upstream already added them), kept local --embedding-model - mamba_cache.py: Took upstream's conditional HAS_MAMBA_CACHE approach - pyproject.toml: Took upstream's version and dependency changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
_generation_step()→self._generation_step()(3 occurrences)_chunked_nextnow processes small prompt batches directly instead of delegating to_orig_next, fixing DeltaRNN conv_state batch-dimension mismatch when mixing prefix-cached and fresh prompts_process_promptswhen cache doesn't belong to the active batch_skip_state/_deferred_draftswhenactive_batchis None and drain pending async_eval work viamx.clear_cache()before starting new prefill, preventing generation hang on consecutive requests after BatchGenerator recreationNew CLI options
--enable-mtp— enable MTP speculative decoding--mtp-num-draft-tokens N— draft tokens per step (default: 1)--mtp-optimistic— skip acceptance check for max speed (~5-10% wrong tokens)How it works
return_hidden=True)Weight preparation
Models need MTP weights added to their safetensors files:
python scripts/add_mtp_weights.py \ --mlx-model-path ~/.cache/huggingface/hub/models--mlx-community--Qwen3-Next-80B-A3B-Instruct-6bit \ --source-model Qwen/Qwen3-Next-80B-A3B-InstructBenchmark: Single-user (Qwen3-Next-80B-A3B-Instruct-6bit, Apple M3 Ultra 256GB)
Benchmark: Concurrent requests (MTP verified, continuous batching, max_tokens=300)
All 16 concurrent requests complete successfully with MTP enabled and prefix caching active.
Test plan
🤖 Generated with Claude Code