Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,104 changes: 1,104 additions & 0 deletions docs/superpowers/plans/2026-03-16-hybrid-model-prefix-cache.md

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions docs/superpowers/specs/2026-03-16-hybrid-model-prefix-cache-design.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Hybrid Model Prefix Cache — Design Spec

**Date:** 2026-03-16
**Issues:** #142, #136
**Branch:** `fix/mllm-continuous-batching-hybrid-models` (PR #165)

## Problem

`BlockAwarePrefixCache` crashes on hybrid models (Qwen 3.5 MoE, Nemotron) with:
```
WARNING: Failed to extract block tensor slice: Too many indices for array with 3 dimensions.
```
Prefix cache silently degrades to no-op. No TTFT benefit on repeated contexts.

## Root Cause

The issue reporters hypothesized that MoE models produce 3D KV tensors. **This is wrong.** All KV cache tensors in mlx-lm are always 4D `(B, n_kv_heads, seq_len, head_dim)` regardless of model architecture.

The actual cause: hybrid models have two cache layer types:

| Layer Type | Cache Class | `.state` Format | Positional? |
|-----------|-------------|-----------------|-------------|
| Attention | KVCache | `(keys_4D, values_4D)` | Yes — indexed by token position |
| GatedDeltaNet / Mamba | ArraysCache | `[conv_3D, recurrent_4D]` | No — cumulative summary |

`_extract_cache_states()` extracts state from ALL layers. `_extract_block_tensor_slice()` then tries to slice every layer as 4D KV:

```python
keys, values = layer_state["state"] # Unpacks ArraysCache [conv_3D, ssm_4D]
keys[:, :, start:end, :] # 4 indices on 3D conv_state → IndexError
```

## Design

### Layer Classification

New helper function identifies cache layer types by `class_name` (already stored by `_extract_cache_states()`):

```python
_KV_CACHE_CLASSES = frozenset({
"KVCache", "RotatingKVCache", "QuantizedKVCache",
"ChunkedKVCache", "ConcatenateKVCache", "BatchKVCache",
"BatchRotatingKVCache",
})

def _is_kv_layer(layer_state: dict) -> bool:
return layer_state.get("class_name", "") in _KV_CACHE_CLASSES
```

### Separate Storage Model

KV and non-KV layers are stored differently — no duplication:

- **KV layers** → block-sliced along seq_dim (axis=2), stored per-block in `KVCacheBlock.cache_data`
- **Non-KV layers** → stored once as whole-sequence state in `BlockAwarePrefixCache._non_kv_states`, keyed by `tuple(block_ids)`

```python
@dataclass
class NonKVCacheData:
"""Full state for non-positional cache layers (SSM, linear attention)."""
layer_indices: List[int] # Position in the full layer list
states: List[Any] # From cache.state (list/tuple of arrays)
meta_states: List[Any] # From cache.meta_state
class_refs: List[type] # For from_state() reconstruction
total_layers: int # Total layer count (KV + non-KV)
```

### Partial Prefix Rejection

For hybrid models, partial prefix reuse (matching some but not all blocks) would restore KV cache but NOT SSM state. The model would generate with mismatched attention/SSM state, producing incorrect output.

Guard: when `fetch_cache()` finds a partial match and the model has non-KV layers, return cache miss instead.

Full matches work because non-KV states are stored keyed by exact `tuple(block_ids)`.

### Changes

#### `prefix_cache.py`

1. **`_is_kv_layer()`** — classify layers by class_name

2. **`_extract_block_tensor_slice()`**:
- Skip non-KV layers (append `None` placeholder)
- KV layers: slice `keys[:,:,start:end,:]` as before
- Return list with `None` gaps at non-KV positions

3. **`store_cache()`**:
- Classify layers into KV and non-KV
- KV: block-slice via `_extract_block_tensor_slice()`
- Non-KV: extract states, store in `self._non_kv_states[tuple(block_ids)]`
- Set `has_non_kv` flag on `BlockCacheEntry` for fast checks

4. **`reconstruct_cache()`**:
- Look up non-KV states via `tuple(block_table.block_ids)`
- KV layers: concatenate block slices → KVCache (as before)
- Non-KV layers: `class_ref.from_state(state, meta_state)`
- Interleave in correct layer order using stored indices
- If non-KV states missing for hybrid model → return `None`

5. **`fetch_cache()` partial match guard**:
- After `find_shared_prefix()` or `_find_best_prefix_match()`
- Check if non-KV states exist for the matched block set
- Missing → return `(None, tokens)` (force recomputation)

6. **Cleanup**: `release_cache()`, `clear()` clean `_non_kv_states`

#### `scheduler.py`

- `_reconstruct_cache_from_states()` fallback path (L1490): guard `shape[2]` access with `hasattr(state[0], 'shape') and len(state) == 2` before assuming KV format

### What This Does NOT Change

- `PrefixCacheManager` (trie-based) — already works for hybrid models (stores whole cache objects)
- `_extract_cache_states()` — already correctly stores `class_name` and `class_ref`
- `_can_trim_cache()` / `_trim_cache()` — already handles hybrid models (checks all layers)
- KVCache tensor shape handling — always 4D, no ndim checks needed
- Pure-KV model behavior — unchanged, no non-KV layers detected

### Tests

New test file: `tests/test_prefix_cache_hybrid.py`

1. **`test_is_kv_layer`** — classification of KVCache vs ArraysCache
2. **`test_extract_block_tensor_slice_hybrid`** — mixed KV/ArraysCache layers, KV sliced, ArraysCache skipped
3. **`test_store_and_reconstruct_hybrid`** — full roundtrip with simulated hybrid model cache
4. **`test_partial_prefix_rejected_for_hybrid`** — partial match → cache miss
5. **`test_full_prefix_hit_for_hybrid`** — exact match → correct reconstruction
6. **`test_pure_kv_model_unchanged`** — regression: existing behavior for pure attention models
7. **`test_cleanup_non_kv_states`** — release_cache and clear remove non-KV data
120 changes: 120 additions & 0 deletions tests/smoke_test_specdec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3
"""
Smoke test for speculative decoding with real models.

Usage: python tests/smoke_test_specdec.py

Uses Qwen3.5-35B-A3B-8bit as target, Qwen3.5-4B-4bit as draft.
Tests the SimpleEngine path (mlx_lm.stream_generate with draft_model).
"""

import os
import sys
import time

# Add project to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

TARGET = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-35B-A3B-8bit")
DRAFT = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-4B-4bit")
PROMPT = "What is the capital of France? Answer in one sentence."
MAX_TOKENS = 64
NUM_DRAFT = 3


def test_without_draft():
"""Baseline: generate without speculative decoding."""
from mlx_lm import load, stream_generate

print("=" * 60)
print("Loading target model (no draft)...")
model, tokenizer = load(TARGET)
print(f"Target loaded. Generating {MAX_TOKENS} tokens...")

tokens = []
t0 = time.perf_counter()
for resp in stream_generate(model, tokenizer, prompt=PROMPT, max_tokens=MAX_TOKENS):
tokens.append(resp.token)
elapsed = time.perf_counter() - t0
text = tokenizer.decode(tokens)
print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):")
print(f" {text}")
print()
return len(tokens), elapsed


def test_with_draft():
"""Speculative: generate with draft model."""
from mlx_lm import load, stream_generate

print("=" * 60)
print("Loading target + draft model...")
model, tokenizer = load(TARGET)
draft_model, _ = load(DRAFT)

# Verify vocab match — walk model structure to find embed_tokens
def _get_vocab_size(m):
for attr in ["model", "language_model"]:
sub = getattr(m, attr, None)
if sub is not None:
et = getattr(sub, "embed_tokens", None)
if et is not None:
return et.weight.shape[0]
return None

target_vocab = _get_vocab_size(model)
draft_vocab = _get_vocab_size(draft_model)
print(f"Target vocab: {target_vocab}, Draft vocab: {draft_vocab}")
if target_vocab and draft_vocab:
assert target_vocab == draft_vocab, "Vocab size mismatch!"

print(f"Generating {MAX_TOKENS} tokens with num_draft_tokens={NUM_DRAFT}...")

tokens = []
from_draft_count = 0
t0 = time.perf_counter()
for resp in stream_generate(
model,
tokenizer,
prompt=PROMPT,
max_tokens=MAX_TOKENS,
draft_model=draft_model,
num_draft_tokens=NUM_DRAFT,
):
tokens.append(resp.token)
if resp.from_draft:
from_draft_count += 1
elapsed = time.perf_counter() - t0

text = tokenizer.decode(tokens)
accept_rate = from_draft_count / len(tokens) * 100 if tokens else 0
print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):")
print(f" {text}")
print(f"Draft acceptance: {from_draft_count}/{len(tokens)} ({accept_rate:.0f}%)")
print()
return len(tokens), elapsed


if __name__ == "__main__":
print("Speculative Decoding Smoke Test")
print("Target:", TARGET)
print("Draft:", DRAFT)
print()

n1, t1 = test_without_draft()
# Clear model from memory
import gc
import mlx.core as mx

gc.collect()
mx.clear_cache()

n2, t2 = test_with_draft()

print("=" * 60)
print("RESULTS:")
print(f" Without draft: {n1} tokens in {t1:.2f}s ({n1/t1:.1f} tok/s)")
print(f" With draft: {n2} tokens in {t2:.2f}s ({n2/t2:.1f} tok/s)")
if t1 > 0 and t2 > 0:
speedup = (n1 / t1) / (n2 / t2) if n2 / t2 > 0 else 0
print(f" Speedup: {1/speedup:.2f}x" if speedup > 0 else " N/A")
129 changes: 129 additions & 0 deletions tests/test_mllm_mtp_routing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for MLLM + MTP per-request routing."""


def test_has_media_content_text_only():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Hello"}]) is False


def test_has_media_content_with_image():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What's this?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_with_video():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_empty():
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([]) is False


def test_has_media_content_string_content():
"""String content (not list) should return False."""
from vllm_mlx.engine.simple import _has_media_content

assert _has_media_content([{"role": "user", "content": "Just text"}]) is False


def test_has_media_content_audio():
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}}
],
}
]
assert _has_media_content(messages) is True


def test_has_media_content_multi_turn():
"""Media in earlier turns should still be detected."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Look at this"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,..."},
},
],
},
{"role": "assistant", "content": "I see an image."},
{"role": "user", "content": "Tell me more about it."},
]
assert _has_media_content(messages) is True


def test_has_media_content_text_list():
"""List content with only text parts should return False."""
from vllm_mlx.engine.simple import _has_media_content

messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
],
}
]
assert _has_media_content(messages) is False


# --- MLXMultimodalLM extraction method tests ---

from unittest.mock import MagicMock


def test_get_language_model():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_lm = MagicMock()
mllm.model = MagicMock()
mllm.model.language_model = inner_lm
assert MLXMultimodalLM.get_language_model(mllm) is inner_lm


def test_get_tokenizer():
from vllm_mlx.models.mllm import MLXMultimodalLM

mllm = MagicMock(spec=MLXMultimodalLM)
inner_tok = MagicMock()
mllm.processor = MagicMock()
mllm.processor.tokenizer = inner_tok
assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok
Loading
Loading