Skip to content

feat: SpecPrefill — attention-based sparse prefill for TTFT reduction#180

Merged
waybarrios merged 7 commits intowaybarrios:mainfrom
Thump604:feat/specprefill
Mar 21, 2026
Merged

feat: SpecPrefill — attention-based sparse prefill for TTFT reduction#180
waybarrios merged 7 commits intowaybarrios:mainfrom
Thump604:feat/specprefill

Conversation

@Thump604
Copy link
Copy Markdown
Collaborator

@Thump604 Thump604 commented Mar 18, 2026

SpecPrefill: attention-based sparse prefill using a draft model to reduce TTFT on long prompts.

What: A small draft model (e.g. 2B) scores token importance via Q@K^T attention. Top-k% tokens are selected in 32-token chunks. Target model sparse-prefills only selected tokens with position-mapped RoPE preserving correct relative positions.

Results (Qwen3.5-122B, 2B draft, 20% keep):

  • 8K: 3.7x, 16K: 4.1x, 32K: 4.2x, 64K: 4.5x, 128K: 5.4x TTFT reduction

Per-request API: extra_body: {"specprefill": true, "specprefill_keep_pct": 0.2}

Depends on: #171

Files: new specprefill.py (742 lines), modified engine/simple.py, server.py, cli.py, api/models.py

Paper: doi.org/10.5281/zenodo.19120919

When both --mllm and --enable-mtp are set, SimpleEngine builds a
parallel mlx_lm TextModel sharing the VLM backbone weights (zero-copy).
Text-only requests route to mlx_lm with MTP speculative decoding;
media requests route to the mlx_vlm MLLM path.

Key components:
- text_model_from_vlm.py: Build mlx_lm TextModel from VLM weights
- Per-request routing in stream_chat() via _has_media_content()
- _stream_generate_text() for MTP-accelerated text generation
- MTP passthrough: --enable-mtp flag through CLI/server/engine/LLM

Tested on Qwen3.5-35B-A3B VLM+MTP (8-bit):
- Text (MTP): 65.3 tok/s
- Vision (MLLM): 63.8 tok/s
- Memory: 38.7 GB (zero-copy, same as single model)
Persist backbone KV cache after prefilling system prompt tokens.
On subsequent requests with the same system prompt, restore the
snapshot and only prefill the suffix (user + history) tokens.

For a 10K-token system prompt on the 122B model, this saves ~57s
per request by avoiding redundant system prompt prefill.

Implementation:
- Detect system prefix via ChatML boundary markers
- Hash prefix text for cache key validation
- On cache miss: prefill system tokens, snapshot backbone KV state
- On cache hit: restore snapshot into fresh cache, send suffix only
- Token prefix validation ensures correct split at tokenization boundary
- Single-entry cache (one system prompt at a time)
- Stats exposed via get_stats() → system_kv_cache
- Cache cleared on stop(), invalidated on system prompt change
Uses a small draft model to identify important prompt tokens via attention
scoring, then sparse-prefills the target model with only those tokens while
preserving original positional encoding via manual RoPE. Reduces TTFT
2.8-3.1x on 122B and 1.8x on 35B at 20% keep rate.

Implementation:
- specprefill.py: Core module with score_tokens(), select_chunks(),
  sparse_prefill(), cleanup_rope() (~640 lines)
- SimpleEngine integration: draft model loading, threshold-based activation,
  composition with system prompt KV cache, graceful fallback on error
- Per-request API: specprefill (bool) + specprefill_keep_pct (float)
  via extra_body for per-request control
- CLI: --specprefill, --specprefill-threshold, --specprefill-keep-pct,
  --specprefill-draft-model, --prefill-step-size

Closes waybarrios#179. Related: waybarrios#178 (TTFT), waybarrios#57 (speculative decoding).
@Thump604
Copy link
Copy Markdown
Collaborator Author

Addendum: Model Format Prerequisites

