Skip to content

feat: probabilistic MTP acceptance (speculative sampling)#1085

Closed
Thump604 wants to merge 21 commits intoml-explore:mainfrom
Thump604:feat/probabilistic-mtp
Closed

feat: probabilistic MTP acceptance (speculative sampling)#1085
Thump604 wants to merge 21 commits intoml-explore:mainfrom
Thump604:feat/probabilistic-mtp

Conversation

@Thump604
Copy link
Copy Markdown

@Thump604 Thump604 commented Apr 1, 2026

Summary

Replace exact-match MTP draft verification with probabilistic acceptance using the standard speculative sampling criterion: accept draft token with probability min(1, p_target(draft) / p_draft(draft)).

Problem

The current MTP verification uses exact match (verify_pred == draft_tok). At temperature > 0, the target model often considers the draft token highly probable but samples a different token. This wastes correct drafts:

  • At temp=0.6 on Qwen3.5-122B-A10B: ~5% acceptance with exact match
  • The MTP head proposes good tokens that the target agrees with probabilistically, but exact match rejects them

Fix

Compute the log acceptance ratio from the target and draft log-probability distributions (both already available from _process_and_sample):

log_accept = (verify_lp[draft_tok_id] - draft_lp[draft_tok_id]).item()
accept = log_accept >= 0 or math.log(random.random()) < log_accept

At temp=0, both distributions are peaked, so this degenerates to exact match (no behavior change). At temp > 0, it accepts tokens the target considers at least as likely as the draft does.

Results

Tested on M2 Ultra 128GB with Qwen3.5-122B-A10B-VLM-MTP-5bit at temp=0.6:

Metric Exact Match Probabilistic
Acceptance rate ~5% 90.5%
Throughput (1024 tok) ~16.8 tok/s 38.8 tok/s
Speedup vs no-MTP ~1.0x 2.3x

Test plan

  • Verified at temp=0.6: 90.5% acceptance, 38.8 tok/s on 122B
  • Verified output correctness (coherent code generation)
  • At temp=0 this degenerates to exact match (mathematical property)
  • 12-line change, 2 new imports (math, random)

AirRunner and others added 21 commits March 17, 2026 02:53
Add mtp_generate_step() in generate.py and MTPModule/MTPDecoderLayer
in qwen3_5.py. Fixes norm weight shift for MTP-specific RMSNorm weights.

Known limitation: SSM state contamination on rejection (GatedDeltaNet layers not trimmable).
Extend GatedDeltaNet.__call__ with an n_confirmed parameter that splits the T=2 verification pass into two sub-calls.
After processing the confirmed token, the intermediate conv/ssm state is snapshotted into ArraysCache.rollback_state.
On rejection, SSM layers restore this snapshot while attention layers trim their KV cache by 1 as before.

Acceptance rate ~64% average / ~85% on 100-token run.
- Yield token.item() instead of raw mx.array to match generate_step convention (fixes detokenizer crash via stream_generate)
- Create MTP cache when prompt_cache lacks MTP entries (server creates backbone-only caches via make_prompt_cache)
- Disable batch generation for MTP models (draft/verify loop requires single-sequence processing)

Note: batch-aware MTP would need per-sequence accept/reject and SSM rollback within BatchGenerator
…t_predicate)

- Return pre-norm hidden states from Qwen3_5TextModel: apply norm in TextModel before lm_head only (avoiding double normalization (model.norm + pre_fc_norm_hidden).
- Exclude mtp.fc from quantization via quant_predicate (the fusion projection (2H→H) stays in bf16 for accuracy).

27B results after reconversion: 80.6% acceptance, 23.3 tok/s on M4 Pro (1.52x).
Replace auto-detection of MTP head with explicit --mtp flag, consistent with existing --draft-model for speculative decoding.

MTP is now opt-in. Without the flag, models with MTP weights use standard generation and batch serving remains fully functional.
8 tests using a tiny synthetic Qwen3.5 model (4 layers, hidden=64) with mtp_num_hidden_layers=1 and hybrid SSM+attention layers.
- MTP module instantiation and cache creation
- return_hidden shape and pre-norm verification
- mtp_forward output shape
- quant_predicate excludes mtp.fc
- Token identity: mtp_generate_step == generate_step (greedy)
- End-to-end mtp_generate_step completion
Instead of silently falling back to standard generation, emit a warning so the user knows their --mtp flag had no effect.
The MTP layer in Qwen3.5 MoE models (35B-A3B, 122B-A10B) uses per-expert
weights (experts.{i}.gate_proj etc) unlike backbone layers which use
fused gate_up_proj. Without stacking these into switch_mlp format,
conversion fails with "768 parameters not in model".

This fix adds MTP expert weight stacking to qwen3_5_moe.py sanitize(),
enabling MTP-preserving conversion for all Qwen3.5 MoE models.

Tested: Qwen3.5-35B-A3B (256 experts, 8-bit) and Qwen3.5-122B-A10B
(256 experts, 5-bit) both convert and run successfully with --mtp.
Add optional parameters to KVCache and QuantizedKVCache:

- `step` (int): Override the class-level step size (default 256).
  Larger values reduce the number of boundary reallocations and
  GPU sync points during generation. For example, step=1024 reduces
  reallocation frequency 4x.

- `max_size` (int, KVCache only): Pre-allocate the buffer to hold
  this many tokens on first use. Eliminates ALL subsequent boundary
  reallocations and concatenations. When the maximum context length
  is known (e.g., from server configuration), this avoids the
  repeated allocate-concatenate-free cycle that causes transient
  memory spikes and GPU sync points.

Also adds `max_context` parameter to `make_prompt_cache()` to pass
through to KVCache constructors.

All parameters are optional with backward-compatible defaults.
Existing code calling `KVCache()` or `make_prompt_cache(model)` is
unaffected.

Motivation: On M2 Ultra 128GB with a 122B MoE model (~82 GB weights),
the repeated KV cache boundary reallocations (every 256 tokens across
12 attention layers) create transient memory spikes of ~2x the cache
size at each boundary. With step=256, a 4000-token generation crosses
15 boundaries, each requiring 24 concatenation operations (12 layers
x K+V). Pre-allocation eliminates this entirely.
Pin behavioral contracts for review findings: checkpoint persistence
through repeated extraction, partial rewind safety on longer hits,
refcount lifecycle, deepcopy failure resilience, single-token shorter
match threshold, prefix non-eviction on longer insert, and checkpoint
localization suppression at prompt boundaries.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Non-thinking models get no benefit from checkpoint caching (their cache
keys don't diverge between turns), so storing checkpoint entries is pure
memory overhead. Gate checkpoint creation on tokenizer.has_thinking to
eliminate unnecessary cache growth for standard models.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
mtp_generate_step squeezes logits to 1D (vocab,) via
logits[:, i, :].squeeze(0) before calling _process_and_sample.
Logits processors (presence_penalty, repetition_penalty) index
with logits[:, tokens] which requires 2D (batch, vocab).

Unsqueeze to 2D before processors, squeeze back after.
Without this, any request using presence_penalty or
repetition_penalty with MTP enabled crashes with:
  ValueError: Too many indices for array with 1 dimensions
Non-thinking models with non-trimmable caches (ArraysCache) need the
checkpoint entry to enable cache reuse via the shorter-cache path.
The early return for non-thinking models was a regression from upstream
behavior where _compute_prompt_checkpoint always returns (True, -1)
for user-terminal chat requests.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- generate.py: scale prefill_step_size inversely with cache size when
  prompt_processed_tokens > 65536, preventing GPU watchdog kills on
  long-context inference (Metal ~5s dispatch limit)
- nemotron_h.py: default time_step_limit lower bound to 0.0 when
  time_step_min is not set (fixes SSM initialization edge case)
Metal allocator non-determinism causes the prompt-path subtest to
flake at 1.35x. A real memory leak over 120 steps would be 10x+,
so 2.0x still catches the failure mode without false positives.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Address reviewer feedback from PR ml-explore#1042:

- CacheEntry.count → ref_count: the field is decremented on extraction,
  so it's a reference count, not an insertion counter.

- Add default rewind() on _BaseCache and a _has_rewind_impl() helper
  that uses method identity to detect real overrides. This replaces the
  inline introspection in _can_rewind_layer_cache with a cleaner helper
  while preserving the same behavior: third-party _BaseCache subclasses
  that implement rewind() participate automatically without needing an
  explicit opt-in flag.

- Add targeted tests for the _has_rewind_impl contract covering base
  class, no-override subclass, custom override, and BatchRotatingKVCache.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Resolve qwen3_5.py conflict: keep MTP cache[0]/cache[1] assignments,
add cache.advance(S) from PR.

Fix nemotron_h.py time_step_min regression: PR dropped time_step_min
guard, restored our version that reads from config.
Replace exact-match draft verification with probabilistic acceptance:
accept draft token with probability min(1, p_target/p_draft).

At temp=0, both distributions are peaked so this degenerates to exact
match. At temp>0, accepts "good enough" tokens that the target model
considers likely, dramatically improving acceptance rates.

Tested on Qwen3.5-122B-A10B at temp=0.6: 90.5% acceptance (was ~5%
with exact match), 38.8 tok/s (was ~16.8 without working MTP).
AirRunner added a commit to AirRunner/mlx-lm that referenced this pull request Apr 3, 2026
With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass
distribution.

For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed.

This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit.

Suggested by @janhilgard; implementation reference in ml-explore#1085 by @Thump604.
AirRunner added a commit to AirRunner/mlx-lm that referenced this pull request Apr 3, 2026
With sampler=None (greedy decoding): keep exact-match acceptance, this is the mathematically correct criterion for a deterministic point-mass
distribution.

For stochastic samplers (temp > 0), accept the draft token with probability min(1, p_target / p_draft), computed from the log-probability distributions already returned by _process_and_sample. No extra forward passes needed.

This recovers the greedy acceptance rate (~46%) at any temperature, vs ~43% with exact-match at temp=0.6 on Qwen3.5-27B 4-bit.

Suggested by @janhilgard; implementation reference in ml-explore#1085 by @Thump604.
@Thump604
Copy link
Copy Markdown
Author

Thump604 commented Apr 3, 2026

Closing — AirRunner integrated probabilistic acceptance directly into #990 (commit 66dc1c4). The implementation matches my approach with an important improvement: greedy sampling (sampler=None) uses exact-match while stochastic (temp>0) uses the log-probability ratio. Benchmarks confirm +3% acceptance recovery at temp=0.6. No need for a separate PR.

@Thump604 Thump604 closed this Apr 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants