Skip to content

feat: Add MTP speculative decoding for Qwen3-Next#82

Merged
waybarrios merged 2 commits intowaybarrios:mainfrom
janhilgard:feat/mtp-speculative-clean
Feb 14, 2026
Merged

feat: Add MTP speculative decoding for Qwen3-Next#82
waybarrios merged 2 commits intowaybarrios:mainfrom
janhilgard:feat/mtp-speculative-clean

Conversation

@janhilgard
Copy link
Copy Markdown
Collaborator

@janhilgard janhilgard commented Feb 13, 2026

Summary

  • Add MTP (Multi-Token Prediction) speculative decoding for models with built-in MTP heads (e.g. Qwen3-Next-80B)
  • Always-advance strategy: verify [primary, draft] in one model call, trim KVCache on reject (no snapshot/restore needed)
  • Two modes: verified (fused eval, 1.43x speedup) and optimistic (zero sync, 1.76x speedup)
  • Deferred draft emission: drafts emitted in the NEXT generation step for correct token ordering
  • Weight preparation script for adding MTP weights to quantized MLX models
  • Chunked prefill bugfix: _generation_step()self._generation_step() (3 occurrences)
  • Fix: batch prefill with prefix cache_chunked_next now processes small prompt batches directly instead of delegating to _orig_next, fixing DeltaRNN conv_state batch-dimension mismatch when mixing prefix-cached and fresh prompts
  • Fix: MTP prefill guard — prevent MTP from running during _process_prompts when cache doesn't belong to the active batch
  • Fix: stale MTP state between requests — clear _skip_state/_deferred_drafts when active_batch is None and drain pending async_eval work via mx.clear_cache() before starting new prefill, preventing generation hang on consecutive requests after BatchGenerator recreation

New CLI options

  • --enable-mtp — enable MTP speculative decoding
  • --mtp-num-draft-tokens N — draft tokens per step (default: 1)
  • --mtp-optimistic — skip acceptance check for max speed (~5-10% wrong tokens)

How it works

  1. Primary forward: Model generates logits + hidden states (return_hidden=True)
  2. MTP draft: MTP head predicts token n+2 from hidden states + primary token n+1
  3. Always-advance verify: Feed [primary, draft] to model in one call (cache advances by 2)
  4. Accept: Store skip_state from position 1, defer draft token for next step emission
  5. Reject: Trim KVCache by 1, skip_state from position 0 (no cold restart)

Weight preparation

Models need MTP weights added to their safetensors files:

python scripts/add_mtp_weights.py \
  --mlx-model-path ~/.cache/huggingface/hub/models--mlx-community--Qwen3-Next-80B-A3B-Instruct-6bit \
  --source-model Qwen/Qwen3-Next-80B-A3B-Instruct

Benchmark: Single-user (Qwen3-Next-80B-A3B-Instruct-6bit, Apple M3 Ultra 256GB)

Mode tok/s Speedup
Baseline 55.1 1.0x
MTP verified 78.8 1.43x
MTP optimistic 96.8 1.76x

Benchmark: Concurrent requests (MTP verified, continuous batching, max_tokens=300)

Parallel Total tok Time Aggregate tok/s Per-request tok/s
1 300 4.3s 70.4 70.4
4 1,200 11.0s 108.8 27.2
8 2,400 16.5s 145.4 18.2
12 3,310 17.1s 194.0 16.7
16 4,241 20.9s 202.6 13.6

All 16 concurrent requests complete successfully with MTP enabled and prefix caching active.

Test plan

  • Import + syntax verification
  • SchedulerConfig MTP fields
  • CLI argument parsing
  • validate_mtp_support with mock models (no MTP, MTP config without weights)
  • validate_mtp_support with real Qwen3-Next model
  • model.mtp_forward() forward pass
  • _install_mtp patches BatchGenerator correctly
  • E2E: MTP verified mode generates correct tokens
  • E2E: MTP optimistic mode generates coherent text
  • Benchmark: verified >= 1.2x speedup
  • Benchmark: optimistic >= 1.4x speedup
  • Concurrent batch test: 8 parallel OK (145 tok/s aggregate)
  • Concurrent batch test: 16 parallel OK (202 tok/s aggregate)
  • Prefix cache + MTP: no batch dimension mismatch
  • Consecutive requests after sampler param change: no hang
  • Black formatting passes

