feat: Add speculative decoding support with draft models#45
feat: Add speculative decoding support with draft models#45janhilgard wants to merge 1 commit intowaybarrios:mainfrom
Conversation
|
This is super! |
|
Thanks @enryold! Machine: Mac Studio M3 Ultra, 256GB unified memory Benchmark method: Custom Python script using async
Results with GPT-OSS-20B-4bit (~11GB model):
M1 Max 64GB benchmarks would be great! The 80B model won't fit (~55GB weights), but GPT-OSS-20B-4bit (~11GB) or Qwen3-8B would be good candidates. Happy to help with setup. |
|
Great work on this. I have a question about the accuracy tradeoff when using small draft models. In theory, speculative decoding should be lossless because the target model verifies every token the draft model proposes. If the draft model guesses wrong, the token gets rejected and the target model generates the correct one. So the final output quality should be identical to running the target model alone, you just get variable speedup depending on how well the draft model predicts. But in practice I'm wondering about a few things:
My concern is that for production use cases where accuracy matters most (agentic tool calling, code generation), the small draft models might not help much and could even hurt. The sweet spot seems to be repetitive or predictable content, which is not where we typically need the most performance. Would be useful to add some acceptance rate logging so users can see whether speculative decoding is actually helping for their specific workload. |
|
Great questions @waybarrios, let me address them one by one. 1. Mathematically lossless — Yesmlx-lm uses full rejection sampling. The implementation in
The output distribution is identical to non-speculative decoding: accepted tokens match main model output exactly, and rejected tokens trigger immediate fallback to the main model's sample. This is genuine rejection sampling, not an approximation. 2. Runtime fallback — not yet, but feasibleCurrently mlx-lm's A practical approach for a follow-up:
The main challenge is that mlx-lm's generator is all-or-nothing per session — switching mid-stream would require restarting the generator. But we could at least measure and log, and disable speculation for subsequent requests if a workload consistently underperforms. 3. Acceptance rates by task typeFrom our benchmarks (Mac Studio M3 Ultra, Qwen3-Next-80B + Qwen3-0.6B draft):
The 70% acceptance rate was on general instruction-following. You're right that coding/reasoning tasks will likely see lower acceptance — the 0.6B model simply can't predict complex reasoning chains. The speedup is real for conversational/predictable content but marginal-to-negative for complex generation. Why it still works for the 80B MoE case: Qwen3-Next only activates ~3B params per token despite 80B total. The 0.6B draft model has surprisingly decent overlap for common patterns, and same tokenizer (151,643 vocab) ensures zero alignment overhead. Important caveat: speculative decoding is workload-dependentFrom our own production usage, the impact varies dramatically by scenario: Where it helps significantly:
Where it actually hurts:
The takeaway is clear: speculative decoding is not a universal win. It can deliver substantial speedups for the right workload, but it can equally slow things down for the wrong one. Users absolutely need to benchmark their specific use case. The acceptance rate statistics are essential — they tell you immediately whether speculation is helping or hurting. Next stepsI agree acceptance rate logging would be valuable. I'll add it as a follow-up:
This keeps the current PR focused on the core feature, with observability as a clean follow-up. |
Can you share this script and the setup so we can start running some benchmarks on diff hardware? |
|
@janhilgard are there a standard/template/example command for |
1967ac9 to
04252f6
Compare
|
@janhilgard thanks for the updates, and it seems like things are going smoothly. To avoid this whole PR only work towards Qwen3-Next, I got some weird test items. Just to stress-test this feature:
To check if they can handle models with weird tokenization file formats, maybe As for dedicated models like Eagle3:
And I wonder how tool use and code generation can be made to speed up as well (outside of Eagle3). |
- Speculative decoding with mlx-lm draft models (1.2-1.4x speedup) - HybridEngine: shared model between speculative + batched modes - JSON schema enforcement with guided generation support - Fix false positive tool call detection for regular JSON - Strip <think> tags from API responses to prevent JSON parse errors Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
04252f6 to
dcfa067
Compare
|
@janhilgard @enryold I need to clarify that most Eagle3 models are NOT compatible with vLLM-MLX since they are not MLX-based. I tried to run them to see if they are functional, but it does not work out so well (for Qwen3 and GPT-OSS-20B) e.g. ReadHatAI speculator models and AngelSlim eagle3 models. If it does work with Qwen3-30B variants, please show me the commands so I can test it ASAP #57 |
|
This is absolutely true @TomLucidor |
|
Right, Eagle3 models ship as PyTorch safetensors with a custom architecture (fusion head on top of the base model) — there's no MLX equivalent yet, so they can't be loaded by mlx-lm or vllm-mlx. What works today for speculative decoding in vllm-mlx:
Re: Qwen3-30B variants — @TomLucidor Qwen3-Coder-Next-48B-A3B should work with vllm-mlx serve mlx-community/Qwen3-30B-A3B-4bit \
--continuous-batching \
--draft-model mlx-community/Qwen3-0.6B-4bit \
--num-draft-tokens 5For Eagle3 to work in MLX, someone would need to:
This is non-trivial and blocked on the fact that no Eagle3 models exist in MLX format. If someone converts the weights, the integration side would be feasible as a follow-up to this PR. In the meantime, the draft model approach gives comparable speedups (1.2-1.3x) for models with a matching smaller variant, which covers Qwen3, Granite, and Falcon families. |
|
@janhilgard How difficult is it to port Eagle to MLX format? I am bit curious about how it could work |
|
@waybarrios even AngelSlim is asking for community feedback on how weights be made from "standard" to MLX with a custom converter. |
|
Closing: Superseded by merged #82 (MTP-based speculative decoding approach). This draft-model approach is no longer needed. |
|
Since Qwen3-Coder-Next does not have MTP layer, and MLX doesn't support EAGLE3 yet, isn't draft model the only way to enable speculative decoding for it? |
Summary
Add support for speculative decoding using mlx-lm's draft model feature, including a new HybridEngine that shares a single model instance between speculative and batched modes.
Features
New CLI arguments
--draft-model: Path to draft model (must share tokenizer with main model)--num-draft-tokens: Tokens to speculate per step (default: 4)Usage modes
Simple mode (speculative only):
vllm-mlx serve mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit \ --draft-model mlx-community/Qwen3-0.6B-4bit \ --num-draft-tokens 5Hybrid mode (speculative + batching with shared model):
vllm-mlx serve mlx-community/Qwen3-Next-80B-A3B-Instruct-4bit \ --continuous-batching \ --draft-model mlx-community/Qwen3-0.6B-4bit \ --num-draft-tokens 5HybridEngine Architecture
When both
--continuous-batchingand--draft-modelare specified, the server uses HybridEngine which:Mode switching:
active_requests < 2→ SimpleEngine (speculative decoding)active_requests >= 2→ BatchedEngine (continuous batching)RAM usage: ~45GB (vs ~90GB if running separate engines)
Benchmark results
Tested with Qwen3-Next-80B-6bit + Qwen3-0.6B-4bit:
Speculative decoding acceptance rate varies by content type - higher for repetitive text (lists, numbers).
Implementation details
HybridEngine: Manages shared model between SimpleEngine and BatchedEngine_inject_shared_model(start_engine=False): Lazy start for HybridEngine_decide_and_switch_mode(): Dynamic mode switching based on concurrent requests_switch_to_mode(): Handles ownership transfer via ModelRegistrystream_generate()withdraft_modelparameterRecent fixes
Limitations
Test plan
🤖 Generated with Claude Code