Added a comment on issue #179 documenting the VLM+MTP model format requirement:

  • SpecPrefill's SimpleEngine integration depends on feat: MLLM+MTP per-request routing for text and vision #171 (MLLM+MTP routing), which requires VLM+MTP format models
  • VLM+MTP models are created via merge-mtp-into-vlm.py (merges MTP weights from mlx-lm MTP model into VLM)
  • Requires mlx-lm with native MTP support (ml-explore/mlx-lm#990)
  • Pre-built models available at Thump604 on HuggingFace (Qwen3.5 all sizes, following mlx-community naming conventions with details on model cards)
  • Standalone score_tokens() + sparse_prefill() work with any mlx-lm model — VLM+MTP format only needed for the full MLLM+MTP routing composition

…refill

Add support for three model architecture families with auto-detection:

- Qwen3.5: gate split + q_norm + RoPE (existing, now refactored)
- Nemotron-H: content-based attention (no RoPE), mixer attr, compacted cache
- GPT-OSS/Llama: standard q_proj + RoPE (GQA, YarnRoPE compatible)

Key changes:
- Architecture-specific query extractors (_qwen35, _llama, _nemotron_h)
- Auto-detection in score_tokens() via model attributes (q_norm/rope/mixer)
- _get_attn_module()/_set_attn_module() abstract self_attn vs mixer access
- _find_attention_layers() handles block_type="*" (Nemotron-H attention)
- _build_layer_to_cache_map() handles compacted cache indexing
- sparse_prefill() skips RoPE patching for architectures without it
- cleanup_rope() is no-op for RoPE-less architectures
- Remove score_tokens_self() stub (CritiPrefill not viable for MoE)

Tested on Qwen3.5 4B (positions + pipeline). Nemotron-H and GPT-OSS
code paths ready for empirical validation.
Two bugs found during cross-architecture testing on GPT-OSS 120B:

1. _llama_extract_queries() used eager evaluation in getattr fallback
   chain: getattr(attn, "num_attention_heads", attn.num_heads) evaluates
   attn.num_heads before checking if num_attention_heads exists. Fixed to
   use safe nested getattr with None default.

2. _compute_importance() concatenated score matrices with different
   shapes when mixing sliding window (128-token RotatingKVCache) and
   full attention (unlimited KVCache) layers. Fixed by skipping layers
   whose cache spans fewer tokens than the full prompt.

Validated on GPT-OSS 120B + 20B draft: importance-based selection
produces coherent output while uniform selection degrades, confirming
scoring signal from 18 full-attention layers is sufficient.
Models with sliding window attention (e.g., GPT-OSS alternating
sliding/full layers) use RotatingKVCache that evicts old entries.
When sparse prefill inserts more tokens than the window size, the
cache loses context needed for decode.

sparse_prefill() now auto-detects RotatingKVCache and augments the
selection to include the last max_size positions, ensuring sliding
window layers have valid recent context.

Validated: GPT-OSS 120B + 20B draft produces coherent output on
2294-token prompts (was garbage before this fix). Qwen3.5 and
Nemotron-H unaffected (no RotatingKVCache in their cache).
Add _stream_generate_specprefill() method for models that don't use MTP
speculative decoding (Nemotron, GPT-OSS, etc). The existing SpecPrefill
integration only worked in the MTP text path (_stream_generate_text).

Changes:
- stream_generate() now pops specprefill/specprefill_keep_pct from kwargs
  and dispatches to the new method when conditions are met
- _stream_generate_specprefill() follows the same pattern as the MTP path:
  score → select → sparse_prefill → autoregressive generation
- Graceful fallback to normal generation on any error
- Per-request overrides (specprefill, specprefill_keep_pct) via extra_body
- Threshold and upper-bound checks identical to MTP path
@Thump604
Copy link
Copy Markdown
Collaborator Author

Server-Side TTFT Benchmarks (Non-MTP Integration)

Added _stream_generate_specprefill() in commit ab6c354 — SpecPrefill now works for non-MTP models via the standard LLM streaming path (stream_generate). Previously it was only available in the MTP text path (_stream_generate_text).

This enables server-side benchmarking for Nemotron and GPT-OSS (neither uses MTP). Per-request overrides (specprefill: true/false, specprefill_keep_pct) work from the OpenAI API.

Nemotron 3 Super 120B-A12B (5-bit) + Nano 4B draft (4-bit) — M2 Ultra 128GB

Actual Tokens Baseline TTFT SpecPrefill 30% SpecPrefill 20% Speedup 30% Speedup 20%
~4K 14.72s 9.22s 7.53s 1.60x 1.95x
~8K 29.09s 17.18s 14.05s 1.69x 2.07x
~16K 58.43s 33.17s 27.24s 1.76x 2.14x

Key observations:

  • Speedup increases with prompt length (1.60x → 2.14x), consistent with fixed scoring overhead being amortized over more target prefill savings
  • Near-perfect linear baseline scaling (~1.79s/1K tokens) — Nemotron's 40 Mamba-2 layers dominate prefill cost
  • 4 attention layers in the draft (9.5% of 42 total) provide sufficient scoring signal
  • Nano 4B draft (~2.1GB) is highly memory-efficient alongside the 83GB target

GPT-OSS 120B (5-bit) + GPT-OSS 20B draft (4-bit) — M2 Ultra 128GB

Actual Tokens Baseline TTFT SpecPrefill 30% SpecPrefill 20% Speedup 30% Speedup 20%
~4K 6.39s 6.01s 5.63s 1.06x 1.14x
~8K 12.78s 11.33s 10.33s 1.13x 1.24x
~16K 26.41s 22.53s 20.67s 1.17x 1.28x

Key observations:

  • Modest speedups (1.06-1.28x) reflect the 20B draft being 5x larger than the 4B drafts used for Qwen/Nemotron
  • Draft scoring cost is proportionally higher relative to the prefill savings
  • GPT-OSS's sliding window attention (128-token RotatingKVCache) makes baseline prefill inherently faster (~0.78s/1K tokens vs Nemotron's ~1.79s/1K)
  • Speedup would improve significantly with a smaller draft model (e.g., 4B)

Draft Model Size Impact

The benchmarks clearly demonstrate the draft-to-target ratio is the key factor:

Draft Model Size Scoring Overhead Target Speedup (16K)
Qwen 4B (4-bit) ~3 GB ~2.5s 3.06x (122B)
Nano 4B (4-bit) ~2.1 GB ~2.5s 2.14x (120B Nemotron)
GPT-OSS 20B (4-bit) ~10 GB ~5s est. 1.28x (120B GPT-OSS)

Recommendation: keep draft models as small as possible (4B 4-bit ideal). The scoring quality from 4B attention layers is sufficient — more parameters in the draft do not meaningfully improve importance detection, but they increase scoring latency.

@Thump604
Copy link
Copy Markdown
Collaborator Author

Paper

The technique and benchmarks are documented in a preprint:

SpecPrefill on Unified Memory: Cross-Architecture Sparse Prefill for Large Language Models on Apple Silicon

Covers the algorithm, cost model, quality validation (adversarial + perplexity), and cross-architecture results (Qwen3.5, Nemotron-H, GPT-OSS).

@waybarrios waybarrios merged commit f518c07 into waybarrios:main Mar 21, 2026
7 checks passed
@Thump604 Thump604 deleted the feat/specprefill branch March 21, 2026 21:50
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Mar 22, 2026
Port SimpleEngine features to BatchedEngine for continuous batching:

- Per-request MTP routing: text-only → TextModel (MTP), media → MLLM
- message_utils.py: shared _normalize_messages (developer→system,
  merge consecutive same-role, hoist system to [0])
- SpecPrefill config + draft model lifecycle in BatchedEngine
- System KV cache with ChatML boundary detection

Replaces PR waybarrios#192 (rebased against main after merge of waybarrios#180, waybarrios#97).
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Mar 23, 2026
ChunkedDraftScorer breaks SpecPrefill draft model scoring into chunks,
yielding between chunks so active generation requests can make progress.
Composes with the admission controller for memory-safe concurrent serving.

Design: draft scoring (the slowest SpecPrefill phase) runs in chunks of
4096 tokens. Between chunks, the caller yields to the event loop via
asyncio.sleep(0), allowing MLLM or other text generation to proceed.
Selection + sparse prefill + generation run monolithically after scoring.

Depends on: merged PR waybarrios#180 (SpecPrefill integration)

5 unit tests.
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Mar 23, 2026
Integrates the admission controller, cooperative specprefill, and MLLM+MTP
per-request routing into the BatchedEngine for production multi-user serving.

Key changes:
- BatchedEngine: admission gates on all 4 public methods (chat, stream_chat,
  generate, stream_generate) with try/finally cleanup
- MLLM+MTP routing: text-only requests → mlx_lm TextModel with MTP
  speculative decoding, media requests → mlx_vlm MLLM path
- System KV cache: prefix boundary detection + snapshot/restore for
  repeated system prompts (7x speedup on cache hits)
- Cooperative specprefill: draft scoring outside the generation lock,
  yielding between chunks for concurrent request progress
- Thread-safe snapshot access (threading.Lock for cross-thread reads/writes)
- Cache-hit re-verification under lock (prevents stale flag after queuing)
- MLLM error loop: breaks after 10 consecutive errors (no infinite loop)
- CLI: --scheduler-policy, --scheduler-headroom-gb flags

Depends on: admission controller PR, cooperative specprefill PR, waybarrios#165, waybarrios#180

New files:
- specprefill.py: SpecPrefill scoring + sparse prefill (builds on merged waybarrios#180)
- text_model_from_vlm.py: zero-copy TextModel construction from VLM backbone
Thump604 added a commit to Thump604/vllm-mlx that referenced this pull request Mar 24, 2026
Port SimpleEngine's MLLM+MTP per-request routing to BatchedEngine.
Text-only requests route to mlx_lm TextModel with MTP speculative
decoding; media requests route to MLLM path.

Uses text_model_from_vlm.py (already upstream from PR waybarrios#180) to build
a zero-copy TextModel from VLM backbone weights. Routing decision is
per-request based on message content via _has_media_content().

Changes:
- Add mtp/prefill_step_size params to BatchedEngine.__init__
- Build TextModel in _start_mllm() when mtp=True
- Route text-only to _stream_chat_text_model in chat()/stream_chat()
- Add _chat_text_model/_stream_chat_text_model for mlx_lm generation
- Add _has_media_content helper (mirrors SimpleEngine)
- Add test_batched_mtp_routing.py (8 tests)
raullenchai pushed a commit to raullenchai/Rapid-MLX that referenced this pull request Mar 26, 2026
…ection, served-model-name

Merge 16 upstream commits (22dcbf8..d235c37) into our fork:

- feat: SpecPrefill — attention-based sparse prefill for TTFT reduction (waybarrios#180)
- feat: native Qwen3-VL video pipeline with temporal 3D conv + M-RoPE (waybarrios#150)
- fix: Disable MambaCache monkey-patch for hybrid models, add MTP auto-injection (waybarrios#97)
- feat: Add --served-model-name CLI parameter (waybarrios#125)
- feat: Add Qwen3.5 text-only loading and dynamic memory threshold (waybarrios#127)
- fix(mllm_scheduler): add adaptive periodic cache clearing (waybarrios#157)
- fix: Metal resource leak under high concurrency (waybarrios#92)

Conflict resolution strategy: keep all fork features (DeltaNet snapshots,
fast SSE templates, tool injection, cloud routing, prompt cache, etc.)
while incorporating upstream's new functionality.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants