feat: probabilistic MTP acceptance (speculative sampling)#1085
Closed
Thump604 wants to merge 21 commits intoml-explore:mainfrom
Closed
feat: probabilistic MTP acceptance (speculative sampling)#1085Thump604 wants to merge 21 commits intoml-explore:mainfrom
Thump604 wants to merge 21 commits intoml-explore:mainfrom
Conversation
Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights. Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable).
Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls. After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state. On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before. Acceptance rate ~64% average / ~85% on 100-token run.
- Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate) - Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache) - Disable batch generation for MTP models (draft/verify loop requires single-sequence processing) Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator
…t_predicate) - Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden). - Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy). 27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x).
Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding. MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional.
8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers. - MTP module instantiation and cache creation - return_hidden shape and pre-norm verification - mtp_forward output shape - quant_predicate excludes mtp.fc - Token identity: mtp_generate_step == generate_step (greedy) - End-to-end mtp_generate_step completion
Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect.
The MTP layer in Qwen3.5 MoE models (35B-A3B, 122B-A10B) uses per-expert
weights (experts.{i}.gate_proj etc) unlike backbone layers which use
fused gate_up_proj. Without stacking these into switch_mlp format,
conversion fails with "768 parameters not in model".
This fix adds MTP expert weight stacking to qwen3_5_moe.py sanitize(),
enabling MTP-preserving conversion for all Qwen3.5 MoE models.
Tested: Qwen3.5-35B-A3B (256 experts, 8-bit) and Qwen3.5-122B-A10B
(256 experts, 5-bit) both convert and run successfully with --mtp.
Add optional parameters to KVCache and QuantizedKVCache: - `step` (int): Override the class-level step size (default 256). Larger values reduce the number of boundary reallocations and GPU sync points during generation. For example, step=1024 reduces reallocation frequency 4x. - `max_size` (int, KVCache only): Pre-allocate the buffer to hold this many tokens on first use. Eliminates ALL subsequent boundary reallocations and concatenations. When the maximum context length is known (e.g., from server configuration), this avoids the repeated allocate-concatenate-free cycle that causes transient memory spikes and GPU sync points. Also adds `max_context` parameter to `make_prompt_cache()` to pass through to KVCache constructors. All parameters are optional with backward-compatible defaults. Existing code calling `KVCache()` or `make_prompt_cache(model)` is unaffected. Motivation: On M2 Ultra 128GB with a 122B MoE model (~82 GB weights), the repeated KV cache boundary reallocations (every 256 tokens across 12 attention layers) create transient memory spikes of ~2x the cache size at each boundary. With step=256, a 4000-token generation crosses 15 boundaries, each requiring 24 concatenation operations (12 layers x K+V). Pre-allocation eliminates this entirely.
Pin behavioral contracts for review findings: checkpoint persistence through repeated extraction, partial rewind safety on longer hits, refcount lifecycle, deepcopy failure resilience, single-token shorter match threshold, prefix non-eviction on longer insert, and checkpoint localization suppression at prompt boundaries. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Non-thinking models get no benefit from checkpoint caching (their cache keys don't diverge between turns), so storing checkpoint entries is pure memory overhead. Gate checkpoint creation on tokenizer.has_thinking to eliminate unnecessary cache growth for standard models. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mtp_generate_step squeezes logits to 1D (vocab,) via logits[:, i, :].squeeze(0) before calling _process_and_sample. Logits processors (presence_penalty, repetition_penalty) index with logits[:, tokens] which requires 2D (batch, vocab). Unsqueeze to 2D before processors, squeeze back after. Without this, any request using presence_penalty or repetition_penalty with MTP enabled crashes with: ValueError: Too many indices for array with 1 dimensions
Non-thinking models with non-trimmable caches (ArraysCache) need the checkpoint entry to enable cache reuse via the shorter-cache path. The early return for non-thinking models was a regression from upstream behavior where _compute_prompt_checkpoint always returns (True, -1) for user-terminal chat requests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- generate.py: scale prefill_step_size inversely with cache size when prompt_processed_tokens > 65536, preventing GPU watchdog kills on long-context inference (Metal ~5s dispatch limit) - nemotron_h.py: default time_step_limit lower bound to 0.0 when time_step_min is not set (fixes SSM initialization edge case)
Metal allocator non-determinism causes the prompt-path subtest to flake at 1.35x. A real memory leak over 120 steps would be 10x+, so 2.0x still catches the failure mode without false positives. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address reviewer feedback from PR ml-explore#1042: - CacheEntry.count → ref_count: the field is decremented on extraction, so it's a reference count, not an insertion counter. - Add default rewind() on _BaseCache and a _has_rewind_impl() helper that uses method identity to detect real overrides. This replaces the inline introspection in _can_rewind_layer_cache with a cleaner helper while preserving the same behavior: third-party _BaseCache subclasses that implement rewind() participate automatically without needing an explicit opt-in flag. - Add targeted tests for the _has_rewind_impl contract covering base class, no-override subclass, custom override, and BatchRotatingKVCache. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolve qwen3_5.py conflict: keep MTP cache[0]/cache[1] assignments, add cache.advance(S) from PR. Fix nemotron_h.py time_step_min regression: PR dropped time_step_min guard, restored our version that reads from config.
Replace exact-match draft verification with probabilistic acceptance: accept draft token with probability min(1, p_target/p_draft). At temp=0, both distributions are peaked so this degenerates to exact match. At temp>0, accepts "good enough" tokens that the target model considers likely, dramatically improving acceptance rates. Tested on Qwen3.5-122B-A10B at temp=0.6: 90.5% acceptance (was ~5% with exact match), 38.8 tok/s (was ~16.8 without working MTP).
AirRunner
added a commit
to AirRunner/mlx-lm
that referenced
this pull request
Apr 3, 2026
With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass distribution. For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed. This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit. Suggested by @janhilgard; implementation reference in ml-explore#1085 by @Thump604.
AirRunner
added a commit
to AirRunner/mlx-lm
that referenced
this pull request
Apr 3, 2026
With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass distribution. For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed. This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit. Suggested by @janhilgard; implementation reference in ml-explore#1085 by @Thump604.
Author
|
Closing — AirRunner integrated probabilistic acceptance directly into #990 (commit 66dc1c4). The implementation matches my approach with an important improvement: greedy sampling (sampler=None) uses exact-match while stochastic (temp>0) uses the log-probability ratio. Benchmarks confirm +3% acceptance recovery at temp=0.6. No need for a separate PR. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Replace exact-match MTP draft verification with probabilistic acceptance using the standard speculative sampling criterion: accept draft token with probability
min(1, p_target(draft) / p_draft(draft)).Problem
The current MTP verification uses exact match (
verify_pred == draft_tok). At temperature > 0, the target model often considers the draft token highly probable but samples a different token. This wastes correct drafts:Fix
Compute the log acceptance ratio from the target and draft log-probability distributions (both already available from
_process_and_sample):At temp=0, both distributions are peaked, so this degenerates to exact match (no behavior change). At temp > 0, it accepts tokens the target considers at least as likely as the draft does.
Results
Tested on M2 Ultra 128GB with Qwen3.5-122B-A10B-VLM-MTP-5bit at temp=0.6:
Test plan