🤖 Generated with Claude Code

@janhilgard
Copy link
Copy Markdown
Collaborator Author

@waybarrios look what you made me do 😄 You gave me write access and I got a little... carried away.

1.76x speedup, MTP speculative decoding, optimistic mode... I blame you for this — one "welcome aboard" and suddenly I'm rewriting the inference engine 🚀

@janhilgard janhilgard force-pushed the feat/mtp-speculative-clean branch from 51adb7d to 3ee6623 Compare February 13, 2026 12:52
This was referenced Feb 13, 2026
@janhilgard janhilgard force-pushed the feat/mtp-speculative-clean branch from 3ee6623 to 04c2d7b Compare February 13, 2026 22:21
@TomLucidor
Copy link
Copy Markdown

TomLucidor commented Feb 13, 2026

How do I test this with bknyaz/Qwen3-Coder-Next-REAM? Trying to quant it now with https://huggingface.co/spaces/mlx-community/mlx-my-repo (ideally Q4/Q3) and testing it locally

@janhilgard janhilgard force-pushed the feat/mtp-speculative-clean branch 2 times, most recently from 35b316f to e8827f0 Compare February 13, 2026 23:44
Multi-Token Prediction (MTP) support for Qwen3-Next models with
speculative decoding integration and related infrastructure.

MTP speculative decoding:
- Monkey-patch MTP head onto Qwen3-Next models at startup
- Draft 1 extra token per step using the model's own MTP head
- Always-advance verify with trim-on-reject for KVCache models
- Hybrid model support: snapshot/restore RNN state on draft reject
  to prevent recurrent state corruption in mixed attention+RNN models
- Configurable via --enable-mtp flag
- ~1.4x throughput improvement on Qwen3-Next-80B-A3B-Instruct-6bit

Weight conversion script (scripts/add_mtp_weights.py):
- Convert MTP projection weights from HF format to MLX quantized
- CPU-only processing to avoid Metal GPU crashes with 512 MoE experts
- Per-projection quantize with eager eval to cap peak memory (~3.4GB)
- Supports 4-bit and 8-bit quantization with group_size detection

Chunked prefill OOM fix:
- mx.contiguous() on chunked prefill slices to break lazy dependency
  on full padded tensor (84GB peak vs 212GB OOM)
- _recover_from_generation_error(): abort running requests + clear
  batch state on fatal errors (OOM, Metal crash)
- Exception handler in step() recovers instead of re-raising
  (prevents infinite loop in engine_core)

Additional features:
- GPT-OSS reasoning parser (--reasoning-parser gpt_oss)
- KV cache quantization support (--kv-cache-quantization)
- CLI args: --enable-mtp, --kv-cache-quantization, --reasoning-parser
- API utils for reasoning content extraction
- Comprehensive test coverage for new features

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@janhilgard janhilgard force-pushed the feat/mtp-speculative-clean branch from e8827f0 to e6dba59 Compare February 14, 2026 09:09
@janhilgard
Copy link
Copy Markdown
Collaborator Author

@TomLucidor Good question! When you quantize with mlx-my-repo (or mlx_lm.convert), the mtp.* weights get stripped — so MTP won't work out of the box on the quantized model.

We have a script in this PR that fixes that: scripts/add_mtp_weights.py

It downloads the MTP shard from the original BF16 model, quantizes the weights to match your target quant level, and injects them into the MLX model directory. Usage:

python scripts/add_mtp_weights.py \
    --mlx-model-path /path/to/your/quantized/model \
    --source-model bknyaz/Qwen3-Coder-Next-REAM \
    --bits 4  # or 3, matching your target quant

Caveat for the REAM model: The script was written for Qwen3-Next-80B (512 experts). REAM uses 384 experts (expert merging), so the MTP shard layout might differ. The script reads num_experts from the target model's config.json and adapts, but you'd need to verify that the REAM model's source repo actually has the MTP weights in a compatible shard. If the MTP head's experts were also merged to 384, the shard name/structure might be different from the default model-00041-of-00041.safetensors.

Simplest path: quantize the model first, then try running the script pointing --source-model at bknyaz/Qwen3-Coder-Next-REAM. If that fails (wrong shard name or missing MTP weights), let us know and we can debug the shard mapping together.

Once MTP weights are in place, just add --enable-mtp to vllm-mlx serve:

vllm-mlx serve /path/to/quantized/model --enable-mtp --port 8080

@waybarrios
Copy link
Copy Markdown
Owner

Interesting, we are providing post-processing code that is awesome! Let me take a deep look on this PR.

@janhilgard
Copy link
Copy Markdown
Collaborator Author

Thanks! Yeah, the post-processing script (add_mtp_weights.py) is necessary because all MLX quantization tools (mlx_lm.convert, mlx-my-repo) strip the mtp.* weights during conversion. The script re-adds them from the original BF16 source model, quantized to match the target bit width.

I've been actively testing this PR against opencode — MTP verified mode with continuous batching, prefix cache, tool calling — everything works well so far. No hangs, no stale state between requests, prefix cache + MTP batch dimension is stable.

Take your time with the review, happy to address any feedback!

@waybarrios
Copy link
Copy Markdown
Owner

I suspect the Opencode problem lies elsewhere. I'm still looking into it and waiting to be 100% sure before posting updates, but I'll share more soon. However, this idea is interesting, though.

@waybarrios
Copy link
Copy Markdown
Owner

Lines 687-689 unconditionally quantize the cache before the duplicate-entry early return at line 694 and before _trim_to_offset() at line 700, but then lines 703-707 quantize again with the proper kv_min_quantize_tokens threshold check. This means duplicate entries get quantized for nothing (defeating the "skip expensive trim/quantize" optimization from PR #73), new entries get quantized twice, and the first quantization runs on untrimmed oversized KV buffers. The fix is to simply remove lines 687-689.

# Quantize cache layers if configured
if self._config.kv_quantize:
cache = _quantize_cache(
cache, self._config.kv_bits, self._config.kv_group_size
)
tokens_key = tuple(tokens)
# If already cached, just update LRU order (skip expensive trim/quantize)
if tokens_key in self._entries:
self._entries.move_to_end(tokens_key)
return True
# Trim oversized KV arrays to actual used size
cache = _trim_to_offset(cache)
# Quantize if enabled and sequence is long enough
if (
self._config.kv_quantize
and len(tokens) >= self._config.kv_min_quantize_tokens
):
cache = _quantize_cache(
cache, self._config.kv_bits, self._config.kv_group_size
)

@waybarrios
Copy link
Copy Markdown
Owner

@janhilgard I found one issue, but I fixed it in the last commit

@waybarrios
Copy link
Copy Markdown
Owner

Take a look at this: https://docs.vllm.ai/projects/ascend/en/main/user_guide/feature_guide/speculative_decoding.html

Since they validated this across many models, we should aim to replicate that. We can test MTP with various models to see how it impacts TTFT and throughput. Remind me, what are your hardware specs?

@janhilgard
Copy link
Copy Markdown
Collaborator Author

Good catch on the duplicate quantization, thanks for the fix! Pulled 9ff75f5 and merged into fork/main.

Re: hardware — Apple M3 Ultra, 256GB unified memory. Happy to run MTP benchmarks across different models. The vLLM Ascend validation matrix is a great reference — we should at minimum cover:

  • Qwen3-Next-80B (MoE, already validated — 1.4-1.76x with MTP)
  • Qwen3-Coder-Next (if/when REAM gets MTP weights sorted)
  • A dense model with MTP heads (if any MLX-converted ones exist)

I can set up a benchmark script that measures both TTFT and throughput with/without --enable-mtp, similar to what Ascend did. Want me to open a separate issue for tracking the benchmark matrix?

@waybarrios
Copy link
Copy Markdown
Owner

Ok will accept this. But then do the benchmarking thing.

@waybarrios waybarrios merged commit 2d48123 into waybarrios:main Feb 14, 2026
7 checks passed
@janhilgard
Copy link
Copy Markdown
Collaborator Author

MTP Benchmark Results — GPU-standard metrics

Ran a comprehensive MTP benchmark on Qwen3-Next-80B-A3B-Instruct-6bit (M3 Ultra 256GB) with --enable-mtp, measuring GPU-standard metrics (TTFT, ITL, throughput, E2E latency) at multiple concurrency levels.

Single-request Latency Profile (21 runs, short/medium/long prompts)

Metric Value
Throughput 43.1 tok/s avg (40.7 – 45.1)
TTFT 78ms p50, 150ms p95
ITL (Inter-Token Latency) 23.2ms mean, 31.3ms p95
E2E latency 3.25s avg (300 max_tokens)

Concurrency Matrix (MTP enabled, 2 rounds averaged)

Concurrency Agg tok/s Per-req tok/s TTFT p50 ITL mean E2E p50
×1 16.5 41.1 475ms 24.3ms 0.57s
×4 68.2 27.0 370ms 34.3ms 3.40s
×8 105.4 19.7 219ms 47.9ms 4.29s
×16 152.1 13.4 331ms 72.7ms 12.18s

Key observations

  1. Aggregate throughput scales well — ×16 gives 152 tok/s (9.2× single-request), good linear scaling
  2. Per-request throughput degrades gracefully — from 41 tok/s (×1) to 13.4 tok/s (×16), ITL grows ~3×
  3. ITL is very consistent — p95 is only ~35% above mean (31ms vs 23ms), suggesting MTP overhead is predictable
  4. TTFT first-run cold penalty — first request after warmup shows ~870ms TTFT (prefix cache miss), subsequent requests drop to 65-80ms

Comparison with GPU MTP results

vllm-mlx (M3 Ultra) SGLang (H200) vLLM (MI300X)
Single-user throughput 43 tok/s 51 tok/s/rank
ITL p50 23ms
Concurrency scaling (×16/×1) 9.2× agg 1.2× at conc=64

Benchmark script

Attached below — standalone Python script that measures TTFT, ITL distribution (p50/p95/p99), throughput, E2E latency at configurable concurrency levels. Outputs CSV.

#!/usr/bin/env python3
"""
MTP Benchmark Suite for vllm-mlx.

Usage:
  python3 benchmark_mtp.py                                    # full benchmark
  python3 benchmark_mtp.py --test latency                     # latency only
  python3 benchmark_mtp.py --test concurrency --concurrency 1,4,8,16
  python3 benchmark_mtp.py --compare localhost:1239,localhost:8080  # A/B
"""

import argparse
import asyncio
import csv
import json
import os
import statistics
import sys
import time
from dataclasses import dataclass, field, asdict
from datetime import datetime
from typing import Optional

import aiohttp

DEFAULT_SERVER = "http://localhost:1239"
RESULTS_DIR = "benchmark_results"

PROMPTS = {
    "short": [
        "What is 2+2?",
        "Say hello.",
        "What color is the sky?",
    ],
    "medium": [
        "Explain the difference between TCP and UDP in networking.",
        "Write a Python function that checks if a string is a palindrome.",
        "Describe how garbage collection works in modern programming languages.",
    ],
    "long": [
        (
            "You are a senior software architect reviewing a microservices system. "
            "The system has 12 services communicating via gRPC and Kafka. "
            "Service A (user-auth) handles authentication with JWT tokens. "
            "Service B (order-processing) manages orders and communicates with "
            "Service C (inventory) and Service D (payment-gateway). "
            "Service E (notification) sends emails and push notifications. "
            "The team reports that during peak hours (10K req/s), the system "
            "experiences cascading failures starting from Service D. "
            "Analyze the potential causes and propose a comprehensive solution "
            "including circuit breakers, bulkheads, and retry strategies."
        ),
    ],
}


@dataclass
class TokenTiming:
    index: int
    timestamp: float
    itl: float


@dataclass
class RequestMetrics:
    prompt_category: str
    prompt_length: int
    completion_tokens: int
    ttft: float
    total_time: float
    throughput: float
    itl_mean: float
    itl_p50: float
    itl_p95: float
    itl_p99: float
    itl_min: float
    itl_max: float
    itl_std: float
    error: Optional[str] = None


@dataclass
class ConcurrencyResult:
    concurrency: int
    num_requests: int
    wall_time: float
    aggregate_tps: float
    per_request_tps_mean: float
    per_request_tps_min: float
    per_request_tps_max: float
    ttft_mean: float
    ttft_p50: float
    ttft_p95: float
    ttft_min: float
    ttft_max: float
    itl_mean: float
    itl_p50: float
    itl_p95: float
    e2e_mean: float
    e2e_p50: float
    e2e_p95: float
    errors: int


def _percentile(data: list[float], p: float) -> float:
    if not data:
        return 0.0
    s = sorted(data)
    k = (len(s) - 1) * p / 100
    f = int(k)
    c = f + 1
    if c >= len(s):
        return s[-1]
    return s[f] + (k - f) * (s[c] - s[f])


async def streaming_request_with_itl(
    session, url, model, messages, max_tokens=300,
):
    payload = {
        "model": model, "messages": messages,
        "max_tokens": max_tokens, "stream": True, "temperature": 0.7,
    }
    timings = []
    t_start = time.perf_counter()
    first_token_time = None
    token_count = 0
    last_token_time = None
    prompt_tokens = 0

    try:
        async with session.post(f"{url}/v1/chat/completions", json=payload) as resp:
            if resp.status != 200:
                body = await resp.text()
                m = RequestMetrics("", 0, 0, 0, time.perf_counter()-t_start,
                    0, 0, 0, 0, 0, 0, 0, 0, error=f"HTTP {resp.status}: {body[:200]}")
                return m, []
            async for line in resp.content:
                decoded = line.decode("utf-8").strip()
                if not decoded.startswith("data: "): continue
                data_str = decoded[6:]
                if data_str == "[DONE]": break
                try: chunk = json.loads(data_str)
                except json.JSONDecodeError: continue
                choices = chunk.get("choices", [])
                if not choices: continue
                content = choices[0].get("delta", {}).get("content", "")
                if content:
                    now = time.perf_counter()
                    if first_token_time is None: first_token_time = now
                    itl = now - last_token_time if last_token_time else now - t_start
                    timings.append(TokenTiming(token_count, now, itl))
                    last_token_time = now
                    token_count += 1
                usage = chunk.get("usage")
                if usage: prompt_tokens = usage.get("prompt_tokens", 0)
    except Exception as e:
        m = RequestMetrics("", 0, 0, 0, time.perf_counter()-t_start,
            0, 0, 0, 0, 0, 0, 0, 0, error=str(e))
        return m, []

    t_end = time.perf_counter()
    if first_token_time is None: first_token_time = t_end
    ttft = first_token_time - t_start
    gen_time = t_end - first_token_time
    throughput = (token_count - 1) / gen_time if gen_time > 0 and token_count > 1 else 0
    itl_values = [t.itl for t in timings[1:]] if len(timings) > 1 else [0]

    return RequestMetrics(
        "", prompt_tokens, token_count, ttft, t_end - t_start, throughput,
        statistics.mean(itl_values), _percentile(itl_values, 50),
        _percentile(itl_values, 95), _percentile(itl_values, 99),
        min(itl_values), max(itl_values),
        statistics.stdev(itl_values) if len(itl_values) > 1 else 0,
    ), timings


async def check_server(url):
    try:
        async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as s:
            async with s.get(f"{url}/v1/models") as r:
                if r.status == 200:
                    data = await r.json()
                    models = data.get("data", [])
                    if models: return models[0].get("id", "unknown")
    except Exception: pass
    return None


async def test_latency_profile(url, model, name, num_runs=3, max_tokens=300):
    print(f"\n{'─'*70}\n  Latency Profile — {name}\n{'─'*70}")
    all_metrics = []
    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=180)) as session:
        for cat, prompts in PROMPTS.items():
            for prompt in prompts:
                for run in range(num_runs):
                    msgs = [{"role": "user", "content": prompt}]
                    m, _ = await streaming_request_with_itl(session, url, model, msgs, max_tokens)
                    m.prompt_category = cat
                    if not m.error:
                        print(f"  [{cat}] Run {run+1}: {m.completion_tokens} tok, "
                              f"TTFT={m.ttft:.3f}s, {m.throughput:.1f} tok/s, "
                              f"ITL_p50={m.itl_p50*1000:.1f}ms, ITL_p95={m.itl_p95*1000:.1f}ms")
                    all_metrics.append(m)
                    await asyncio.sleep(0.3)
    valid = [m for m in all_metrics if not m.error]
    if valid:
        print(f"\n  Summary ({len(valid)} runs):")
        print(f"    TTFT:  avg={statistics.mean(m.ttft for m in valid):.3f}s")
        print(f"    TPS:   avg={statistics.mean(m.throughput for m in valid):.1f}")
        print(f"    ITL:   mean={statistics.mean(m.itl_mean for m in valid)*1000:.1f}ms, "
              f"p95={statistics.mean(m.itl_p95 for m in valid)*1000:.1f}ms")
    return all_metrics


async def test_concurrency_matrix(url, model, name, levels, max_tokens=200, rounds=2):
    print(f"\n{'─'*70}\n  Concurrency Matrix — {name}\n{'─'*70}")
    prompts = PROMPTS["short"][:2] + PROMPTS["medium"][:2] + PROMPTS["long"][:1]
    results = []
    for n in levels:
        round_results = []
        for rnd in range(rounds):
            req_prompts = [prompts[i % len(prompts)] for i in range(n)]
            async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=300)) as session:
                t0 = time.perf_counter()
                tasks = [streaming_request_with_itl(session, url, model,
                    [{"role":"user","content":p}], max_tokens) for p in req_prompts]
                responses = await asyncio.gather(*tasks)
                wall = time.perf_counter() - t0
            mlist = [r[0] for r in responses]
            valid = [m for m in mlist if not m.error]
            if not valid: continue
            total_tok = sum(m.completion_tokens for m in valid)
            cr = ConcurrencyResult(
                n, len(mlist), wall, total_tok/wall if wall>0 else 0,
                statistics.mean(m.throughput for m in valid),
                min(m.throughput for m in valid), max(m.throughput for m in valid),
                statistics.mean(m.ttft for m in valid),
                _percentile([m.ttft for m in valid], 50),
                _percentile([m.ttft for m in valid], 95),
                min(m.ttft for m in valid), max(m.ttft for m in valid),
                statistics.mean(m.itl_mean for m in valid),
                _percentile([m.itl_mean for m in valid], 50),
                _percentile([m.itl_mean for m in valid], 95),
                statistics.mean(m.total_time for m in valid),
                _percentile([m.total_time for m in valid], 50),
                _percentile([m.total_time for m in valid], 95),
                len(mlist)-len(valid),
            )
            round_results.append(cr)
            print(f"  ×{n} round {rnd+1}: agg={cr.aggregate_tps:.1f} tok/s, "
                  f"per-req={cr.per_request_tps_mean:.1f}, TTFT_p50={cr.ttft_p50:.3f}s")
            await asyncio.sleep(2)
        if round_results:
            avg = ConcurrencyResult(
                n, round_results[0].num_requests,
                statistics.mean(r.wall_time for r in round_results),
                statistics.mean(r.aggregate_tps for r in round_results),
                statistics.mean(r.per_request_tps_mean for r in round_results),
                min(r.per_request_tps_min for r in round_results),
                max(r.per_request_tps_max for r in round_results),
                statistics.mean(r.ttft_mean for r in round_results),
                statistics.mean(r.ttft_p50 for r in round_results),
                statistics.mean(r.ttft_p95 for r in round_results),
                min(r.ttft_min for r in round_results),
                max(r.ttft_max for r in round_results),
                statistics.mean(r.itl_mean for r in round_results),
                statistics.mean(r.itl_p50 for r in round_results),
                statistics.mean(r.itl_p95 for r in round_results),
                statistics.mean(r.e2e_mean for r in round_results),
                statistics.mean(r.e2e_p50 for r in round_results),
                statistics.mean(r.e2e_p95 for r in round_results),
                sum(r.errors for r in round_results),
            )
            results.append(avg)
    if results:
        print(f"\n  {'Conc':>5} {'Agg tok/s':>10} {'Per-req':>10} {'TTFT p50':>10} "
              f"{'ITL mean':>10} {'E2E p50':>10}")
        for r in results:
            print(f"  {r.concurrency:>5} {r.aggregate_tps:>10.1f} "
                  f"{r.per_request_tps_mean:>10.1f} {r.ttft_p50:>9.3f}s "
                  f"{r.itl_mean*1000:>9.1f}ms {r.e2e_p50:>9.2f}s")
    return results


def save_csv(metrics, filepath, fieldnames):
    os.makedirs(os.path.dirname(filepath) or ".", exist_ok=True)
    with open(filepath, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        for m in metrics: w.writerow(asdict(m))
    print(f"  Saved: {filepath}")


async def main():
    parser = argparse.ArgumentParser(description="MTP Benchmark Suite")
    parser.add_argument("--server", default=DEFAULT_SERVER)
    parser.add_argument("--model", default=None)
    parser.add_argument("--test", choices=["latency","concurrency","all"], default="all")
    parser.add_argument("--concurrency", default="1,4,8,16")
    parser.add_argument("--runs", type=int, default=3)
    parser.add_argument("--rounds", type=int, default=2)
    parser.add_argument("--max-tokens", type=int, default=300)
    parser.add_argument("--compare", default=None,
        help="A/B compare: host1:port1,host2:port2")
    args = parser.parse_args()
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")

    url = args.server
    model = args.model or await check_server(url)
    if not model:
        print(f"Server {url} unavailable!"); sys.exit(1)
    print(f"Server: {url}\nModel: {model}")

    # Warmup
    async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30)) as s:
        wr, _ = await streaming_request_with_itl(s, url, model,
            [{"role":"user","content":"Hello."}], 10)
        print(f"Warmup: TTFT={wr.ttft:.3f}s, {wr.throughput:.1f} tok/s")
    await asyncio.sleep(1)

    if args.test in ("latency", "all"):
        m = await test_latency_profile(url, model, model, args.runs, args.max_tokens)
        save_csv(m, f"{RESULTS_DIR}/mtp_latency_{ts}.csv",
            [f.name for f in RequestMetrics.__dataclass_fields__.values()])
        await asyncio.sleep(2)

    if args.test in ("concurrency", "all"):
        levels = [int(x) for x in args.concurrency.split(",")]
        r = await test_concurrency_matrix(url, model, model, levels, args.max_tokens, args.rounds)
        save_csv(r, f"{RESULTS_DIR}/mtp_concurrency_{ts}.csv",
            [f.name for f in ConcurrencyResult.__dataclass_fields__.values()])

if __name__ == "__main__":
    asyncio.run(main())

Next steps — I'll open a separate issue proposing:

  1. Exposing _mtp_stats (accepted/rejected/errors) via a /v1/mtp/stats API endpoint
  2. Adaptive k (dynamic mtp_num_draft_tokens based on runtime acceptance rate)

@TomLucidor
Copy link
Copy Markdown

@janhilgard I need to compare the metadata of the original Qwen3-Coder-Next to the REAM version for the MTP feature, BUT I can't seem to locate the metadata/namespace for it. Wondering where it is so I can check if the REAM version can be made compatible (and not just an MLX stripper issue)?

@janhilgard
Copy link
Copy Markdown
Collaborator Author

@TomLucidor The MTP weights live under the mtp.* key prefix in the safetensors files (e.g. mtp.fc.weight, mtp.norm.weight, mtp.layers.0.*), and the config flag is num_nextn_predict_layers (set to 1 when MTP is present).

However — Qwen3-Coder-Next doesn't ship with MTP weights at all. The non-Coder variant (Qwen3-Next-80B-A3B-Instruct) has them in its last shard (model-00041-of-00041.safetensors, ~3.3 GB) — that's where our add_mtp_weights.py script pulls from. But the Coder model has only 40 shards vs. 41, and no mtp.* keys anywhere in its weight map. This is an upstream decision by Qwen, not an MLX conversion stripping issue.

So the compatibility chain breaks at the source:

  1. Qwen3-Coder-Next — no MTP weights released → can't use --enable-mtp
  2. Qwen3-Coder-Next-REAM — derived from (1), also no MTP weights. Plus the expert count is 384 vs. 512, so even if you tried to graft the non-Coder MTP head, the MTP layer's MoE dimensions wouldn't match.

What you can check to confirm: Download the Coder model's model.safetensors.index.json and search for any mtp.* keys:

curl -s https://huggingface.co/Qwen/Qwen3-Coder-Next/resolve/main/model.safetensors.index.json | python3 -c "import sys,json; wm=json.load(sys.stdin)['weight_map']; mtp=[k for k in wm if 'mtp' in k]; print(f'{len(mtp)} MTP keys found') if mtp else print('No MTP keys — MTP not included in this model')"

If Qwen releases MTP weights for the Coder variant in the future, the add_mtp_weights.py script would work — you'd just point --source-model at the Coder repo. But for now, MTP is only available for the non-Coder Qwen3-Next line.

@TomLucidor
Copy link
Copy Markdown

TomLucidor commented Feb 15, 2026

@janhilgard okay now I am slightly spooked

  1. It is better to test this following REAM model that will be converted to MLX (to see if MTP is compatible with REAM), but IDK where/show to get this working and tested in vllm-mlx https://huggingface.co/SamsungSAILMontreal/Qwen3-Next-80B-A3B-Instruct-REAM
  2. If Qwen3-Coder-Next has no MTP, then transplanting the REAM from the Instruct model very likely cause issues then? If not, it would be a worthy try, but if so... That would be a bit unfortunate, needing Eagle3 instead (another feature)

REAM has no MTP support https://huggingface.co/SamsungSAILMontreal/Qwen3-Next-80B-A3B-Instruct-REAM/raw/main/model.safetensors.index.json compared to the original https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct/raw/main/model.safetensors.index.json (Now I gonna ask the REAM dev to see if they can support MTP just in case)

sooth pushed a commit to sooth/vllm-mlx that referenced this pull request Feb 27, 2026
Merge 17 upstream commits including:
- KV cache quantization for prefix cache memory reduction (waybarrios#62)
- Streaming tool call parsing via ToolParser integration (waybarrios#46)
- MTP speculative decoding for Qwen3-Next (waybarrios#82)
- GPT-OSS reasoning parser and Harmony format parsers
- mlx-lm >= 0.30.5 requirement, transformers >= 5.0.0
- BatchMambaCache fix for mlx-lm >= 0.30.6 (waybarrios#89)
- MLLM continuous batching fixes (waybarrios#76)
- Force MLLM mode option (waybarrios#81)
- Various bug fixes

Conflict resolution:
- server.py: Replaced local tool_call_buffering with upstream's
  ToolParser-based streaming (more robust)
- cli.py: Deduplicated --mllm, --default-temperature, --default-top-p
  args (upstream already added them), kept local --embedding-model
- mamba_cache.py: Took upstream's conditional HAS_MAMBA_CACHE approach
- pyproject.toml: Took upstream's version and dependency changes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants