feat: SpecPrefill — attention-based sparse prefill for TTFT reduction#180
feat: SpecPrefill — attention-based sparse prefill for TTFT reduction#180waybarrios merged 7 commits intowaybarrios:mainfrom
Conversation
When both --mllm and --enable-mtp are set, SimpleEngine builds a parallel mlx_lm TextModel sharing the VLM backbone weights (zero-copy). Text-only requests route to mlx_lm with MTP speculative decoding; media requests route to the mlx_vlm MLLM path. Key components: - text_model_from_vlm.py: Build mlx_lm TextModel from VLM weights - Per-request routing in stream_chat() via _has_media_content() - _stream_generate_text() for MTP-accelerated text generation - MTP passthrough: --enable-mtp flag through CLI/server/engine/LLM Tested on Qwen3.5-35B-A3B VLM+MTP (8-bit): - Text (MTP): 65.3 tok/s - Vision (MLLM): 63.8 tok/s - Memory: 38.7 GB (zero-copy, same as single model)
Persist backbone KV cache after prefilling system prompt tokens. On subsequent requests with the same system prompt, restore the snapshot and only prefill the suffix (user + history) tokens. For a 10K-token system prompt on the 122B model, this saves ~57s per request by avoiding redundant system prompt prefill. Implementation: - Detect system prefix via ChatML boundary markers - Hash prefix text for cache key validation - On cache miss: prefill system tokens, snapshot backbone KV state - On cache hit: restore snapshot into fresh cache, send suffix only - Token prefix validation ensures correct split at tokenization boundary - Single-entry cache (one system prompt at a time) - Stats exposed via get_stats() → system_kv_cache - Cache cleared on stop(), invalidated on system prompt change
Uses a small draft model to identify important prompt tokens via attention scoring, then sparse-prefills the target model with only those tokens while preserving original positional encoding via manual RoPE. Reduces TTFT 2.8-3.1x on 122B and 1.8x on 35B at 20% keep rate. Implementation: - specprefill.py: Core module with score_tokens(), select_chunks(), sparse_prefill(), cleanup_rope() (~640 lines) - SimpleEngine integration: draft model loading, threshold-based activation, composition with system prompt KV cache, graceful fallback on error - Per-request API: specprefill (bool) + specprefill_keep_pct (float) via extra_body for per-request control - CLI: --specprefill, --specprefill-threshold, --specprefill-keep-pct, --specprefill-draft-model, --prefill-step-size Closes waybarrios#179. Related: waybarrios#178 (TTFT), waybarrios#57 (speculative decoding).
Addendum: Model Format PrerequisitesAdded a comment on issue #179 documenting the VLM+MTP model format requirement:
|
…refill Add support for three model architecture families with auto-detection: - Qwen3.5: gate split + q_norm + RoPE (existing, now refactored) - Nemotron-H: content-based attention (no RoPE), mixer attr, compacted cache - GPT-OSS/Llama: standard q_proj + RoPE (GQA, YarnRoPE compatible) Key changes: - Architecture-specific query extractors (_qwen35, _llama, _nemotron_h) - Auto-detection in score_tokens() via model attributes (q_norm/rope/mixer) - _get_attn_module()/_set_attn_module() abstract self_attn vs mixer access - _find_attention_layers() handles block_type="*" (Nemotron-H attention) - _build_layer_to_cache_map() handles compacted cache indexing - sparse_prefill() skips RoPE patching for architectures without it - cleanup_rope() is no-op for RoPE-less architectures - Remove score_tokens_self() stub (CritiPrefill not viable for MoE) Tested on Qwen3.5 4B (positions + pipeline). Nemotron-H and GPT-OSS code paths ready for empirical validation.
Two bugs found during cross-architecture testing on GPT-OSS 120B: 1. _llama_extract_queries() used eager evaluation in getattr fallback chain: getattr(attn, "num_attention_heads", attn.num_heads) evaluates attn.num_heads before checking if num_attention_heads exists. Fixed to use safe nested getattr with None default. 2. _compute_importance() concatenated score matrices with different shapes when mixing sliding window (128-token RotatingKVCache) and full attention (unlimited KVCache) layers. Fixed by skipping layers whose cache spans fewer tokens than the full prompt. Validated on GPT-OSS 120B + 20B draft: importance-based selection produces coherent output while uniform selection degrades, confirming scoring signal from 18 full-attention layers is sufficient.
Models with sliding window attention (e.g., GPT-OSS alternating sliding/full layers) use RotatingKVCache that evicts old entries. When sparse prefill inserts more tokens than the window size, the cache loses context needed for decode. sparse_prefill() now auto-detects RotatingKVCache and augments the selection to include the last max_size positions, ensuring sliding window layers have valid recent context. Validated: GPT-OSS 120B + 20B draft produces coherent output on 2294-token prompts (was garbage before this fix). Qwen3.5 and Nemotron-H unaffected (no RotatingKVCache in their cache).
Add _stream_generate_specprefill() method for models that don't use MTP speculative decoding (Nemotron, GPT-OSS, etc). The existing SpecPrefill integration only worked in the MTP text path (_stream_generate_text). Changes: - stream_generate() now pops specprefill/specprefill_keep_pct from kwargs and dispatches to the new method when conditions are met - _stream_generate_specprefill() follows the same pattern as the MTP path: score → select → sparse_prefill → autoregressive generation - Graceful fallback to normal generation on any error - Per-request overrides (specprefill, specprefill_keep_pct) via extra_body - Threshold and upper-bound checks identical to MTP path
Server-Side TTFT Benchmarks (Non-MTP Integration)Added This enables server-side benchmarking for Nemotron and GPT-OSS (neither uses MTP). Per-request overrides ( Nemotron 3 Super 120B-A12B (5-bit) + Nano 4B draft (4-bit) — M2 Ultra 128GB
Key observations:
GPT-OSS 120B (5-bit) + GPT-OSS 20B draft (4-bit) — M2 Ultra 128GB
Key observations:
Draft Model Size ImpactThe benchmarks clearly demonstrate the draft-to-target ratio is the key factor:
Recommendation: keep draft models as small as possible (4B 4-bit ideal). The scoring quality from 4B attention layers is sufficient — more parameters in the draft do not meaningfully improve importance detection, but they increase scoring latency. |
PaperThe technique and benchmarks are documented in a preprint: SpecPrefill on Unified Memory: Cross-Architecture Sparse Prefill for Large Language Models on Apple Silicon
Covers the algorithm, cost model, quality validation (adversarial + perplexity), and cross-architecture results (Qwen3.5, Nemotron-H, GPT-OSS). |
Port SimpleEngine features to BatchedEngine for continuous batching: - Per-request MTP routing: text-only → TextModel (MTP), media → MLLM - message_utils.py: shared _normalize_messages (developer→system, merge consecutive same-role, hoist system to [0]) - SpecPrefill config + draft model lifecycle in BatchedEngine - System KV cache with ChatML boundary detection Replaces PR waybarrios#192 (rebased against main after merge of waybarrios#180, waybarrios#97).
ChunkedDraftScorer breaks SpecPrefill draft model scoring into chunks, yielding between chunks so active generation requests can make progress. Composes with the admission controller for memory-safe concurrent serving. Design: draft scoring (the slowest SpecPrefill phase) runs in chunks of 4096 tokens. Between chunks, the caller yields to the event loop via asyncio.sleep(0), allowing MLLM or other text generation to proceed. Selection + sparse prefill + generation run monolithically after scoring. Depends on: merged PR waybarrios#180 (SpecPrefill integration) 5 unit tests.
Integrates the admission controller, cooperative specprefill, and MLLM+MTP per-request routing into the BatchedEngine for production multi-user serving. Key changes: - BatchedEngine: admission gates on all 4 public methods (chat, stream_chat, generate, stream_generate) with try/finally cleanup - MLLM+MTP routing: text-only requests → mlx_lm TextModel with MTP speculative decoding, media requests → mlx_vlm MLLM path - System KV cache: prefix boundary detection + snapshot/restore for repeated system prompts (7x speedup on cache hits) - Cooperative specprefill: draft scoring outside the generation lock, yielding between chunks for concurrent request progress - Thread-safe snapshot access (threading.Lock for cross-thread reads/writes) - Cache-hit re-verification under lock (prevents stale flag after queuing) - MLLM error loop: breaks after 10 consecutive errors (no infinite loop) - CLI: --scheduler-policy, --scheduler-headroom-gb flags Depends on: admission controller PR, cooperative specprefill PR, waybarrios#165, waybarrios#180 New files: - specprefill.py: SpecPrefill scoring + sparse prefill (builds on merged waybarrios#180) - text_model_from_vlm.py: zero-copy TextModel construction from VLM backbone
Port SimpleEngine's MLLM+MTP per-request routing to BatchedEngine. Text-only requests route to mlx_lm TextModel with MTP speculative decoding; media requests route to MLLM path. Uses text_model_from_vlm.py (already upstream from PR waybarrios#180) to build a zero-copy TextModel from VLM backbone weights. Routing decision is per-request based on message content via _has_media_content(). Changes: - Add mtp/prefill_step_size params to BatchedEngine.__init__ - Build TextModel in _start_mllm() when mtp=True - Route text-only to _stream_chat_text_model in chat()/stream_chat() - Add _chat_text_model/_stream_chat_text_model for mlx_lm generation - Add _has_media_content helper (mirrors SimpleEngine) - Add test_batched_mtp_routing.py (8 tests)
…ection, served-model-name Merge 16 upstream commits (22dcbf8..d235c37) into our fork: - feat: SpecPrefill — attention-based sparse prefill for TTFT reduction (waybarrios#180) - feat: native Qwen3-VL video pipeline with temporal 3D conv + M-RoPE (waybarrios#150) - fix: Disable MambaCache monkey-patch for hybrid models, add MTP auto-injection (waybarrios#97) - feat: Add --served-model-name CLI parameter (waybarrios#125) - feat: Add Qwen3.5 text-only loading and dynamic memory threshold (waybarrios#127) - fix(mllm_scheduler): add adaptive periodic cache clearing (waybarrios#157) - fix: Metal resource leak under high concurrency (waybarrios#92) Conflict resolution strategy: keep all fork features (DeltaNet snapshots, fast SSE templates, tool injection, cloud routing, prompt cache, etc.) while incorporating upstream's new functionality. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
SpecPrefill: attention-based sparse prefill using a draft model to reduce TTFT on long prompts.
What: A small draft model (e.g. 2B) scores token importance via Q@K^T attention. Top-k% tokens are selected in 32-token chunks. Target model sparse-prefills only selected tokens with position-mapped RoPE preserving correct relative positions.
Results (Qwen3.5-122B, 2B draft, 20% keep):
Per-request API:
extra_body: {"specprefill": true, "specprefill_keep_pct": 0.2}Depends on: #171
Files: new
specprefill.py(742 lines), modifiedengine/simple.py,server.py,cli.py,api/models.pyPaper: doi.org/10.5281/zenodo.19120919