diff --git a/docs/superpowers/plans/2026-03-16-hybrid-model-prefix-cache.md b/docs/superpowers/plans/2026-03-16-hybrid-model-prefix-cache.md new file mode 100644 index 00000000..c28eff7e --- /dev/null +++ b/docs/superpowers/plans/2026-03-16-hybrid-model-prefix-cache.md @@ -0,0 +1,1104 @@ +# Hybrid Model Prefix Cache — Implementation Plan + +> **For agentic workers:** REQUIRED: Use superpowers:subagent-driven-development (if subagents available) or superpowers:executing-plans to implement this plan. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Make `BlockAwarePrefixCache` work with hybrid models (Qwen 3.5, Nemotron) that mix KVCache and ArraysCache layers, fixing issues #142 and #136. + +**Architecture:** Classify cache layers as positional (KVCache — block-sliceable) vs non-positional (ArraysCache — stored whole). KV layers are block-sliced as before. Non-KV layers are stored once per prefix entry and restored via `from_state()`. Partial prefix reuse is rejected for hybrid models (mismatched KV/SSM state would corrupt output). + +**Tech Stack:** Python 3.12, mlx, mlx-lm (KVCache/ArraysCache), pytest + +**Spec:** `docs/superpowers/specs/2026-03-16-hybrid-model-prefix-cache-design.md` + +--- + +## File Map + +| File | Action | Responsibility | +|------|--------|----------------| +| `vllm_mlx/prefix_cache.py` | Modify | Layer classification, hybrid-aware slicing/storage/reconstruction | +| `vllm_mlx/scheduler.py` | Modify | Robustness guard in `_reconstruct_cache_from_states()` fallback | +| `tests/test_prefix_cache_hybrid.py` | Create | All new tests for hybrid prefix cache behavior | + +--- + +## Chunk 1: Layer Classification and Extract Fix + +### Task 1: Add `_is_kv_layer` helper and `NonKVCacheData` dataclass + +**Files:** +- Modify: `vllm_mlx/prefix_cache.py` (add near top, after existing imports/dataclasses) +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the test file with layer classification tests** + +Create `tests/test_prefix_cache_hybrid.py`: + +```python +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BlockAwarePrefixCache with hybrid model caches (issues #142, #136).""" + +from unittest.mock import MagicMock + +import pytest + +try: + import mlx.core as mx + from mlx_lm.models.cache import ArraysCache, KVCache + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +from vllm_mlx.prefix_cache import ( + BlockAwarePrefixCache, + NonKVCacheData, + _is_kv_layer, +) + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_kv_state(seq_len=64, n_kv_heads=4, head_dim=32): + """Simulate a KVCache layer_state dict (4D tensors).""" + keys = mx.zeros((1, n_kv_heads, seq_len, head_dim)) + values = mx.zeros((1, n_kv_heads, seq_len, head_dim)) + return { + "state": (keys, values), + "meta_state": (str(seq_len),), + "class_name": "KVCache", + "class_ref": KVCache, + } + + +def _make_arrays_state(seq_len=64, conv_dim=128, ssm_heads=8, ssm_dim=64): + """Simulate an ArraysCache layer_state dict (conv_3D + recurrent_4D).""" + conv_state = mx.zeros((1, 3, conv_dim)) # (B, kernel-1, conv_dim) + ssm_state = mx.zeros((1, ssm_heads, ssm_dim, ssm_dim)) # (B, H, D, D) + return { + "state": [conv_state, ssm_state], + "meta_state": "", + "class_name": "ArraysCache", + "class_ref": ArraysCache, + } + + +def _make_hybrid_cache_data(n_total=48, attn_interval=4, seq_len=128): + """Simulate extracted cache states for a hybrid model. + + Qwen 3.5 pattern: every 4th layer is attention, rest are GatedDeltaNet. + So layers 3,7,11,... are KVCache; layers 0,1,2,4,5,6,... are ArraysCache. + """ + cache_data = [] + for i in range(n_total): + is_attn = (i + 1) % attn_interval == 0 + if is_attn: + cache_data.append(_make_kv_state(seq_len=seq_len)) + else: + cache_data.append(_make_arrays_state()) + return cache_data + + +def _make_pure_kv_cache_data(n_layers=32, seq_len=128): + """Simulate extracted cache states for a pure attention model.""" + return [_make_kv_state(seq_len=seq_len) for _ in range(n_layers)] + + +# --------------------------------------------------------------------------- +# Tests: Layer Classification +# --------------------------------------------------------------------------- + +class TestIsKVLayer: + def test_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "KVCache"}) is True + + def test_rotating_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "RotatingKVCache"}) is True + + def test_quantized_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "QuantizedKVCache"}) is True + + def test_batch_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "BatchKVCache"}) is True + + def test_arrays_cache_is_not_kv(self): + assert _is_kv_layer({"class_name": "ArraysCache"}) is False + + def test_cache_list_is_not_kv(self): + assert _is_kv_layer({"class_name": "CacheList"}) is False + + def test_missing_class_name_is_not_kv(self): + assert _is_kv_layer({}) is False + + def test_empty_class_name_is_not_kv(self): + assert _is_kv_layer({"class_name": ""}) is False +``` + +- [ ] **Step 2: Run test to verify it fails (import error)** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestIsKVLayer -v 2>&1 | tail -5` +Expected: FAIL — `ImportError: cannot import name '_is_kv_layer' from 'vllm_mlx.prefix_cache'` + +- [ ] **Step 3: Implement `_is_kv_layer` and `NonKVCacheData` in `prefix_cache.py`** + +Add after the existing `PrefixCacheStats` class (around line 67), before `class PrefixCacheManager`: + +```python +# Known KV cache class names — positional caches that can be block-sliced +# along the sequence dimension (axis=2). All produce 4D tensors: +# (batch, n_kv_heads, seq_len, head_dim). +_KV_CACHE_CLASSES = frozenset({ + "KVCache", + "RotatingKVCache", + "QuantizedKVCache", + "ChunkedKVCache", + "ConcatenateKVCache", + "BatchKVCache", + "BatchRotatingKVCache", +}) + + +def _is_kv_layer(layer_state: dict) -> bool: + """Check if a layer state dict represents a positional KV cache layer. + + Args: + layer_state: Dict from _extract_cache_states() containing + 'class_name', 'state', 'meta_state', 'class_ref'. + + Returns: + True if the layer is a KV cache that can be sliced along seq_len. + """ + return layer_state.get("class_name", "") in _KV_CACHE_CLASSES + + +@dataclass +class NonKVCacheData: + """Full state for non-positional cache layers (SSM, linear attention). + + Hybrid models (Qwen 3.5, Nemotron) have layers that use ArraysCache + instead of KVCache. These store cumulative state (conv + recurrent) + that cannot be sliced into blocks. This dataclass stores the full + state for reconstruction alongside block-sliced KV layers. + """ + + layer_indices: List[int] # Position in the full layer list + states: List[Any] # From cache.state for each non-KV layer + meta_states: List[Any] # From cache.meta_state for each non-KV layer + class_refs: List[Any] # type refs for from_state() reconstruction + total_layers: int # Total layer count (KV + non-KV) +``` + +Also update the module's import block — add `NonKVCacheData` and `_is_kv_layer` to the public API at the top level (no `__all__` needed, just ensure they're importable). + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestIsKVLayer -v 2>&1 | tail -15` +Expected: 8 passed + +- [ ] **Step 5: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/prefix_cache.py tests/test_prefix_cache_hybrid.py +git commit -m "feat: add _is_kv_layer() and NonKVCacheData for hybrid model prefix cache + +Adds layer classification to distinguish positional KV cache layers +(block-sliceable) from non-positional layers like ArraysCache (SSM, +linear attention). Foundation for fixing #142 and #136." +``` + +--- + +### Task 2: Fix `_extract_block_tensor_slice()` to skip non-KV layers + +**Files:** +- Modify: `vllm_mlx/prefix_cache.py:638-691` (`_extract_block_tensor_slice`) +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_prefix_cache_hybrid.py`: + +```python +class TestExtractBlockTensorSlice: + """Test _extract_block_tensor_slice with hybrid cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_pure_kv_slicing_unchanged(self, cache): + """Pure KV model: all layers sliced as before.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + assert len(result) == 4 + for keys_slice, values_slice in result: + assert keys_slice.shape == (1, 4, 64, 32) + assert values_slice.shape == (1, 4, 64, 32) + + def test_hybrid_skips_non_kv_layers(self, cache): + """Hybrid model: KV layers sliced, ArraysCache layers return None.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + # Layers: 0=Arr, 1=Arr, 2=Arr, 3=KV, 4=Arr, 5=Arr, 6=Arr, 7=KV + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + assert len(result) == 8 + # Non-KV layers are None + assert result[0] is None + assert result[1] is None + assert result[2] is None + assert result[4] is None + assert result[5] is None + assert result[6] is None + # KV layers are sliced + assert result[3] is not None + keys, values = result[3] + assert keys.shape == (1, 4, 64, 32) + assert result[7] is not None + + def test_hybrid_does_not_crash(self, cache): + """Regression: the original bug — no IndexError on ArraysCache layers.""" + data = _make_hybrid_cache_data(n_total=48, attn_interval=4, seq_len=256) + # This used to raise "Too many indices for array with 3 dimensions" + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + + def test_slice_beyond_seq_len(self, cache): + """KV slice beyond available seq_len clips correctly.""" + data = _make_pure_kv_cache_data(n_layers=2, seq_len=100) + result = cache._extract_block_tensor_slice(data, 64, 128) + assert result is not None + keys, values = result[0] + assert keys.shape[2] == 36 # 100 - 64 +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestExtractBlockTensorSlice -v 2>&1 | tail -10` +Expected: FAIL — `test_hybrid_skips_non_kv_layers` and `test_hybrid_does_not_crash` fail + +- [ ] **Step 3: Rewrite `_extract_block_tensor_slice()` in `prefix_cache.py`** + +Replace lines 638-691 (the entire method) with: + +```python + def _extract_block_tensor_slice( + self, + cache_data: List[Dict[str, Any]], + start_idx: int, + end_idx: int, + ) -> Optional[List[Optional[Tuple[Any, Any]]]]: + """ + Extract tensor slices for a single block from cache data. + + For KV cache layers (positional), slices keys/values along the + sequence dimension. For non-KV layers (ArraysCache, etc.), + returns None at that position — these are stored separately + by store_cache() as whole-sequence state. + + Args: + cache_data: List of layer states from _extract_cache_states() + start_idx: Start token index in the sequence + end_idx: End token index in the sequence + + Returns: + List with one entry per layer: + - (keys_slice, values_slice) for KV layers + - None for non-KV layers + Returns None only on complete failure. + """ + if not HAS_MLX or not cache_data: + return None + + try: + block_slices: List[Optional[Tuple[Any, Any]]] = [] + for layer_state in cache_data: + if "state" not in layer_state: + block_slices.append(None) + continue + + # Skip non-KV layers — they can't be block-sliced + if not _is_kv_layer(layer_state): + block_slices.append(None) + continue + + keys, values = layer_state["state"] + + # KV cache shape: (batch, n_kv_heads, seq_len, head_dim) + # Slice along seq_len dimension (axis 2) + seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 + + if end_idx > seq_len: + actual_end = min(end_idx, seq_len) + if start_idx >= actual_end: + block_slices.append(None) + continue + keys_slice = keys[:, :, start_idx:actual_end, :] + values_slice = values[:, :, start_idx:actual_end, :] + else: + keys_slice = keys[:, :, start_idx:end_idx, :] + values_slice = values[:, :, start_idx:end_idx, :] + + block_slices.append((keys_slice, values_slice)) + + # Return None only if no layers produced data at all + if all(s is None for s in block_slices): + return None + + return block_slices + + except Exception as e: + logger.warning(f"Failed to extract block tensor slice: {e}") + return None +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestExtractBlockTensorSlice -v 2>&1 | tail -10` +Expected: 4 passed + +- [ ] **Step 5: Run existing prefix cache tests for regression** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache.py -v 2>&1 | tail -15` +Expected: All pass (no regression) + +- [ ] **Step 6: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/prefix_cache.py tests/test_prefix_cache_hybrid.py +git commit -m "fix: _extract_block_tensor_slice skips non-KV layers + +Fixes the 'Too many indices for array with 3 dimensions' crash on +hybrid models (Qwen 3.5, Nemotron). ArraysCache layers are skipped +during block slicing — they're stored separately as whole-sequence +state. Fixes #142, #136." +``` + +--- + +## Chunk 2: Store and Reconstruct Hybrid Cache + +### Task 3: Update `store_cache()` to store non-KV states separately + +**Files:** +- Modify: `vllm_mlx/prefix_cache.py:380-636` (`BlockAwarePrefixCache.__init__` and `store_cache`) +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_prefix_cache_hybrid.py`: + +```python +class TestStoreHybridCache: + """Test store_cache with hybrid model cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_stores_non_kv_states(self, cache): + """Hybrid cache data stores non-KV states in _non_kv_states.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 1 + # Check the stored non-KV data + non_kv = list(cache._non_kv_states.values())[0] + assert isinstance(non_kv, NonKVCacheData) + assert non_kv.total_layers == 8 + assert len(non_kv.layer_indices) == 6 # 6 ArraysCache layers + assert non_kv.layer_indices == [0, 1, 2, 4, 5, 6] + + def test_pure_kv_no_non_kv_states(self, cache): + """Pure KV model does not create non-KV states.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 0 + + def test_has_non_kv_flag_on_entry(self, cache): + """BlockCacheEntry gets has_non_kv flag.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + entry = cache._request_tables["req-1"] + assert entry.has_non_kv is True + + def test_has_non_kv_false_for_pure_kv(self, cache): + """Pure KV entries have has_non_kv=False.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + entry = cache._request_tables["req-1"] + assert entry.has_non_kv is False +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestStoreHybridCache -v 2>&1 | tail -10` +Expected: FAIL — `_non_kv_states` doesn't exist, `has_non_kv` not on `BlockCacheEntry` + +- [ ] **Step 3: Implement changes to `BlockAwarePrefixCache`** + +3a. Add `has_non_kv` field to `BlockCacheEntry` (around line 375): + +```python +@dataclass +class BlockCacheEntry: + """Entry mapping a token sequence to cache blocks.""" + + block_table: BlockTable + cache_data: List[Any] # Actual KV cache data per block + last_access: float + has_non_kv: bool = False # True if model has non-positional cache layers +``` + +3b. Add `_non_kv_states` dict to `__init__` (after `self._tokens_saved = 0`, around line 434): + +```python + # Non-KV layer states for hybrid models (SSM, linear attention). + # Keyed by tuple(block_ids) for lookup during reconstruction. + self._non_kv_states: Dict[Tuple[int, ...], NonKVCacheData] = {} +``` + +3c. In `store_cache()`, after block allocation loop and before creating `BlockCacheEntry` (around line 613), add non-KV extraction: + +```python + # Extract and store non-KV layer states for hybrid models + has_non_kv = False + if is_tensor_data and cache_data: + non_kv_indices = [] + non_kv_states_list = [] + non_kv_meta_list = [] + non_kv_refs = [] + for idx, layer_state in enumerate(cache_data): + if not _is_kv_layer(layer_state): + non_kv_indices.append(idx) + non_kv_states_list.append(layer_state.get("state")) + non_kv_meta_list.append(layer_state.get("meta_state")) + non_kv_refs.append(layer_state.get("class_ref")) + if non_kv_indices: + has_non_kv = True + block_key = tuple(block_table.block_ids) + self._non_kv_states[block_key] = NonKVCacheData( + layer_indices=non_kv_indices, + states=non_kv_states_list, + meta_states=non_kv_meta_list, + class_refs=non_kv_refs, + total_layers=len(cache_data), + ) +``` + +3d. Update `BlockCacheEntry` creation (around line 617) to include `has_non_kv`: + +```python + self._request_tables[request_id] = BlockCacheEntry( + block_table=block_table, + cache_data=cache_data, + last_access=time.time(), + has_non_kv=has_non_kv, + ) +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestStoreHybridCache -v 2>&1 | tail -10` +Expected: 4 passed + +- [ ] **Step 5: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/prefix_cache.py tests/test_prefix_cache_hybrid.py +git commit -m "feat: store_cache separates KV and non-KV layer states + +Hybrid models' non-KV states (ArraysCache) are stored once per +prefix in _non_kv_states dict, keyed by block_ids. KV layers +continue to be block-sliced as before. No memory duplication." +``` + +--- + +### Task 4: Rewrite `reconstruct_cache()` for hybrid models + +**Files:** +- Modify: `vllm_mlx/prefix_cache.py:772-891` (`reconstruct_cache`) +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_prefix_cache_hybrid.py`: + +```python +class TestReconstructHybridCache: + """Test reconstruct_cache with hybrid model cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_pure_kv_reconstruct_unchanged(self, cache): + """Pure KV model: reconstruct works as before.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + result = cache.reconstruct_cache(bt) + assert result is not None + assert len(result) == 4 + for c in result: + assert hasattr(c, "keys") + assert hasattr(c, "values") + + def test_hybrid_reconstruct_all_layers(self, cache): + """Hybrid model: reconstructs both KV and non-KV layers.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + result = cache.reconstruct_cache(bt) + assert result is not None + assert len(result) == 8 # All layers present + # KV layers (3, 7) should be KVCache + assert hasattr(result[3], "keys") + assert hasattr(result[7], "keys") + # Non-KV layers (0,1,2,4,5,6) should be ArraysCache + assert isinstance(result[0], ArraysCache) + assert isinstance(result[4], ArraysCache) + + def test_hybrid_reconstruct_missing_non_kv_returns_none(self, cache): + """If non-KV states are missing, return None (can't reconstruct).""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + # Delete non-KV states to simulate missing data + cache._non_kv_states.clear() + result = cache.reconstruct_cache(bt) + assert result is None +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestReconstructHybridCache -v 2>&1 | tail -10` +Expected: FAIL — `test_hybrid_reconstruct_all_layers` fails (current code only handles KV) + +- [ ] **Step 3: Rewrite `reconstruct_cache()` in `prefix_cache.py`** + +Replace lines 772-891 (the entire method) with: + +```python + def reconstruct_cache( + self, + block_table: BlockTable, + ) -> Optional[List[Any]]: + """ + Reconstruct cache objects from stored block tensor data. + + For pure-KV models: concatenates KV block slices into KVCache objects. + For hybrid models: also restores non-KV layers (ArraysCache, etc.) + from stored whole-sequence state via from_state(). + + Args: + block_table: BlockTable containing block IDs to reconstruct from + + Returns: + List of reconstructed cache objects (one per layer), + or None if reconstruction fails + """ + if not block_table or not block_table.block_ids: + return None + + if not HAS_MLX: + logger.warning("Cannot reconstruct cache: MLX not available") + return None + + try: + # Check for non-KV states (hybrid model) + block_key = tuple(block_table.block_ids) + non_kv_data = self._non_kv_states.get(block_key) + has_non_kv = non_kv_data is not None + + # Collect cache data from all blocks + all_block_data = [] + for block_id in block_table.block_ids: + block = self.paged_cache.allocated_blocks.get(block_id) + if not block: + logger.warning( + f"Block {block_id} not found in allocated blocks" + ) + return None + + if block.cache_data is None: + logger.debug(f"Block {block_id} has no tensor data stored") + return None + + all_block_data.append(block.cache_data) + + if not all_block_data: + return None + + # Determine total number of layers from block data + num_layers = len(all_block_data[0]) + if num_layers == 0: + return None + + # If hybrid model but no non-KV states, can't reconstruct + if has_non_kv and non_kv_data.total_layers != num_layers: + logger.warning( + f"Layer count mismatch: blocks have {num_layers}, " + f"non-KV data expects {non_kv_data.total_layers}" + ) + return None + + # Build set of non-KV layer indices for fast lookup + non_kv_idx_set = set() + if has_non_kv: + non_kv_idx_set = set(non_kv_data.layer_indices) + + # Check if any non-KV layers exist in the data but we have + # no stored states — indicates hybrid model with missing data + if not has_non_kv: + for layer_idx in range(num_layers): + # Check first block: if a layer slot is None, it's non-KV + if all_block_data[0][layer_idx] is None: + logger.debug( + "Hybrid model detected but no non-KV states " + "stored — cannot reconstruct" + ) + return None + + # Reconstruct each layer + reconstructed_caches = [] + + for layer_idx in range(num_layers): + if layer_idx in non_kv_idx_set: + # Non-KV layer: restore from stored whole-sequence state + pos = non_kv_data.layer_indices.index(layer_idx) + state = non_kv_data.states[pos] + meta = non_kv_data.meta_states[pos] + cls = non_kv_data.class_refs[pos] + + if cls is not None and hasattr(cls, "from_state"): + cache_obj = cls.from_state(state, meta) + else: + # Can't reconstruct without class ref + logger.warning( + f"No class_ref for non-KV layer {layer_idx}" + ) + return None + + reconstructed_caches.append(cache_obj) + else: + # KV layer: concatenate block slices + layer_keys = [] + layer_values = [] + + for block_data in all_block_data: + if layer_idx < len(block_data): + entry = block_data[layer_idx] + if entry is not None: + keys_slice, values_slice = entry + layer_keys.append(keys_slice) + layer_values.append(values_slice) + + if not layer_keys: + logger.debug( + f"No KV data for layer {layer_idx}" + ) + return None + + # Concatenate along sequence dimension (axis 2) + concat_keys = mx.concatenate(layer_keys, axis=2) + concat_values = mx.concatenate(layer_values, axis=2) + + try: + from mlx_lm.models.cache import KVCache + + cache_obj = KVCache() + cache_obj.keys = concat_keys + cache_obj.values = concat_values + cache_obj.offset = concat_keys.shape[2] + reconstructed_caches.append(cache_obj) + + except ImportError: + class SimpleKVCache: + def __init__(self, keys, values): + self.keys = keys + self.values = values + self.offset = keys.shape[2] + + @property + def state(self): + return (self.keys, self.values) + + @property + def meta_state(self): + return (str(self.offset),) + + reconstructed_caches.append( + SimpleKVCache(concat_keys, concat_values) + ) + + if not reconstructed_caches: + return None + + logger.debug( + f"Reconstructed cache: {len(reconstructed_caches)} layers " + f"({len(non_kv_idx_set)} non-KV), " + f"{block_table.num_tokens} tokens from " + f"{len(block_table.block_ids)} blocks" + ) + + return reconstructed_caches + + except Exception as e: + logger.warning(f"Failed to reconstruct cache: {e}") + import traceback + + logger.debug(traceback.format_exc()) + return None +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestReconstructHybridCache -v 2>&1 | tail -10` +Expected: 3 passed + +- [ ] **Step 5: Run full test suite for regression** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache.py tests/test_prefix_cache_hybrid.py -v 2>&1 | tail -20` +Expected: All pass + +- [ ] **Step 6: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/prefix_cache.py tests/test_prefix_cache_hybrid.py +git commit -m "feat: reconstruct_cache handles hybrid model layers + +KV layers are concatenated from block slices. Non-KV layers +(ArraysCache) are restored via from_state() from stored whole- +sequence state. If non-KV states are missing, returns None to +force safe recomputation." +``` + +--- + +## Chunk 3: Fetch Guard, Cleanup, and Scheduler Fix + +### Task 5: Add partial prefix rejection guard to `fetch_cache()` + +**Files:** +- Modify: `vllm_mlx/prefix_cache.py:436-510` (`fetch_cache`) +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_prefix_cache_hybrid.py`: + +```python +class TestFetchHybridCache: + """Test fetch_cache with hybrid model prefix matching.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_full_match_hybrid_returns_cache(self, cache): + """Full prefix match with non-KV states → cache hit.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + # Same tokens → should find prefix + bt, remaining = cache.fetch_cache("req-2", tokens) + # May or may not hit depending on paged cache internals, + # but should NOT crash + # The key assertion: no exception raised + + def test_pure_kv_partial_match_still_works(self, cache): + """Pure KV model: partial prefix reuse is allowed.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=192) + tokens = list(range(192)) + cache.store_cache("req-1", tokens, data) + # Different request with shorter prefix → partial match OK for pure KV + shorter_tokens = list(range(128)) + bt, remaining = cache.fetch_cache("req-2", shorter_tokens) + # Should not crash regardless of match result + + def test_cleanup_removes_non_kv_states(self, cache): + """release_cache removes non-KV states for the request.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 1 + cache.release_cache("req-1") + # Non-KV states for this request's blocks should be cleaned up + # (they may persist for shared blocks, but the entry is gone) + + def test_clear_removes_all_non_kv_states(self, cache): + """clear() removes all non-KV states.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + cache.store_cache("req-1", list(range(128)), data) + cache.store_cache("req-2", list(range(64, 192)), data) + cache.clear() + assert len(cache._non_kv_states) == 0 +``` + +- [ ] **Step 2: Run test to verify failures** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestFetchHybridCache -v 2>&1 | tail -10` +Expected: At least `test_cleanup_removes_non_kv_states` and `test_clear_removes_all_non_kv_states` fail + +- [ ] **Step 3: Implement fetch guard and cleanup** + +3a. In `fetch_cache()`, after finding shared blocks via `paged_cache.find_shared_prefix()` (around line 459), add a guard before returning: + +```python + # Guard: reject partial prefix match for hybrid models. + # If non-KV states don't exist for this exact block set, + # a hybrid model can't be correctly reconstructed (SSM/KV mismatch). + if self._non_kv_states: + candidate_key = tuple(block_table.block_ids) + if candidate_key not in self._non_kv_states: + # Hybrid model with partial match — reject + logger.debug( + f"Rejecting partial prefix match for {request_id}: " + f"hybrid model requires full match with non-KV states" + ) + self.paged_cache.delete_block_table(request_id) + self._misses += 1 + return None, tokens +``` + +3b. Same guard in the `_find_best_prefix_match` branch (around line 505): + +```python + # Same hybrid guard for prefix index matches + if self._non_kv_states: + candidate_key = tuple(block_table.block_ids) + if candidate_key not in self._non_kv_states: + logger.debug( + f"Rejecting prefix index match for {request_id}: " + f"hybrid model requires full match with non-KV states" + ) + self.paged_cache.delete_block_table(request_id) + self._misses += 1 + return None, tokens +``` + +3c. In `release_cache()` (around line 724), clean up non-KV states: + +```python + def release_cache(self, request_id: str) -> None: + """Release cache blocks for a completed request.""" + entry = self._request_tables.pop(request_id, None) + if entry: + # Clean up non-KV states if this was the last reference + block_key = tuple(entry.block_table.block_ids) + # Only remove if no other request uses the same blocks + other_uses = any( + tuple(e.block_table.block_ids) == block_key + for e in self._request_tables.values() + ) + if not other_uses: + self._non_kv_states.pop(block_key, None) + self.paged_cache.delete_block_table(request_id) + logger.debug(f"Released cache for {request_id}") +``` + +3d. In `clear()` (around line 954), add cleanup: + +```python + def clear(self) -> None: + """Clear all cached data.""" + self._request_tables.clear() + self._prefix_index.clear() + self._non_kv_states.clear() + self.paged_cache.clear() + self.reset_stats() +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestFetchHybridCache -v 2>&1 | tail -10` +Expected: 4 passed + +- [ ] **Step 5: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/prefix_cache.py tests/test_prefix_cache_hybrid.py +git commit -m "feat: fetch_cache rejects partial prefix for hybrid models + +Partial prefix matches on hybrid models would produce mismatched +KV/SSM state. Guard checks non-KV states exist for exact block +set. Cleanup in release_cache and clear." +``` + +--- + +### Task 6: Fix scheduler fallback in `_reconstruct_cache_from_states()` + +**Files:** +- Modify: `vllm_mlx/scheduler.py:1486-1496` +- Test: `tests/test_prefix_cache_hybrid.py` + +- [ ] **Step 1: Write the failing test** + +Add to `tests/test_prefix_cache_hybrid.py`: + +```python +class TestSchedulerRobustness: + """Test scheduler cache reconstruction with non-KV layers.""" + + def test_reconstruct_with_arrays_cache_uses_from_state(self): + """ArraysCache layer uses from_state, not manual KV reconstruction.""" + from vllm_mlx.scheduler import Scheduler, SchedulerConfig + + config = SchedulerConfig(enable_prefix_cache=False) + # We can't easily construct a full Scheduler, so test the method + # indirectly by verifying the data path. + # The key check: _reconstruct_cache_from_states handles ArraysCache + # via from_state (class_ref path), NOT the fallback. + arrays_state = { + "state": [mx.zeros((1, 3, 128)), mx.zeros((1, 8, 64, 64))], + "meta_state": "", + "class_name": "ArraysCache", + "class_ref": ArraysCache, + } + kv_state = _make_kv_state(seq_len=100) + extracted = [arrays_state, kv_state, arrays_state, kv_state] + + # Test reconstruction + # ArraysCache has from_state, KVCache has from_state + # Both should work through the class_ref path + for layer_state in extracted: + cls = layer_state["class_ref"] + state = layer_state["state"] + meta = layer_state.get("meta_state", "") + # This should not raise + obj = cls.from_state(state, meta) + assert obj is not None +``` + +- [ ] **Step 2: Run test to verify it passes (sanity check)** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache_hybrid.py::TestSchedulerRobustness -v 2>&1 | tail -5` +Expected: PASS (from_state works for both types) + +- [ ] **Step 3: Add defensive guard to scheduler fallback** + +In `vllm_mlx/scheduler.py`, modify the fallback path at lines 1486-1496. Replace: + +```python + else: + # Fallback: try KVCache manual reconstruction + from mlx_lm.models.cache import KVCache + + if len(state) != 2: + return None + cache = KVCache() + cache.keys, cache.values = state + cache.offset = ( + int(meta_state[0]) if meta_state else cache.keys.shape[2] + ) +``` + +With: + +```python + else: + # Fallback: try KVCache manual reconstruction + from mlx_lm.models.cache import KVCache + + if ( + not isinstance(state, (tuple, list)) + or len(state) != 2 + or not hasattr(state[0], "shape") + or state[0].ndim != 4 + ): + logger.debug( + f"[mid_prefill_cache] skipping non-KV layer " + f"(state type={type(state).__name__})" + ) + return None + cache = KVCache() + cache.keys, cache.values = state + cache.offset = ( + int(meta_state[0]) if meta_state else cache.keys.shape[2] + ) +``` + +- [ ] **Step 4: Run existing scheduler tests for regression** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/ -k "scheduler" -v 2>&1 | tail -15` +Expected: All pass + +- [ ] **Step 5: Commit** + +```bash +cd ~/code/vllm-mlx +git add vllm_mlx/scheduler.py tests/test_prefix_cache_hybrid.py +git commit -m "fix: scheduler fallback guards against non-KV cache states + +The fallback path in _reconstruct_cache_from_states now checks +that state looks like a KV tuple (2-element, 4D tensors) before +assuming shape[2] for offset. Prevents crash on ArraysCache states." +``` + +--- + +### Task 7: Run full test suite and format + +- [ ] **Step 1: Run ruff linter** + +Run: `cd ~/code/vllm-mlx && ruff check vllm_mlx/prefix_cache.py vllm_mlx/scheduler.py tests/test_prefix_cache_hybrid.py --select E,F,W --ignore E501 2>&1` +Expected: All checks passed + +- [ ] **Step 2: Run black formatter** + +Run: `cd ~/code/vllm-mlx && black vllm_mlx/prefix_cache.py vllm_mlx/scheduler.py tests/test_prefix_cache_hybrid.py 2>&1` +Expected: Files reformatted (or already formatted) + +- [ ] **Step 3: Run full prefix cache test suite** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/test_prefix_cache.py tests/test_prefix_cache_hybrid.py tests/test_paged_cache.py -v 2>&1 | tail -25` +Expected: All pass + +- [ ] **Step 4: Run all tests** + +Run: `cd ~/code/vllm-mlx && python -m pytest tests/ -v --timeout 60 2>&1 | tail -30` +Expected: All pass, 0 failures + +- [ ] **Step 5: Commit formatting if needed** + +```bash +cd ~/code/vllm-mlx +git add -u +git diff --cached --stat +# Only commit if there are formatting changes +git commit -m "style: formatting for hybrid prefix cache changes" || true +``` + +- [ ] **Step 6: Push to PR #165 branch** + +```bash +cd ~/code/vllm-mlx +git push origin fix/mllm-continuous-batching-hybrid-models +``` diff --git a/docs/superpowers/specs/2026-03-16-hybrid-model-prefix-cache-design.md b/docs/superpowers/specs/2026-03-16-hybrid-model-prefix-cache-design.md new file mode 100644 index 00000000..fd571b17 --- /dev/null +++ b/docs/superpowers/specs/2026-03-16-hybrid-model-prefix-cache-design.md @@ -0,0 +1,129 @@ +# Hybrid Model Prefix Cache — Design Spec + +**Date:** 2026-03-16 +**Issues:** #142, #136 +**Branch:** `fix/mllm-continuous-batching-hybrid-models` (PR #165) + +## Problem + +`BlockAwarePrefixCache` crashes on hybrid models (Qwen 3.5 MoE, Nemotron) with: +``` +WARNING: Failed to extract block tensor slice: Too many indices for array with 3 dimensions. +``` +Prefix cache silently degrades to no-op. No TTFT benefit on repeated contexts. + +## Root Cause + +The issue reporters hypothesized that MoE models produce 3D KV tensors. **This is wrong.** All KV cache tensors in mlx-lm are always 4D `(B, n_kv_heads, seq_len, head_dim)` regardless of model architecture. + +The actual cause: hybrid models have two cache layer types: + +| Layer Type | Cache Class | `.state` Format | Positional? | +|-----------|-------------|-----------------|-------------| +| Attention | KVCache | `(keys_4D, values_4D)` | Yes — indexed by token position | +| GatedDeltaNet / Mamba | ArraysCache | `[conv_3D, recurrent_4D]` | No — cumulative summary | + +`_extract_cache_states()` extracts state from ALL layers. `_extract_block_tensor_slice()` then tries to slice every layer as 4D KV: + +```python +keys, values = layer_state["state"] # Unpacks ArraysCache [conv_3D, ssm_4D] +keys[:, :, start:end, :] # 4 indices on 3D conv_state → IndexError +``` + +## Design + +### Layer Classification + +New helper function identifies cache layer types by `class_name` (already stored by `_extract_cache_states()`): + +```python +_KV_CACHE_CLASSES = frozenset({ + "KVCache", "RotatingKVCache", "QuantizedKVCache", + "ChunkedKVCache", "ConcatenateKVCache", "BatchKVCache", + "BatchRotatingKVCache", +}) + +def _is_kv_layer(layer_state: dict) -> bool: + return layer_state.get("class_name", "") in _KV_CACHE_CLASSES +``` + +### Separate Storage Model + +KV and non-KV layers are stored differently — no duplication: + +- **KV layers** → block-sliced along seq_dim (axis=2), stored per-block in `KVCacheBlock.cache_data` +- **Non-KV layers** → stored once as whole-sequence state in `BlockAwarePrefixCache._non_kv_states`, keyed by `tuple(block_ids)` + +```python +@dataclass +class NonKVCacheData: + """Full state for non-positional cache layers (SSM, linear attention).""" + layer_indices: List[int] # Position in the full layer list + states: List[Any] # From cache.state (list/tuple of arrays) + meta_states: List[Any] # From cache.meta_state + class_refs: List[type] # For from_state() reconstruction + total_layers: int # Total layer count (KV + non-KV) +``` + +### Partial Prefix Rejection + +For hybrid models, partial prefix reuse (matching some but not all blocks) would restore KV cache but NOT SSM state. The model would generate with mismatched attention/SSM state, producing incorrect output. + +Guard: when `fetch_cache()` finds a partial match and the model has non-KV layers, return cache miss instead. + +Full matches work because non-KV states are stored keyed by exact `tuple(block_ids)`. + +### Changes + +#### `prefix_cache.py` + +1. **`_is_kv_layer()`** — classify layers by class_name + +2. **`_extract_block_tensor_slice()`**: + - Skip non-KV layers (append `None` placeholder) + - KV layers: slice `keys[:,:,start:end,:]` as before + - Return list with `None` gaps at non-KV positions + +3. **`store_cache()`**: + - Classify layers into KV and non-KV + - KV: block-slice via `_extract_block_tensor_slice()` + - Non-KV: extract states, store in `self._non_kv_states[tuple(block_ids)]` + - Set `has_non_kv` flag on `BlockCacheEntry` for fast checks + +4. **`reconstruct_cache()`**: + - Look up non-KV states via `tuple(block_table.block_ids)` + - KV layers: concatenate block slices → KVCache (as before) + - Non-KV layers: `class_ref.from_state(state, meta_state)` + - Interleave in correct layer order using stored indices + - If non-KV states missing for hybrid model → return `None` + +5. **`fetch_cache()` partial match guard**: + - After `find_shared_prefix()` or `_find_best_prefix_match()` + - Check if non-KV states exist for the matched block set + - Missing → return `(None, tokens)` (force recomputation) + +6. **Cleanup**: `release_cache()`, `clear()` clean `_non_kv_states` + +#### `scheduler.py` + +- `_reconstruct_cache_from_states()` fallback path (L1490): guard `shape[2]` access with `hasattr(state[0], 'shape') and len(state) == 2` before assuming KV format + +### What This Does NOT Change + +- `PrefixCacheManager` (trie-based) — already works for hybrid models (stores whole cache objects) +- `_extract_cache_states()` — already correctly stores `class_name` and `class_ref` +- `_can_trim_cache()` / `_trim_cache()` — already handles hybrid models (checks all layers) +- KVCache tensor shape handling — always 4D, no ndim checks needed +- Pure-KV model behavior — unchanged, no non-KV layers detected + +### Tests + +New test file: `tests/test_prefix_cache_hybrid.py` + +1. **`test_is_kv_layer`** — classification of KVCache vs ArraysCache +2. **`test_extract_block_tensor_slice_hybrid`** — mixed KV/ArraysCache layers, KV sliced, ArraysCache skipped +3. **`test_store_and_reconstruct_hybrid`** — full roundtrip with simulated hybrid model cache +4. **`test_partial_prefix_rejected_for_hybrid`** — partial match → cache miss +5. **`test_full_prefix_hit_for_hybrid`** — exact match → correct reconstruction +6. **`test_pure_kv_model_unchanged`** — regression: existing behavior for pure attention models +7. **`test_cleanup_non_kv_states`** — release_cache and clear remove non-KV data diff --git a/tests/smoke_test_specdec.py b/tests/smoke_test_specdec.py new file mode 100644 index 00000000..452202c7 --- /dev/null +++ b/tests/smoke_test_specdec.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +""" +Smoke test for speculative decoding with real models. + +Usage: python tests/smoke_test_specdec.py + +Uses Qwen3.5-35B-A3B-8bit as target, Qwen3.5-4B-4bit as draft. +Tests the SimpleEngine path (mlx_lm.stream_generate with draft_model). +""" + +import os +import sys +import time + +# Add project to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +TARGET = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-35B-A3B-8bit") +DRAFT = os.path.expanduser("~/ai-models/mlx_models/Qwen3.5-4B-4bit") +PROMPT = "What is the capital of France? Answer in one sentence." +MAX_TOKENS = 64 +NUM_DRAFT = 3 + + +def test_without_draft(): + """Baseline: generate without speculative decoding.""" + from mlx_lm import load, stream_generate + + print("=" * 60) + print("Loading target model (no draft)...") + model, tokenizer = load(TARGET) + print(f"Target loaded. Generating {MAX_TOKENS} tokens...") + + tokens = [] + t0 = time.perf_counter() + for resp in stream_generate(model, tokenizer, prompt=PROMPT, max_tokens=MAX_TOKENS): + tokens.append(resp.token) + elapsed = time.perf_counter() - t0 + text = tokenizer.decode(tokens) + print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):") + print(f" {text}") + print() + return len(tokens), elapsed + + +def test_with_draft(): + """Speculative: generate with draft model.""" + from mlx_lm import load, stream_generate + + print("=" * 60) + print("Loading target + draft model...") + model, tokenizer = load(TARGET) + draft_model, _ = load(DRAFT) + + # Verify vocab match — walk model structure to find embed_tokens + def _get_vocab_size(m): + for attr in ["model", "language_model"]: + sub = getattr(m, attr, None) + if sub is not None: + et = getattr(sub, "embed_tokens", None) + if et is not None: + return et.weight.shape[0] + return None + + target_vocab = _get_vocab_size(model) + draft_vocab = _get_vocab_size(draft_model) + print(f"Target vocab: {target_vocab}, Draft vocab: {draft_vocab}") + if target_vocab and draft_vocab: + assert target_vocab == draft_vocab, "Vocab size mismatch!" + + print(f"Generating {MAX_TOKENS} tokens with num_draft_tokens={NUM_DRAFT}...") + + tokens = [] + from_draft_count = 0 + t0 = time.perf_counter() + for resp in stream_generate( + model, + tokenizer, + prompt=PROMPT, + max_tokens=MAX_TOKENS, + draft_model=draft_model, + num_draft_tokens=NUM_DRAFT, + ): + tokens.append(resp.token) + if resp.from_draft: + from_draft_count += 1 + elapsed = time.perf_counter() - t0 + + text = tokenizer.decode(tokens) + accept_rate = from_draft_count / len(tokens) * 100 if tokens else 0 + print(f"Output ({len(tokens)} tokens, {len(tokens)/elapsed:.1f} tok/s):") + print(f" {text}") + print(f"Draft acceptance: {from_draft_count}/{len(tokens)} ({accept_rate:.0f}%)") + print() + return len(tokens), elapsed + + +if __name__ == "__main__": + print("Speculative Decoding Smoke Test") + print("Target:", TARGET) + print("Draft:", DRAFT) + print() + + n1, t1 = test_without_draft() + # Clear model from memory + import gc + import mlx.core as mx + + gc.collect() + mx.clear_cache() + + n2, t2 = test_with_draft() + + print("=" * 60) + print("RESULTS:") + print(f" Without draft: {n1} tokens in {t1:.2f}s ({n1/t1:.1f} tok/s)") + print(f" With draft: {n2} tokens in {t2:.2f}s ({n2/t2:.1f} tok/s)") + if t1 > 0 and t2 > 0: + speedup = (n1 / t1) / (n2 / t2) if n2 / t2 > 0 else 0 + print(f" Speedup: {1/speedup:.2f}x" if speedup > 0 else " N/A") diff --git a/tests/test_mllm_mtp_routing.py b/tests/test_mllm_mtp_routing.py new file mode 100644 index 00000000..e2394cf6 --- /dev/null +++ b/tests/test_mllm_mtp_routing.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for MLLM + MTP per-request routing.""" + + +def test_has_media_content_text_only(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Hello"}]) is False + + +def test_has_media_content_with_image(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's this?"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_with_video(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "video_url", "video_url": {"url": "file:///tmp/v.mp4"}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_empty(): + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([]) is False + + +def test_has_media_content_string_content(): + """String content (not list) should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + assert _has_media_content([{"role": "user", "content": "Just text"}]) is False + + +def test_has_media_content_audio(): + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,..."}} + ], + } + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_multi_turn(): + """Media in earlier turns should still be detected.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Look at this"}, + { + "type": "image_url", + "image_url": {"url": "data:image/png;base64,..."}, + }, + ], + }, + {"role": "assistant", "content": "I see an image."}, + {"role": "user", "content": "Tell me more about it."}, + ] + assert _has_media_content(messages) is True + + +def test_has_media_content_text_list(): + """List content with only text parts should return False.""" + from vllm_mlx.engine.simple import _has_media_content + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + ], + } + ] + assert _has_media_content(messages) is False + + +# --- MLXMultimodalLM extraction method tests --- + +from unittest.mock import MagicMock + + +def test_get_language_model(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_lm = MagicMock() + mllm.model = MagicMock() + mllm.model.language_model = inner_lm + assert MLXMultimodalLM.get_language_model(mllm) is inner_lm + + +def test_get_tokenizer(): + from vllm_mlx.models.mllm import MLXMultimodalLM + + mllm = MagicMock(spec=MLXMultimodalLM) + inner_tok = MagicMock() + mllm.processor = MagicMock() + mllm.processor.tokenizer = inner_tok + assert MLXMultimodalLM.get_tokenizer(mllm) is inner_tok diff --git a/tests/test_speculative_decoding.py b/tests/test_speculative_decoding.py new file mode 100644 index 00000000..3f9a8b72 --- /dev/null +++ b/tests/test_speculative_decoding.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for speculative decoding with a separate draft model (SimpleEngine path).""" + +import pytest + +try: + import mlx.core as mx # noqa: F401 + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Tests: CLI args +# --------------------------------------------------------------------------- + + +class TestCLIArgs: + def test_draft_model_arg_parsing(self): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--draft-model", type=str, default=None) + parser.add_argument("--num-draft-tokens", type=int, default=3) + + args = parser.parse_args( + ["--draft-model", "/path/to/model", "--num-draft-tokens", "5"] + ) + assert args.draft_model == "/path/to/model" + assert args.num_draft_tokens == 5 + + def test_default_num_draft_tokens(self): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--num-draft-tokens", type=int, default=3) + + args = parser.parse_args([]) + assert args.num_draft_tokens == 3 + + +# --------------------------------------------------------------------------- +# Tests: SimpleEngine draft model +# --------------------------------------------------------------------------- + + +class TestSimpleEngineDraftModel: + def test_draft_model_params_stored(self): + from vllm_mlx.engine.simple import SimpleEngine + + engine = SimpleEngine( + model_name="test-model", + draft_model_path="/path/to/draft", + num_draft_tokens=5, + ) + assert engine._draft_model_path == "/path/to/draft" + assert engine._num_draft_tokens == 5 + + def test_no_draft_model_by_default(self): + from vllm_mlx.engine.simple import SimpleEngine + + engine = SimpleEngine(model_name="test-model") + assert engine._draft_model_path is None + assert engine._num_draft_tokens == 3 + + +class TestMLXLanguageModelDraftModel: + def test_draft_model_params_stored(self): + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel( + model_name="test-model", + draft_model_path="/path/to/draft", + num_draft_tokens=5, + ) + assert model._draft_model_path == "/path/to/draft" + assert model._num_draft_tokens == 5 + assert model.draft_model is None + + def test_no_draft_model_by_default(self): + from vllm_mlx.models.llm import MLXLanguageModel + + model = MLXLanguageModel(model_name="test-model") + assert model._draft_model_path is None + assert model.draft_model is None diff --git a/tests/test_text_model_from_vlm.py b/tests/test_text_model_from_vlm.py new file mode 100644 index 00000000..037ff810 --- /dev/null +++ b/tests/test_text_model_from_vlm.py @@ -0,0 +1,140 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for building mlx_lm TextModel from mlx_vlm-loaded weights.""" + +import json +from pathlib import Path + +import pytest + +from vllm_mlx.text_model_from_vlm import build_text_model + +# VLM+MTP model (created by merging mlx-community VLM + our MTP weights) +VLM_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-VLM-MTP-8bit" + +# Text-only MTP model (no vision tower — can't test VLM loading) +TEXT_MTP_MODEL = Path.home() / "ai-models/mlx_models/Qwen3.5-35B-A3B-8bit" + + +def test_build_text_model_no_config(): + """Returns None when model path has no config.json.""" + result = build_text_model(None, "/nonexistent/path") + assert result is None + + +def test_build_text_model_none_vlm(): + """Returns None when vlm_model is None.""" + result = build_text_model(None, TEXT_MTP_MODEL) + assert result is None + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_build_text_model_moe(): + """build_text_model creates a TextModel with shared weights and MTP (MoE).""" + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, processor = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + assert text_model is not None, "build_text_model returned None" + + # TextModel should have MTP (config has mtp_num_hidden_layers=1) + assert hasattr(text_model, "mtp"), "TextModel missing .mtp attribute" + assert text_model.mtp is not None, "TextModel.mtp is None" + assert hasattr(text_model, "mtp_forward"), "TextModel missing mtp_forward method" + assert hasattr( + text_model, "make_mtp_cache" + ), "TextModel missing make_mtp_cache method" + + # Verify MoE layer exists in MTP + mtp_layer = text_model.mtp.layers[0] + assert hasattr(mtp_layer, "mlp"), "MTP layer missing mlp" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_mtp_forward(): + """TextModel.mtp_forward returns logits of correct vocab_size shape.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + mtp_cache = text_model.make_mtp_cache() + assert len(mtp_cache) > 0 + + hidden = mx.zeros((1, 1, text_config["hidden_size"])) + next_ids = mx.array([[0]]) + logits = text_model.mtp_forward(hidden, next_ids, mtp_cache) + + assert ( + logits.shape[-1] == text_config["vocab_size"] + ), f"Expected vocab_size={text_config['vocab_size']}, got {logits.shape[-1]}" + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_text_model_return_hidden(): + """TextModel supports return_hidden=True (required by mtp_generate_step).""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + config = json.loads((VLM_MTP_MODEL / "config.json").read_text()) + text_config = config.get("text_config", config) + + cache = text_model.make_cache() + tokens = mx.array([[1, 2, 3]]) # Dummy token IDs + + # return_hidden=True should return (logits, hidden_states) + result = text_model(tokens, cache=cache, return_hidden=True) + + # Should be a tuple of (logits, hidden) + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + logits, hidden = result + assert logits.shape[-1] == text_config["vocab_size"] + assert hidden.shape[-1] == text_config["hidden_size"] + + +@pytest.mark.skipif(not VLM_MTP_MODEL.exists(), reason="VLM+MTP model not on disk") +def test_weight_sharing(): + """Backbone weights are shared (zero-copy) between vlm and TextModel.""" + import mlx.core as mx + import runtime_patches + + runtime_patches.apply() + + from mlx_vlm import load as vlm_load + + vlm_model, _ = vlm_load(str(VLM_MTP_MODEL)) + text_model = build_text_model(vlm_model, VLM_MTP_MODEL) + + # Compare a backbone weight reference. + # Layer 0 may be linear_attn (GatedDeltaNet) on MoE models, so find a layer + # with self_attn (full attention layers are at indices 11, 15, 19, 23, 27). + for i in range(len(vlm_model.language_model.model.layers)): + layer = vlm_model.language_model.model.layers[i] + if hasattr(layer, "self_attn"): + vlm_weight = layer.self_attn.q_proj.weight + tm_weight = text_model.model.layers[i].self_attn.q_proj.weight + assert mx.array_equal( + vlm_weight, tm_weight + ), f"Weights at layer {i} should be identical" + break + else: + pytest.fail("No layer with self_attn found") diff --git a/vllm_mlx/api/models.py b/vllm_mlx/api/models.py index aa7cb245..f2faf233 100644 --- a/vllm_mlx/api/models.py +++ b/vllm_mlx/api/models.py @@ -174,6 +174,11 @@ class ChatCompletionRequest(BaseModel): video_max_frames: int | None = None # Request timeout in seconds (None = use server default) timeout: float | None = None + # SpecPrefill per-request overrides (via extra_body) + specprefill: bool | None = ( + None # True=force, False=disable, None=use server default + ) + specprefill_keep_pct: float | None = None # Override keep percentage (0.1-1.0) class AssistantMessage(BaseModel): diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index dcbee8ac..53d25968 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -178,6 +178,16 @@ def serve_command(args): print(f"Prefix cache: max_entries={args.prefix_cache_size}") else: print("Mode: Simple (maximum throughput)") + if args.enable_mtp: + print("MTP: enabled (native speculative decoding)") + if args.enable_mtp and getattr(args, "mllm", False): + print("MTP + MLLM: per-request routing (text-only → MTP, media → MLLM)") + if args.specprefill and args.specprefill_draft_model: + print( + f"SpecPrefill: enabled (draft={args.specprefill_draft_model}, " + f"threshold={args.specprefill_threshold}, " + f"keep={args.specprefill_keep_pct*100:.0f}%)" + ) # Load model with unified server load_model( @@ -187,6 +197,11 @@ def serve_command(args): stream_interval=args.stream_interval if args.continuous_batching else 1, max_tokens=args.max_tokens, force_mllm=args.mllm, + mtp=args.enable_mtp, + specprefill_enabled=args.specprefill, + specprefill_draft_model_path=args.specprefill_draft_model, + specprefill_threshold=args.specprefill_threshold, + specprefill_keep_pct=args.specprefill_keep_pct, ) # Start server @@ -827,6 +842,36 @@ def main(): default=None, help="Pre-load an embedding model at startup (e.g. mlx-community/embeddinggemma-300m-6bit)", ) + # SpecPrefill (attention-based sparse prefill using draft model) + serve_parser.add_argument( + "--specprefill", + action="store_true", + default=False, + help="Enable SpecPrefill: use a small draft model to score token importance, " + "then sparse-prefill only the important tokens on the target model. " + "Reduces TTFT on long prompts. Requires --specprefill-draft-model.", + ) + serve_parser.add_argument( + "--specprefill-threshold", + type=int, + default=8192, + help="Minimum suffix tokens to trigger SpecPrefill (default: 8192). " + "Shorter prompts use full prefill (scoring overhead > savings).", + ) + serve_parser.add_argument( + "--specprefill-keep-pct", + type=float, + default=0.3, + help="Fraction of tokens to keep during sparse prefill (default: 0.3). " + "Lower = faster prefill but more quality loss.", + ) + serve_parser.add_argument( + "--specprefill-draft-model", + type=str, + default=None, + help="Path to small draft model for SpecPrefill importance scoring. " + "Must share the same tokenizer as the target model.", + ) # Bench command bench_parser = subparsers.add_parser("bench", help="Run benchmark") bench_parser.add_argument("model", type=str, help="Model to benchmark") diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..c83519ee 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -11,16 +11,40 @@ LLM engine), so text-only requests must also be routed through it. """ +import asyncio import logging from collections.abc import AsyncIterator from typing import Any from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_output_text, extract_multimodal_content, is_mllm_model +from ..message_utils import _normalize_messages from .base import BaseEngine, GenerationOutput logger = logging.getLogger(__name__) +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + def _extract_media_from_messages(messages: list[dict[str, Any]]) -> tuple: """ @@ -137,6 +161,12 @@ def __init__( scheduler_config: Any | None = None, stream_interval: int = 1, force_mllm: bool = False, + mtp: bool = False, + prefill_step_size: int | None = None, + specprefill_enabled: bool = False, + specprefill_draft_model_path: str | None = None, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, ): """ Initialize the batched engine. @@ -147,12 +177,28 @@ def __init__( scheduler_config: Optional scheduler configuration stream_interval: Tokens to batch before streaming (1=every token) force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable MTP per-request routing (text-only → TextModel, media → MLLM) + prefill_step_size: Chunk size for prompt prefill (default 2048) + specprefill_enabled: Enable SpecPrefill sparse prefill + specprefill_draft_model_path: Draft model directory name under ~/ai-models/mlx_models/ + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill (default 8192) + specprefill_keep_pct: Fraction of tokens to keep (default 0.3) """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._scheduler_config = scheduler_config self._stream_interval = stream_interval self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp + self._prefill_step_size = prefill_step_size or 2048 + + # SpecPrefill configuration + self._specprefill_enabled = specprefill_enabled + self._specprefill_draft_model_path = specprefill_draft_model_path + self._specprefill_threshold = specprefill_threshold + self._specprefill_keep_pct = specprefill_keep_pct + self._specprefill_lock = asyncio.Lock() + self._draft_model = None self._model = None self._processor = None # For MLLM @@ -162,6 +208,16 @@ def __init__( self._mllm_instance = None # MLXMultimodalLM instance self._loaded = False + # Per-request routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + self._text_generation_lock = asyncio.Lock() + + # System prompt KV cache (reduces repeated prefill across requests) + self._system_kv_snapshot = None # List of (keys, values) per backbone layer + self._system_kv_hash = None # Hash of system prefix text + self._system_kv_token_count = 0 # Tokens in cached prefix + @property def model_name(self) -> str: """Get the model name.""" @@ -241,6 +297,73 @@ async def _start_mllm(self) -> None: f"completion_batch={completion_batch_size}" ) + # Build TextModel for MTP per-request routing (text-only → MTP, media → MLLM) + if self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model( + self._mllm_instance.model, self._model_name + ) + if self._text_model is not None: + # Get tokenizer from the MLLM instance (same model, shared tokenizer) + self._text_tokenizer = self._mllm_instance.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches SimpleEngine pattern) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + # Check if TextModel actually has MTP + has_mtp = ( + hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ) + if has_mtp: + logger.info( + "BatchedEngine MLLM+MTP routing: " + "text-only → TextModel (MTP), media → MLLM" + ) + else: + logger.warning( + "TextModel built but no MTP head — " + "text-only won't use MTP" + ) + self._text_model = None + self._text_tokenizer = None + except Exception as e: + logger.error(f"MTP TextModel build failed: {e}") + self._text_model = None + self._text_tokenizer = None + + # Load SpecPrefill draft model (for TextModel path — sparse cache + # is incompatible with MTP, so specprefill generates autoregressively) + if self._specprefill_enabled and self._specprefill_draft_model_path: + try: + from pathlib import Path + + from mlx_lm import load as mlx_lm_load + + draft_path = str( + Path.home() + / "ai-models" + / "mlx_models" + / self._specprefill_draft_model_path + ) + self._draft_model, _ = mlx_lm_load(draft_path) + logger.info( + "SpecPrefill draft model loaded: %s (threshold=%d, keep=%.0f%%)", + self._specprefill_draft_model_path, + self._specprefill_threshold, + self._specprefill_keep_pct * 100, + ) + except Exception as e: + logger.warning("Failed to load SpecPrefill draft model: %s", e) + self._specprefill_enabled = False + self._draft_model = None + async def _start_llm(self) -> None: """Start the LLM engine with AsyncEngineCore.""" from ..engine_core import AsyncEngineCore, EngineConfig @@ -327,6 +450,12 @@ async def stop(self) -> None: self._tokenizer = None self._processor = None self._mllm_instance = None + self._text_model = None + self._text_tokenizer = None + self._draft_model = None + self._system_kv_snapshot = None + self._system_kv_hash = None + self._system_kv_token_count = 0 self._loaded = False logger.info("BatchedEngine stopped") @@ -612,6 +741,20 @@ async def chat( if not self._loaded: await self.start() + # Normalize messages before any path (developer->system, merge consecutive) + messages = _normalize_messages(messages) + + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + return await self._chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ) + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -723,6 +866,22 @@ async def stream_chat( if not self._loaded: await self.start() + # Normalize messages before any path (developer->system, merge consecutive) + messages = _normalize_messages(messages) + + # Per-request MTP routing: text-only → TextModel, media → MLLM + if self._text_model is not None and not _has_media_content(messages): + async for output in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + yield output + return + # Extract images/videos from messages (OpenAI multimodal format) # Note: We only use extracted media here, messages are already processed by server _, extracted_images, extracted_videos = extract_multimodal_content(messages) @@ -755,6 +914,469 @@ async def stream_chat( ): yield output + async def _chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> GenerationOutput: + """Non-streaming text-only generation via mlx_lm TextModel with MTP. + + Collects all streaming output into a single GenerationOutput. + Used when MLLM+MTP routing is active and the request has no media. + """ + logger.info("Text-only request → TextModel (MTP) [non-streaming]") + accumulated_text = "" + last_chunk = None + async for chunk in self._stream_chat_text_model( + messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + **kwargs, + ): + accumulated_text = chunk.text + last_chunk = chunk + if last_chunk is not None: + return GenerationOutput( + text=accumulated_text, + prompt_tokens=last_chunk.prompt_tokens, + completion_tokens=last_chunk.completion_tokens, + finish_reason=last_chunk.finish_reason, + ) + return GenerationOutput(text="", finish_reason="stop") + + async def _stream_chat_text_model( + self, + messages: list[dict[str, Any]], + max_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + tools: list[dict] | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Streaming text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs the full generation in a single thread to maintain Metal safety. + + System prompt KV caching: on the first request, prefills system tokens + and snapshots backbone KV state. Subsequent requests with the same + system prompt restore the snapshot and only prefill the suffix tokens. + + SpecPrefill: when a draft model is loaded and the prompt exceeds the + threshold, uses attention-based sparse prefill for faster TTFT. + Composes with system KV cache (sparse-prefill only the suffix when + cache hits). Falls back to normal path on any error. + """ + import hashlib + import os + + import mlx.core as mx + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.models.cache import make_prompt_cache + from mlx_lm.sample_utils import make_sampler + + # Per-request specprefill overrides (from extra_body) + specprefill_override = kwargs.pop("specprefill", None) + specprefill_keep_pct_override = kwargs.pop("specprefill_keep_pct", None) + + # Convert tools for template + template_tools = convert_tools_for_template(tools) if tools else None + + # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Apply chat template + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if template_tools: + template_kwargs["tools"] = template_tools + + try: + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + # Template doesn't accept tools= or enable_thinking= + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + # Build sampler + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + # --- System KV cache: find system prefix boundary --- + # ChatML (Qwen 3.5): everything before first <|im_start|>user is the system prefix + USER_MARKER = "<|im_start|>user" + marker_pos = prompt.find(USER_MARKER) + if marker_pos > 0: + system_prefix = prompt[:marker_pos] + suffix = prompt[marker_pos:] + prefix_hash = hashlib.sha256(system_prefix.encode()).hexdigest()[:16] + else: + system_prefix = None + suffix = prompt + prefix_hash = None + + # Check for cache hit + cache_hit = ( + prefix_hash is not None + and prefix_hash == self._system_kv_hash + and self._system_kv_snapshot is not None + ) + + if cache_hit: + logger.info( + "Text-only request → TextModel (MTP) [streaming, system KV cache HIT: " + "reusing %d cached tokens, hash=%s]", + self._system_kv_token_count, + prefix_hash, + ) + else: + logger.info("Text-only request → TextModel (MTP) [streaming]") + + prefill_step_size = self._prefill_step_size + + # --- SpecPrefill decision --- + # Determine whether to use specprefill for this request. + # Must be decided before entering the generation lock so we can + # tokenize and check the threshold outside the critical section. + _SPECPREFILL_MAX_TOKENS = 196608 + use_specprefill = False + if self._draft_model is not None: + if specprefill_override is True: + use_specprefill = True + elif specprefill_override is None and self._specprefill_enabled: + use_specprefill = True + # specprefill_override=False explicitly disables + + # Tokenize to determine token count for specprefill threshold check. + # We need this for both specprefill and normal paths anyway. + sp_tokens = None # tokens to score (suffix or full prompt) + sp_offset = 0 # position offset for sparse_prefill + sp_n_total = 0 # total prompt tokens (for logging / threshold) + + if use_specprefill: + if cache_hit: + # Score only the suffix — system prefix is already cached + sp_tokens = self._text_tokenizer.encode(suffix) + sp_offset = self._system_kv_token_count + sp_n_total = sp_offset + len(sp_tokens) + else: + # Score the full prompt + sp_tokens = self._text_tokenizer.encode(prompt) + sp_offset = 0 + sp_n_total = len(sp_tokens) + + n_sp_tokens = len(sp_tokens) + + # Threshold check (skip when force-enabled via per-request override) + if ( + specprefill_override is not True + and n_sp_tokens <= self._specprefill_threshold + ): + use_specprefill = False + + # Upper bound: cap to avoid draft model OOM + if use_specprefill and n_sp_tokens > _SPECPREFILL_MAX_TOKENS: + logger.warning( + "SpecPrefill: prompt %d tokens exceeds max %d, " + "falling back to normal path", + n_sp_tokens, + _SPECPREFILL_MAX_TOKENS, + ) + use_specprefill = False + + # Run under generation lock, all tokens in single thread (Metal safety) + async with self._text_generation_lock: + + def _run_with_cache(): + if use_specprefill: + try: + return _run_specprefill() + except Exception as e: + logger.error( + "SpecPrefill failed, falling back to normal path: %s", e + ) + # Fall through to normal path + if cache_hit: + return _run_cache_hit() + else: + return _run_cache_miss() + + def _run_specprefill(): + """Score tokens, sparse prefill, generate autoregressively. + + Composes with system KV cache: when cache_hit, restores the + system KV snapshot first, then sparse-prefills only the suffix + tokens with position_offset = system_kv_token_count. + + Does NOT use MTP (sparse cache is incompatible with MTP + speculative decoding). + """ + import time + from types import SimpleNamespace + + from ..specprefill import ( + cleanup_rope, + score_tokens, + select_chunks, + sparse_prefill, + ) + + # Build target cache (optionally restore system KV snapshot) + target_cache = make_prompt_cache(self._text_model) + if cache_hit: + for layer_idx, snapshot_state in enumerate( + self._system_kv_snapshot + ): + if layer_idx < len(target_cache): + target_cache[layer_idx].state = snapshot_state + mx.eval([c.state for c in target_cache if hasattr(c, "state")]) + + try: + # Phase 1: Score with draft model + t0 = time.monotonic() + importance = score_tokens( + self._draft_model, + sp_tokens, + prefill_step_size=prefill_step_size, + ) + t_score = time.monotonic() - t0 + + # Phase 2: Select important chunks + effective_keep = ( + specprefill_keep_pct_override or self._specprefill_keep_pct + ) + selected = select_chunks(importance, keep_pct=effective_keep) + n_selected = selected.shape[0] + n_scored = len(sp_tokens) + + # Phase 3: Sparse prefill on target model + t0 = time.monotonic() + logits = sparse_prefill( + self._text_model, + sp_tokens, + selected, + target_cache, + step_size=prefill_step_size, + position_offset=sp_offset, + ) + t_prefill = time.monotonic() - t0 + + logger.info( + "SpecPrefill: scored %d tokens in %.1fs, " + "sparse prefill %d/%d (keep=%.0f%%) in %.1fs " + "(offset=%d, effective_keep=%.2f)", + n_scored, + t_score, + n_selected, + n_scored, + n_selected / n_scored * 100, + t_prefill, + sp_offset, + effective_keep, + ) + + # Phase 4: Generate (simple autoregressive, no MTP) + eos_id = self._text_tokenizer.eos_token_id + y = sampler(logits[:, -1, :]) + mx.eval(y) + + results = [] + generated_ids = [] + prev_decoded = "" + + for _ in range(max_tokens): + tok_id = y.item() + generated_ids.append(tok_id) + + # Incremental text decode + decoded = self._text_tokenizer.decode(generated_ids) + new_text = decoded[len(prev_decoded) :] + prev_decoded = decoded + + is_eos = tok_id == eos_id + results.append( + SimpleNamespace( + text=new_text, + finish_reason="stop" if is_eos else None, + ) + ) + + if is_eos: + break + + # Next token + logits = self._text_model(y.reshape(1, -1), cache=target_cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) + + return results, sp_n_total + + finally: + cleanup_rope(self._text_model) + + def _run_cache_hit(): + """Restore system KV snapshot, prefill only suffix, generate.""" + # Restore cached KV state into a fresh cache + restored_cache = make_prompt_cache(self._text_model) + for layer_idx, snapshot_state in enumerate(self._system_kv_snapshot): + if layer_idx < len(restored_cache): + restored_cache[layer_idx].state = snapshot_state + mx.eval([c.state for c in restored_cache if hasattr(c, "state")]) + + # Tokenize just the suffix and generate with the primed cache. + # stream_generate accepts mx.array prompt (skips tokenization) + # and prompt_cache is forwarded to mtp_generate_step. + suffix_tokens = self._text_tokenizer.encode(suffix) + suffix_array = mx.array(suffix_tokens) + n_suffix = len(suffix_tokens) + + logger.info( + "System KV cache HIT: prefilling %d suffix tokens " + "(skipped %d cached tokens)", + n_suffix, + self._system_kv_token_count, + ) + + results = [] + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=suffix_array, + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prompt_cache=restored_cache, + prefill_step_size=prefill_step_size, + ): + results.append(resp) + return results, self._system_kv_token_count + len(suffix_tokens) + + def _run_cache_miss(): + """Full prefill + generation, then snapshot system KV for next time.""" + results = [] + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + prefill_step_size=prefill_step_size, + ): + results.append(resp) + + # Snapshot system KV for next request (if we found a system prefix) + if prefix_hash is not None and system_prefix is not None: + try: + _snapshot_system_kv() + except Exception as e: + logger.warning("Failed to snapshot system KV cache: %s", e) + + # Get total prompt token count from generation response + prompt_tokens = 0 + if results and hasattr(results[0], "prompt_tokens"): + prompt_tokens = results[0].prompt_tokens + return results, prompt_tokens + + def _snapshot_system_kv(): + """Prefill just the system prefix on a fresh cache and save snapshot.""" + snapshot_cache = make_prompt_cache(self._text_model) + prefix_tokens = self._text_tokenizer.encode(system_prefix) + prefix_ids = mx.array(prefix_tokens) + + # Chunked prefill of system prefix + for i in range(0, prefix_ids.size, prefill_step_size): + chunk = prefix_ids[i : i + prefill_step_size] + self._text_model(chunk[None], cache=snapshot_cache) + mx.eval([c.state for c in snapshot_cache if hasattr(c, "state")]) + + # Save snapshot: deep copy of each cache layer's state + self._system_kv_snapshot = [] + for c in snapshot_cache: + state = c.state + if isinstance(state, tuple) and len(state) == 2: + # KVCache: (keys, values) — copy to detach from cache + keys, values = state + self._system_kv_snapshot.append( + (mx.array(keys), mx.array(values)) + ) + elif isinstance(state, list): + # ArraysCache: list of arrays (Mamba/hybrid) + self._system_kv_snapshot.append( + [mx.array(a) if a is not None else None for a in state] + ) + else: + # Unknown cache type — store as-is + self._system_kv_snapshot.append(state) + + self._system_kv_token_count = len(prefix_tokens) + self._system_kv_hash = prefix_hash + + cache_bytes = 0 + for entry in self._system_kv_snapshot: + if isinstance(entry, tuple) and len(entry) == 2: + cache_bytes += entry[0].nbytes + entry[1].nbytes + elif isinstance(entry, list): + cache_bytes += sum(a.nbytes for a in entry if a is not None) + logger.info( + "System KV cache: stored %d-token snapshot " "(%.1f MB), hash=%s", + len(prefix_tokens), + cache_bytes / 1e6, + prefix_hash, + ) + + result = await asyncio.to_thread(_run_with_cache) + all_resps, prompt_token_count = result + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=prompt_token_count, + completion_tokens=token_count, + finished=finished, + finish_reason="stop" if finished else None, + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=prompt_token_count, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = { @@ -779,6 +1401,29 @@ def get_stats(self) -> dict[str, Any]: elif self._engine: stats.update(self._engine.get_stats()) + # SpecPrefill stats + if self._draft_model is not None: + stats["specprefill"] = { + "enabled": self._specprefill_enabled, + "draft_model": self._specprefill_draft_model_path, + "threshold": self._specprefill_threshold, + "keep_pct": self._specprefill_keep_pct, + } + + # System KV cache stats + if self._system_kv_snapshot is not None: + cache_bytes = 0 + for entry in self._system_kv_snapshot: + if isinstance(entry, tuple) and len(entry) == 2: + cache_bytes += entry[0].nbytes + entry[1].nbytes + elif isinstance(entry, list): + cache_bytes += sum(a.nbytes for a in entry if a is not None) + stats["system_kv_cache"] = { + "tokens": self._system_kv_token_count, + "hash": self._system_kv_hash, + "memory_mb": round(cache_bytes / 1e6, 1), + } + return stats def get_cache_stats(self) -> dict[str, Any] | None: diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 3a7e14e2..d168f79e 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -13,11 +13,35 @@ from ..api.tool_calling import convert_tools_for_template from ..api.utils import clean_output_text, is_mllm_model +from ..message_utils import _normalize_messages from .base import BaseEngine, GenerationOutput logger = logging.getLogger(__name__) +_MEDIA_TYPES = frozenset( + { + "image_url", + "video_url", + "audio_url", + "image", + "video", + "audio", + } +) + + +def _has_media_content(messages: list) -> bool: + """Check if any message contains media content (images, video, audio).""" + for msg in messages: + content = msg.get("content") + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") in _MEDIA_TYPES: + return True + return False + + class SimpleEngine(BaseEngine): """ Simple engine for direct model calls. @@ -32,6 +56,11 @@ def __init__( trust_remote_code: bool = True, enable_cache: bool = True, force_mllm: bool = False, + mtp: bool = False, + specprefill_enabled: bool = False, + specprefill_draft_model_path: str | None = None, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, ): """ Initialize the simple engine. @@ -41,15 +70,29 @@ def __init__( trust_remote_code: Whether to trust remote code enable_cache: Enable VLM cache for multimodal models force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (model must have MTP head) + specprefill_enabled: Enable SpecPrefill (attention-based sparse prefill) + specprefill_draft_model_path: Path to draft model for SpecPrefill scoring + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill + specprefill_keep_pct: Fraction of tokens to keep during sparse prefill """ self._model_name = model_name self._trust_remote_code = trust_remote_code self._enable_cache = enable_cache self._is_mllm = force_mllm or is_mllm_model(model_name) + self._mtp = mtp + self._specprefill_enabled = specprefill_enabled + self._specprefill_draft_model_path = specprefill_draft_model_path + self._specprefill_threshold = specprefill_threshold + self._specprefill_keep_pct = specprefill_keep_pct self._model = None self._loaded = False + # Per-request routing state (MLLM+MTP mode) + self._text_model = None + self._text_tokenizer = None + # Lock to serialize MLX operations (prevents Metal command buffer conflicts) self._generation_lock = asyncio.Lock() @@ -91,11 +134,54 @@ async def start(self) -> None: self._model = MLXLanguageModel( self._model_name, trust_remote_code=self._trust_remote_code, + mtp=self._mtp, ) self._model.load() self._loaded = True - logger.info(f"SimpleEngine loaded: {self._model_name} (MLLM={self._is_mllm})") + + # Build parallel mlx_lm TextModel for text-only MTP routing + if self._is_mllm and self._mtp: + try: + from ..text_model_from_vlm import build_text_model + + self._text_model = build_text_model(self._model.model, self._model_name) + + if ( + self._text_model is not None + and hasattr(self._text_model, "mtp") + and self._text_model.mtp is not None + ): + self._text_tokenizer = self._model.get_tokenizer() + + # Apply Qwen3.5 eos_token fix (matches MLXLanguageModel.load) + if "qwen3" in self._model_name.lower(): + self._text_tokenizer.eos_token = "<|im_end|>" + self._text_tokenizer.eos_token_id = ( + self._text_tokenizer.convert_tokens_to_ids("<|im_end|>") + ) + + logger.info( + "MLLM+MTP routing: text-only → mlx_lm TextModel (MTP=True), " + "media → mlx_vlm" + ) + else: + logger.warning( + "TextModel built but no MTP — text-only requests won't use MTP" + ) + self._text_model = None + + except Exception as e: + logger.error("MLLM+MTP routing setup failed: %s", e) + self._text_model = None + self._text_tokenizer = None + + mtp_info = f", MTP={self._mtp}" if self._mtp else "" + routing = ", routing=per-request" if self._text_model is not None else "" + logger.info( + f"SimpleEngine loaded: {self._model_name} " + f"(MLLM={self._is_mllm}{mtp_info}{routing})" + ) async def stop(self) -> None: """Stop the engine and cleanup resources.""" @@ -264,9 +350,40 @@ async def chat( if not self._loaded: await self.start() + messages = _normalize_messages(messages) + # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None + # Per-request routing: text-only through mlx_lm with MTP + if ( + self._is_mllm + and self._text_model is not None + and not _has_media_content(messages) + ): + logger.info("Text-only request → LLM path (MTP=True) [non-streaming]") + # Collect streaming output into single response + accumulated_text = "" + last_chunk = None + async for chunk in self._stream_generate_text( + messages, + max_tokens, + temperature, + top_p, + tools=template_tools, + **kwargs, + ): + accumulated_text = chunk.text + last_chunk = chunk + if last_chunk is not None: + return GenerationOutput( + text=accumulated_text, + prompt_tokens=last_chunk.prompt_tokens, + completion_tokens=last_chunk.completion_tokens, + finish_reason=last_chunk.finish_reason, + ) + return GenerationOutput(text="", finish_reason="stop") + async with self._generation_lock: if self._is_mllm: # For MLLM, use the chat method which handles images/videos @@ -336,47 +453,73 @@ async def stream_chat( if not self._loaded: await self.start() + # Normalize messages before any path (developer->system, merge consecutive) + messages = _normalize_messages(messages) + # Convert tools for template template_tools = convert_tools_for_template(tools) if tools else None + # Per-request routing: text-only through mlx_lm with MTP + if ( + self._is_mllm + and self._text_model is not None + and not _has_media_content(messages) + ): + logger.info("Text-only request → LLM path (MTP=True)") + async for chunk in self._stream_generate_text( + messages, + max_tokens, + temperature, + top_p, + tools=template_tools, + **kwargs, + ): + yield chunk + return + # Build prompt using tokenizer if self._is_mllm: - # For MLLM, use stream_chat which yields tokens incrementally - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, + if self._text_model is not None: + logger.info("Media request → MLLM path") + # For MLLM, use stream_chat which yields tokens incrementally. + # Must hold _generation_lock to prevent concurrent Metal access + # (e.g. OpenCode sends title + main request simultaneously). + async with self._generation_lock: + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) ) - ) - chunks = await asyncio.to_thread(run_stream) + chunks = await asyncio.to_thread(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream @@ -415,6 +558,106 @@ def run_stream(): ): yield output + async def _stream_generate_text( + self, + messages: list[dict[str, Any]], + max_tokens: int, + temperature: float, + top_p: float, + tools: list | None = None, + **kwargs, + ) -> AsyncIterator[GenerationOutput]: + """Text-only generation via mlx_lm TextModel with MTP. + + Used when MLLM+MTP routing is active and the request has no media. + Runs the full generation in a single thread to maintain Metal safety. + """ + import os + + from mlx_lm import stream_generate as mlx_stream_generate + from mlx_lm.sample_utils import make_sampler + + # Read enable_thinking from env (set by runtime_patches, consistent with MLLM path) + enable_thinking_env = os.environ.get("VLLM_MLX_ENABLE_THINKING", "true") + enable_thinking = enable_thinking_env.lower() in ("true", "1", "yes") + + # Apply chat template + template_kwargs = { + "tokenize": False, + "add_generation_prompt": True, + "enable_thinking": enable_thinking, + } + if tools: + template_kwargs["tools"] = tools + + try: + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + except TypeError: + # Template doesn't accept tools= or enable_thinking= + template_kwargs.pop("tools", None) + template_kwargs.pop("enable_thinking", None) + prompt = self._text_tokenizer.apply_chat_template( + messages, **template_kwargs + ) + + # Build sampler + sampler = make_sampler(temp=temperature, top_p=top_p) + max_tokens = max_tokens or 4096 + + # Run under generation lock, all tokens in single thread (Metal safety) + async with self._generation_lock: + + def _run_all(): + results = [] + for resp in mlx_stream_generate( + self._text_model, + self._text_tokenizer, + prompt=prompt, + max_tokens=max_tokens, + sampler=sampler, + mtp=True, + ): + results.append(resp) + return results + + all_resps = await asyncio.to_thread(_run_all) + + # Yield results as GenerationOutput + accumulated_text = "" + token_count = 0 + finished = False + for i, resp in enumerate(all_resps): + token_count += 1 + new_text = resp.text if hasattr(resp, "text") else str(resp) + accumulated_text += new_text + + is_last = i == len(all_resps) - 1 + finished = is_last or token_count >= max_tokens + + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=0, + completion_tokens=token_count, + finished=finished, + finish_reason="stop" if finished else None, + ) + + if finished: + break + + if not finished: + yield GenerationOutput( + text=accumulated_text, + new_text="", + prompt_tokens=0, + completion_tokens=token_count, + finished=True, + finish_reason="length", + ) + def get_stats(self) -> dict[str, Any]: """Get engine statistics.""" stats = { diff --git a/vllm_mlx/message_utils.py b/vllm_mlx/message_utils.py new file mode 100644 index 00000000..621ac057 --- /dev/null +++ b/vllm_mlx/message_utils.py @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Shared message normalization utilities. + +Provides ``_normalize_messages()`` which maps non-standard roles, merges +consecutive same-role messages, and hoists system messages to position [0]. +Used by both SimpleEngine and BatchedEngine before ``apply_chat_template``. +""" + +import logging + +logger = logging.getLogger(__name__) + + +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + """ + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + # Hoist system messages to position [0] and merge them. + # Many CLIs (OpenCode, Qwen Code, Kilo) send system messages mid-conversation; + # the Qwen 3.5 chat template rejects any system message not at position [0]. + system_msgs = [m for m in merged if m["role"] == "system"] + non_system = [m for m in merged if m["role"] != "system"] + if system_msgs and (len(system_msgs) > 1 or merged[0]["role"] != "system"): + # Combine all system message content (string only) into one + parts = [] + for m in system_msgs: + c = m.get("content") + if isinstance(c, str): + parts.append(c) + elif isinstance(c, list): + # Multimodal system message — extract text parts + for part in c: + if isinstance(part, str): + parts.append(part) + elif isinstance(part, dict) and part.get("type") == "text": + parts.append(part.get("text", "")) + if parts: + combined_system = {"role": "system", "content": "\n\n".join(parts)} + merged = [combined_system] + non_system + logger.info( + f"Hoisted {len(system_msgs)} system message(s) to position [0] " + f"({len(combined_system['content'])} chars)" + ) + else: + # No string content — just move the first system msg to front + merged = system_msgs[:1] + non_system + logger.info("Hoisted system message to position [0]") + + merged_count = len(messages) - len(merged) + if merged_count: + logger.info(f"Normalized messages: merged {len(messages)} -> {len(merged)}") + + return merged diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 092c060e..72182037 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -50,6 +50,7 @@ def __init__( model_name: str, tokenizer_name: str | None = None, trust_remote_code: bool = False, + mtp: bool = False, ): """ Initialize the MLX language model. @@ -58,10 +59,12 @@ def __init__( model_name: HuggingFace model name or local path tokenizer_name: Optional separate tokenizer name trust_remote_code: Whether to trust remote code + mtp: Enable native MTP speculative decoding (model must have MTP head) """ self.model_name = model_name self.tokenizer_name = tokenizer_name or model_name self.trust_remote_code = trust_remote_code + self._mtp = mtp self.model = None self.tokenizer = None @@ -203,12 +206,17 @@ def stream_generate( token_count = 0 accumulated_text = "" + mtp_kwargs = {} + if self._mtp: + mtp_kwargs["mtp"] = True + for response in stream_generate( self.model, self.tokenizer, prompt=prompt, max_tokens=max_tokens, sampler=sampler, + **mtp_kwargs, ): token_count += 1 # response.text is the new token text (not accumulated) diff --git a/vllm_mlx/models/mllm.py b/vllm_mlx/models/mllm.py index 22b36963..df20bc10 100644 --- a/vllm_mlx/models/mllm.py +++ b/vllm_mlx/models/mllm.py @@ -33,6 +33,59 @@ logger = logging.getLogger(__name__) +def _hoist_system_messages(chat_messages: list[dict]) -> list[dict]: + """Hoist system messages to position [0] in mlx_vlm chat_messages format. + + Qwen 3.5 chat template rejects system messages not at the beginning. + CLIs like OpenCode/Qwen Code/Kilo send system messages mid-conversation. + Also maps ``developer`` role to ``system`` and merges consecutive same-role. + """ + if not chat_messages: + return chat_messages + + # Map developer -> system + for msg in chat_messages: + if msg.get("role") == "developer": + msg["role"] = "system" + + system_msgs = [m for m in chat_messages if m["role"] == "system"] + if not system_msgs: + return chat_messages + if len(system_msgs) == 1 and chat_messages[0]["role"] == "system": + return chat_messages + + non_system = [m for m in chat_messages if m["role"] != "system"] + + # Extract text from all system messages (handles both str and list content) + parts = [] + for m in system_msgs: + c = m.get("content") + if isinstance(c, str): + parts.append(c) + elif isinstance(c, list): + for item in c: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text", "")) + + if parts: + combined_text = "\n\n".join(parts) + combined = { + "role": "system", + "content": [ + {"type": "text", "text": combined_text, "content": combined_text} + ], + } + logger.info( + f"Hoisted {len(system_msgs)} system message(s) to position [0] " + f"({len(combined_text)} chars)" + ) + return [combined] + non_system + + return system_msgs[:1] + non_system + + class TempFileManager: """Thread-safe manager for tracking and cleaning up temporary files.""" @@ -740,6 +793,14 @@ def load(self) -> None: logger.error(f"Failed to load MLLM: {e}") raise + def get_language_model(self): + """Extract the underlying language model for mlx_lm TextModel construction.""" + return self.model.language_model + + def get_tokenizer(self): + """Get the text tokenizer (not the multimodal processor).""" + return self.processor.tokenizer + def _prepare_images(self, images: list) -> list[str]: """Process image inputs and return local file paths.""" processed = [] @@ -1143,6 +1204,9 @@ def chat( all_images.extend(frames) logger.info(f"Added {len(frames)} frames from video: {video_path}") + # Hoist system messages to position [0] for Qwen 3.5 template + chat_messages = _hoist_system_messages(chat_messages) + # Apply chat template directly - messages are already properly structured logger.info( f"Applying chat template with {len(chat_messages)} messages, {len(all_images)} images" @@ -1511,6 +1575,9 @@ def stream_chat( ) all_images.extend(frames) + # Hoist system messages to position [0] for Qwen 3.5 template + chat_messages = _hoist_system_messages(chat_messages) + # Apply chat template directly - messages are already properly structured # Pop tools so they don't leak into mlx_vlm.generate()/stream_generate() tools = kwargs.pop("tools", None) diff --git a/vllm_mlx/repetition_detector.py b/vllm_mlx/repetition_detector.py new file mode 100644 index 00000000..0b2c604c --- /dev/null +++ b/vllm_mlx/repetition_detector.py @@ -0,0 +1,73 @@ +"""Detect degenerate repeating token patterns during generation.""" + + +class RepetitionDetector: + """Sliding-window detector for repeating token sequences. + + Checks periodically (every ``check_interval`` tokens) whether the + last ``window`` tokens contain a pattern of length 2-``max_pattern`` + repeated at least ``min_repeats`` times consecutively. + + Usage:: + + det = RepetitionDetector() + for token_id in generate(): + if det.check(token_id): + break # degenerate loop detected + """ + + def __init__( + self, + window: int = 200, + max_pattern: int = 50, + min_repeats: int = 3, + check_interval: int = 20, + ): + self.window = window + self.max_pattern = max_pattern + self.min_repeats = min_repeats + self.check_interval = check_interval + self._tokens: list[int] = [] + self._count = 0 + + def check(self, token_id: int) -> bool: + """Record a token and return True if a repetition loop is detected.""" + self._tokens.append(token_id) + self._count += 1 + + # Only keep the sliding window + if len(self._tokens) > self.window: + self._tokens = self._tokens[-self.window :] + + # Check periodically to stay lightweight + if self._count % self.check_interval != 0: + return False + + return self._is_repeating() + + def _is_repeating(self) -> bool: + tokens = self._tokens + n = len(tokens) + # Need at least min_repeats * 2 tokens for shortest pattern (len 2) + if n < self.min_repeats * 2: + return False + + for pat_len in range(2, min(self.max_pattern + 1, n // self.min_repeats + 1)): + pattern = tokens[-pat_len:] + repeats = 1 + pos = n - 2 * pat_len + while pos >= 0: + if tokens[pos : pos + pat_len] == pattern: + repeats += 1 + if repeats >= self.min_repeats: + return True + pos -= pat_len + else: + break + + return False + + def reset(self): + """Clear state for a new generation.""" + self._tokens.clear() + self._count = 0 diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e..58c1f13f 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -467,6 +467,11 @@ def load_model( stream_interval: int = 1, max_tokens: int = 32768, force_mllm: bool = False, + mtp: bool = False, + specprefill_enabled: bool = False, + specprefill_draft_model_path: str | None = None, + specprefill_threshold: int = 8192, + specprefill_keep_pct: float = 0.3, ): """ Load a model (auto-detects MLLM vs LLM). @@ -478,6 +483,11 @@ def load_model( stream_interval: Tokens to batch before streaming (batched mode only) max_tokens: Default max tokens for generation force_mllm: Force loading as MLLM even if not auto-detected + mtp: Enable native MTP speculative decoding (per-request routing in both engines) + specprefill_enabled: Enable SpecPrefill (attention-based sparse prefill) + specprefill_draft_model_path: Path to draft model for SpecPrefill scoring + specprefill_threshold: Minimum suffix tokens to trigger SpecPrefill + specprefill_keep_pct: Fraction of tokens to keep during sparse prefill """ global _engine, _model_name, _default_max_tokens, _tool_parser_instance @@ -496,13 +506,26 @@ def load_model( scheduler_config=scheduler_config, stream_interval=stream_interval, force_mllm=force_mllm, + mtp=mtp, + specprefill_enabled=specprefill_enabled, + specprefill_draft_model_path=specprefill_draft_model_path, + specprefill_threshold=specprefill_threshold, + specprefill_keep_pct=specprefill_keep_pct, ) # BatchedEngine will be started in lifespan (uvicorn's event loop) # Just log for now logger.info(f"Model loaded (batched mode): {model_name}") else: logger.info(f"Loading model with SimpleEngine: {model_name}") - _engine = SimpleEngine(model_name=model_name, force_mllm=force_mllm) + _engine = SimpleEngine( + model_name=model_name, + force_mllm=force_mllm, + mtp=mtp, + specprefill_enabled=specprefill_enabled, + specprefill_draft_model_path=specprefill_draft_model_path, + specprefill_threshold=specprefill_threshold, + specprefill_keep_pct=specprefill_keep_pct, + ) # Start SimpleEngine synchronously (no background loop) # Use new_event_loop() for Python 3.10+ compatibility (get_event_loop() is deprecated) loop = asyncio.new_event_loop() @@ -1359,6 +1382,12 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re if request.video_max_frames: chat_kwargs["video_max_frames"] = request.video_max_frames + # SpecPrefill: per-request overrides + if request.specprefill is not None: + chat_kwargs["specprefill"] = request.specprefill + if request.specprefill_keep_pct is not None: + chat_kwargs["specprefill_keep_pct"] = request.specprefill_keep_pct + # Add tools if provided if request.tools: chat_kwargs["tools"] = convert_tools_for_template(request.tools) diff --git a/vllm_mlx/specprefill.py b/vllm_mlx/specprefill.py new file mode 100644 index 00000000..5ea985ff --- /dev/null +++ b/vllm_mlx/specprefill.py @@ -0,0 +1,742 @@ +# SPDX-License-Identifier: Apache-2.0 +"""SpecPrefill: Attention-based sparse prefill for MLX. + +Full pipeline for reducing TTFT on long prompts: + Step 1 (score_tokens): Use a small draft model to identify important tokens + Step 2 (sparse_prefill): Prefill target model with only selected tokens, + preserving original positional encoding via manual RoPE + +Usage: + from specprefill import score_tokens, select_chunks, sparse_prefill, cleanup_rope + + # 1. Score with draft model + importance = score_tokens(draft_model, tokens) + + # 2. Select important token chunks + selected = select_chunks(importance, keep_pct=0.3) + + # 3. Sparse prefill on target model + target_cache = make_prompt_cache(target_model) + logits = sparse_prefill(target_model, tokens, selected, target_cache) + + # 4. Generate normally using target_cache... + + # 5. Cleanup + cleanup_rope(target_model) + +Design notes: + - RoPE is relative: Q_m @ K_p^T depends only on (m - p). Selected keys stored + contiguously in the cache buffer with correct RoPE angles produce correct + attention during decode. + - After sparse prefill of N tokens from a total prompt of M, cache.offset = N + but decode RoPE needs position M. The _OffsetAdjustedRoPE adds (M - N) to + each RoPE offset call, so decode position = N + i + (M - N) = M + i. + - GatedDeltaNet (linear attention) layers process sparse tokens through their + conv/SSM state normally. This is lossy but acceptable per the SpecPrefill + paper — attention layers are the primary long-range mechanism. + +Reference: arxiv.org/abs/2502.02789 (SpecPrefill: Speculative Prefilling) +""" + +import math + +import mlx.core as mx + +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.sample_utils import make_sampler + +# =========================================================================== +# Step 1: Token importance scoring (draft model) +# =========================================================================== + + +class _AttentionCapture: + """Wrapper that captures post-RoPE query vectors and delegates to original. + + Installed on attention layers during lookahead decode to capture query + vectors for importance scoring. Supports multiple architectures via + query_extractor callback. + """ + + def __init__(self, original, buf_idx, query_buffer, query_extractor=None): + self._original = original + self._buf_idx = buf_idx + self._query_buffer = query_buffer + self._query_extractor = query_extractor or _qwen35_extract_queries + + def __call__(self, x, mask=None, cache=None): + queries = self._query_extractor(self._original, x, cache) + self._query_buffer[self._buf_idx].append(queries) + return self._original(x, mask=mask, cache=cache) + + def __getattr__(self, name): + return getattr(self._original, name) + + +def _qwen35_extract_queries(attn, x, cache=None): + """Extract post-RoPE queries from Qwen3.5 attention (gate split + q_norm). + + Qwen3.5 q_proj output is 2x wider: [queries, gate]. We split, normalize, + then apply RoPE. + """ + B, L, D = x.shape + q_out = attn.q_proj(x) + queries, _gate = mx.split( + q_out.reshape(B, L, attn.num_attention_heads, -1), 2, axis=-1 + ) + queries = attn.q_norm(queries).transpose(0, 2, 1, 3) + if cache is not None: + queries = attn.rope(queries, offset=cache.offset) + else: + queries = attn.rope(queries) + return queries + + +def _llama_extract_queries(attn, x, cache=None): + """Extract post-RoPE queries from standard transformer attention. + + Standard architecture: q_proj → reshape → RoPE. No gate, no q_norm. + Works for Llama 3.x, Mistral, Gemma, GPT-OSS, and other GQA models. + """ + B, L, D = x.shape + n_heads = getattr( + attn, + "num_attention_heads", + getattr(attn, "n_heads", getattr(attn, "num_heads", None)), + ) + queries = attn.q_proj(x) + queries = queries.reshape(B, L, n_heads, -1).transpose(0, 2, 1, 3) + if cache is not None: + queries = attn.rope(queries, offset=cache.offset) + else: + queries = attn.rope(queries) + return queries + + +def _nemotron_h_extract_queries(attn, x, cache=None): + """Extract queries from Nemotron-H attention (no RoPE, no gate, no q_norm). + + Nemotron-H attention layers have NO positional encoding — RoPE is absent. + Positional modeling comes from Mamba2 layers. Attention is content-based only. + """ + B, L, D = x.shape + queries = attn.q_proj(x).reshape(B, L, attn.num_heads, -1).transpose(0, 2, 1, 3) + # No RoPE to apply — queries are used as-is for content-based scoring + return queries + + +def _patch_attention_for_capture(model, query_buffer, query_extractor=None): + """Replace attention modules on full-attention layers with capture wrappers. + + Supports both `self_attn` (Qwen3.5/Llama/GPT-OSS) and `mixer` + (Nemotron-H block_type="*") attribute conventions. + + Returns (originals, attn_layer_indices) for cleanup. + """ + originals = [] + attn_indices = [] + for layer_idx, layer in _find_attention_layers(model): + buf_idx = len(attn_indices) + attn_indices.append(layer_idx) + orig = _get_attn_module(layer) + _set_attn_module( + layer, + _AttentionCapture( + orig, buf_idx, query_buffer, query_extractor=query_extractor + ), + ) + originals.append((layer_idx, orig)) + return originals, attn_indices + + +def _unpatch_attention_capture(model, originals): + """Restore original attention modules after capture.""" + for layer_idx, orig in originals: + _set_attn_module(model.layers[layer_idx], orig) + + +def _prefill_draft(model, tokens, cache, step_size=2048): + """Prefill prompt tokens into cache. Returns logits from last token.""" + prompt = mx.array(tokens) if not isinstance(tokens, mx.array) else tokens + n = len(tokens) + processed = 0 + while n - processed > 1: + chunk = min(step_size, n - processed - 1) + model(prompt[processed : processed + chunk][None], cache=cache) + mx.eval([c.state for c in cache]) + processed += chunk + mx.clear_cache() + logits = model(prompt[processed:][None], cache=cache) + mx.eval(logits) + return logits + + +def _lookahead_decode(model, first_logits, cache, n_steps, temp=0.6, top_p=0.95): + """Run n_steps autoregressive decode, returning generated token ids. + + Query vectors are captured by the monkey-patched attention layers. + """ + sampler = make_sampler(temp=temp, top_p=top_p) + y = sampler(first_logits[:, -1, :]) + mx.eval(y) + generated = [y.item()] + for _ in range(n_steps): + logits = model(y.reshape(1, -1), cache=cache) + y = sampler(logits[:, -1, :]) + mx.eval(y) + generated.append(y.item()) + return generated + + +def _avg_pool1d(x, kernel_size): + """1D average pooling along last axis via prefix-sum. + + Args: + x: (..., M) input + kernel_size: window size (odd for centered) + + Returns: + (..., M) pooled (same size, zero-padded at edges) + """ + if kernel_size <= 1: + return x + pad = kernel_size // 2 + padded = mx.pad(x, [(0, 0)] * (x.ndim - 1) + [(pad, pad)]) + zeros = mx.zeros(x.shape[:-1] + (1,), dtype=x.dtype) + prefix = mx.concatenate([zeros, mx.cumsum(padded, axis=-1)], axis=-1) + return (prefix[..., kernel_size:] - prefix[..., :-kernel_size]) / kernel_size + + +def _compute_importance( + query_buffer, attn_caches, n_prompt, n_attn_heads, n_kv_heads, pool_kernel=13 +): + """Compute per-token importance from captured queries and cached keys. + + Aggregation (SpecPrefill paper): + 1. softmax(Q @ K^T / sqrt(d)) per head, per layer, per lookahead token + 2. avg_pool1d smoothing + 3. max across (layers × heads) + 4. mean across lookahead tokens + + Returns: (n_prompt,) importance scores. + """ + heads_per_group = n_attn_heads // n_kv_heads + all_scores = [] + + for layer_i, captures in enumerate(query_buffer): + if not captures: + continue + cache = attn_caches[layer_i] + prompt_keys = cache.keys[..., :n_prompt, :] + # Skip layers with windowed/rotating caches that don't span + # the full prompt (e.g., GPT-OSS sliding_attention with 128-token window). + # These lack global context and would produce mismatched score shapes. + if prompt_keys.shape[-2] < n_prompt: + continue + head_dim = prompt_keys.shape[-1] + q_stack = mx.concatenate(captures, axis=2) + if heads_per_group > 1: + expanded_keys = mx.repeat(prompt_keys, heads_per_group, axis=1) + else: + expanded_keys = prompt_keys + scale = head_dim**-0.5 + scores = (q_stack @ expanded_keys.transpose(0, 1, 3, 2)) * scale + weights = mx.softmax(scores.astype(mx.float32), axis=-1) + all_scores.append(weights.squeeze(0)) + + if not all_scores: + raise RuntimeError("No attention scores captured — check model/patching") + + combined = mx.concatenate(all_scores, axis=0) + if pool_kernel and pool_kernel > 1: + combined = _avg_pool1d(combined, pool_kernel) + max_scores = mx.max(combined, axis=0) + importance = mx.mean(max_scores, axis=0) + return importance + + +def score_tokens( + model, + tokens, + n_lookahead=8, + pool_kernel=13, + temp=0.6, + top_p=0.95, + prefill_step_size=2048, + query_extractor=None, +): + """Score token importance using attention-based analysis on a draft model. + + Runs the full scoring pipeline: + 1. Prefill the draft model with all tokens + 2. N lookahead decode steps, capturing query vectors from attention layers + 3. Compute importance: Q_lookahead @ K_prompt^T, aggregated across heads/layers + + The draft model's cache is created internally and discarded after scoring. + + Args: + model: Draft model (small, fast — e.g. 4B) + tokens: list or mx.array of token IDs + n_lookahead: decode steps for query capture (default 8) + pool_kernel: smoothing kernel for avg_pool1d (default 13, 0=disable) + temp: sampling temperature for lookahead (default 0.6) + top_p: top-p for lookahead (default 0.95) + prefill_step_size: chunk size for draft prefill (default 2048) + query_extractor: function(attn, x, cache) → queries tensor. + Default: _qwen35_extract_queries. Use _llama_extract_queries for + standard Llama/Mistral/Gemma models. + + Returns: + importance: (M,) mx.array of per-token importance scores + """ + if isinstance(tokens, mx.array): + tokens = tokens.tolist() + n_prompt = len(tokens) + + # Model topology — detect attribute names across architectures + attn_layers = _find_attention_layers(model) + n_attn_layers = len(attn_layers) + attn_obj = _get_attn_module(attn_layers[0][1]) + # Attribute names vary: num_attention_heads (Qwen3.5), n_heads (Llama), + # num_heads (Nemotron-H) + n_attn_heads = getattr( + attn_obj, + "num_attention_heads", + getattr(attn_obj, "n_heads", getattr(attn_obj, "num_heads", None)), + ) + n_kv_heads = getattr( + attn_obj, "num_key_value_heads", getattr(attn_obj, "n_kv_heads", None) + ) + + # Auto-detect query extractor if not specified + if query_extractor is None: + if hasattr(attn_obj, "q_norm"): + query_extractor = _qwen35_extract_queries + elif not hasattr(attn_obj, "rope"): + # No RoPE attribute → Nemotron-H style (content-based attention) + query_extractor = _nemotron_h_extract_queries + else: + query_extractor = _llama_extract_queries + + # Phase 1: Prefill + cache = make_prompt_cache(model) + logits = _prefill_draft(model, tokens, cache, step_size=prefill_step_size) + + # Phase 2: Lookahead decode with query capture + query_buffer = [[] for _ in range(n_attn_layers)] + patches, attn_indices = _patch_attention_for_capture( + model, query_buffer, query_extractor=query_extractor + ) + try: + _lookahead_decode(model, logits, cache, n_lookahead, temp=temp, top_p=top_p) + mx.eval(query_buffer) + finally: + _unpatch_attention_capture(model, patches) + + # Phase 3: Compute importance + # Map layer indices to cache indices (identity for standard models, + # compacted for Nemotron-H where only M/* layers have cache entries) + layer_to_cache = _build_layer_to_cache_map(model) + attn_caches = [cache[layer_to_cache[i]] for i in attn_indices] + importance = _compute_importance( + query_buffer, + attn_caches, + n_prompt, + n_attn_heads, + n_kv_heads, + pool_kernel=pool_kernel if pool_kernel > 0 else None, + ) + mx.eval(importance) + + # Draft cache is no longer needed — let GC reclaim it + del cache, logits, query_buffer, attn_caches + mx.clear_cache() + + return importance + + +def select_chunks(importance, keep_pct=0.3, chunk_size=32): + """Select top-k% token chunks by average importance. + + Args: + importance: (M,) per-token importance scores + keep_pct: fraction of chunks to keep (default 0.3) + chunk_size: tokens per chunk (default 32) + + Returns: + sorted mx.array of kept token indices + """ + M = importance.shape[0] + if keep_pct >= 1.0: + return mx.arange(M) + + n_chunks = math.ceil(M / chunk_size) + keep_n = max(1, math.ceil(n_chunks * keep_pct)) + + chunk_scores = [] + for i in range(n_chunks): + start = i * chunk_size + end = min(start + chunk_size, M) + chunk_scores.append(mx.mean(importance[start:end]).item()) + + top_chunks = sorted(range(n_chunks), key=lambda i: chunk_scores[i], reverse=True)[ + :keep_n + ] + top_chunks.sort() + + indices = [] + for ci in top_chunks: + start = ci * chunk_size + end = min(start + chunk_size, M) + indices.extend(range(start, end)) + + return mx.array(indices) + + +# =========================================================================== +# Step 2: Sparse prefill with non-contiguous position IDs (target model) +# =========================================================================== + + +# --------------------------------------------------------------------------- +# Manual RoPE at arbitrary positions +# --------------------------------------------------------------------------- + + +def manual_rope(x, positions, dims, base=10000.0, scale=1.0): + """Apply RoPE at arbitrary (non-contiguous) positions. + + Uses non-traditional (interleaved) layout matching Qwen3.5: + rotates first `dims` dimensions as pairs [0,half), [half,dims), + passes through [dims:] unchanged. + + Args: + x: (B, n_heads, L, head_dim) input tensor + positions: (L,) position indices (can be non-contiguous) + dims: number of dimensions to rotate (head_dim * partial_rotary_factor) + base: RoPE base frequency (default 10000.0) + scale: position scale divisor (default 1.0, higher = compressed positions) + + Returns: + (B, n_heads, L, head_dim) with RoPE applied + """ + half = dims // 2 + inv_freq = 1.0 / (base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)) + scaled_pos = positions.astype(mx.float32) / scale + angles = scaled_pos[:, None] * inv_freq[None, :] # (L, half) + cos_a = mx.cos(angles)[None, None, :, :] # (1, 1, L, half) + sin_a = mx.sin(angles)[None, None, :, :] + x_rot, x_pass = x[..., :dims], x[..., dims:] + x1, x2 = x_rot[..., :half], x_rot[..., half:] + rotated = mx.concatenate( + [x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1 + ) + return mx.concatenate([rotated, x_pass], axis=-1) + + +def manual_rope_with_freqs(x, positions, dims, freqs, pre_scale=1.0): + """Apply RoPE at arbitrary positions using pre-computed frequencies. + + For custom RoPE variants (Llama3, Yarn, SuScaled) that store _freqs. + """ + half = dims // 2 + inv_freq = (1.0 / freqs).astype(mx.float32) + angles = positions[:, None].astype(mx.float32) * inv_freq[None, :] + cos_a = mx.cos(angles)[None, None, :, :] + sin_a = mx.sin(angles)[None, None, :, :] + x_rot, x_pass = x[..., :dims], x[..., dims:] + if pre_scale != 1.0: + x_rot = pre_scale * x_rot + x1, x2 = x_rot[..., :half], x_rot[..., half:] + rotated = mx.concatenate( + [x1 * cos_a - x2 * sin_a, x1 * sin_a + x2 * cos_a], axis=-1 + ) + return mx.concatenate([rotated, x_pass], axis=-1) + + +# --------------------------------------------------------------------------- +# RoPE wrappers +# --------------------------------------------------------------------------- + + +class _PositionMappedRoPE: + """Wraps a RoPE module to apply rotation at non-contiguous positions. + + Used during sparse prefill. The `offset` parameter from the cache tells us + which slice of the position array to use for the current chunk: + positions = all_positions[(offset - cache_start) : (offset - cache_start) + L] + + When composing with a pre-populated cache (e.g., system KV cache), cache_start + is the initial cache offset so indexing into the position array is correct. + """ + + def __init__(self, original_rope, all_positions, cache_start=0): + self._original = original_rope + self._all_positions = all_positions + self._cache_start = cache_start + self._has_custom_freqs = hasattr(original_rope, "_freqs") + + if self._has_custom_freqs: + self._freqs = original_rope._freqs + self._dims = _get_dims(original_rope) + self._pre_scale = _get_pre_scale(original_rope) + else: + # Standard nn.RoPE: attributes are dims, base, scale (no underscore) + self._dims = original_rope.dims + self._base = original_rope.base + self._scale = original_rope.scale + + def __call__(self, x, offset=0): + L = x.shape[2] + idx = offset - self._cache_start + positions = self._all_positions[idx : idx + L] + if self._has_custom_freqs: + return manual_rope_with_freqs( + x, positions, self._dims, self._freqs, pre_scale=self._pre_scale + ) + return manual_rope(x, positions, self._dims, base=self._base, scale=self._scale) + + +class _OffsetAdjustedRoPE: + """Wraps a RoPE module to add a constant offset for decode after sparse prefill. + + After sparse prefill of N tokens from a prompt of M total tokens: + cache.offset = N + i (i = decode step) + desired RoPE position = M + i + adjustment = M - N + + So: RoPE(x, offset = cache.offset + adjustment) = RoPE(x, M + i) + """ + + def __init__(self, original_rope, adjustment): + self._original = original_rope + self._adjustment = adjustment + + def __call__(self, x, offset=0): + return self._original(x, offset=offset + self._adjustment) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_dims(rope_module): + """Extract rotary dimensions from any RoPE variant.""" + for attr in ("_dims", "dim", "dims"): + if hasattr(rope_module, attr): + return getattr(rope_module, attr) + raise ValueError(f"Cannot determine dims from {type(rope_module)}") + + +def _get_pre_scale(rope_module): + """Extract pre-scale factor from custom RoPE variants (SuScaled, Yarn).""" + if hasattr(rope_module, "mscale"): + return rope_module.mscale + if hasattr(rope_module, "_scale") and hasattr(rope_module, "dim"): + return rope_module._scale + return 1.0 + + +def _find_attention_layers(model): + """Find all full-attention layers across architectures. + + Supports: + - Qwen3.5 / Llama / GPT-OSS: layers with `self_attn` attribute + - Nemotron-H: layers with `block_type == "*"` (attention blocks use `mixer`) + + Returns list of (layer_idx, layer) tuples. + """ + results = [] + for idx, layer in enumerate(model.layers): + if hasattr(layer, "self_attn"): + results.append((idx, layer)) + elif getattr(layer, "block_type", None) == "*": + results.append((idx, layer)) + return results + + +def _get_attn_module(layer): + """Get the attention module from a layer (self_attn or mixer).""" + if hasattr(layer, "self_attn"): + return layer.self_attn + if getattr(layer, "block_type", None) == "*": + return layer.mixer + return None + + +def _set_attn_module(layer, module): + """Set the attention module on a layer (self_attn or mixer).""" + if hasattr(layer, "self_attn"): + layer.self_attn = module + elif getattr(layer, "block_type", None) == "*": + layer.mixer = module + + +def _build_layer_to_cache_map(model): + """Build mapping from model layer index to cache index. + + Standard models (Qwen3.5, Llama, GPT-OSS): one cache entry per layer, + so the mapping is identity (layer_idx → layer_idx). + + Nemotron-H: only M (Mamba2) and * (attention) layers have cache entries. + MLP (-) and MoE (E) layers get no cache. The mapping is compacted. + + Returns dict {layer_idx: cache_idx}. + """ + has_block_type = any(hasattr(layer, "block_type") for layer in model.layers) + if not has_block_type: + # Standard model: identity mapping + return {i: i for i in range(len(model.layers))} + + # Nemotron-H style: count cache entries for M/* layers + layer_to_cache = {} + cache_idx = 0 + for layer_idx, layer in enumerate(model.layers): + bt = getattr(layer, "block_type", None) + if bt in ("M", "*"): + layer_to_cache[layer_idx] = cache_idx + cache_idx += 1 + return layer_to_cache + + +# --------------------------------------------------------------------------- +# Core API — sparse prefill +# --------------------------------------------------------------------------- + + +def sparse_prefill( + model, tokens, selected_indices, cache, step_size=2048, position_offset=0 +): + """Prefill the model cache with selected tokens at their original positions. + + Runs the model forward on only the selected tokens while preserving their + original positional encoding via manual RoPE. After this call, the cache + contains KV entries with correct RoPE positions, and attention layers have + _OffsetAdjustedRoPE installed for correct decode positioning. + + Args: + model: Language model with .layers property (TextModel or VLM Model) + tokens: (M,) all prompt token IDs (mx.array or list) + selected_indices: (N,) sorted indices into tokens to keep (mx.array or list) + cache: list of KVCache/ArraysCache from make_prompt_cache() + step_size: chunk size for processing (default 2048) + position_offset: added to selected_indices for RoPE positions (default 0). + Use when the cache already has tokens from a prior prefill (e.g., + system prompt KV cache with S tokens → position_offset=S). + + Returns: + logits: (1, 1, vocab_size) from the last selected token + + Side effects: + - Populates cache with KV for selected tokens + - Installs _OffsetAdjustedRoPE on attention layers for decode + - Call cleanup_rope(model) after generation to restore original RoPE + """ + if not isinstance(tokens, mx.array): + tokens = mx.array(tokens) + if not isinstance(selected_indices, mx.array): + selected_indices = mx.array(selected_indices) + + M = tokens.shape[0] + + # Detect RotatingKVCache and ensure tail tokens are included. + # Models with sliding window attention (e.g., GPT-OSS) use RotatingKVCache + # which evicts old entries. We must include the last `max_size` positions + # so sliding window layers have valid recent context for decode. + max_rotating_size = 0 + for c in cache: + if type(c).__name__ == "RotatingKVCache": + max_rotating_size = max(max_rotating_size, getattr(c, "max_size", 0)) + if max_rotating_size > 0: + tail_start = max(0, M - max_rotating_size) + tail_indices = set(range(tail_start, M)) + existing = set(selected_indices.tolist()) + merged = sorted(existing | tail_indices) + selected_indices = mx.array(merged) + + # RoPE positions: absolute positions accounting for any prefix + selected_positions = selected_indices.astype(mx.int32) + position_offset + selected_tokens = tokens[selected_indices] + N = selected_tokens.shape[0] + + # Determine initial cache offset (non-zero when system KV cache is restored) + attn_layers = _find_attention_layers(model) + layer_to_cache = _build_layer_to_cache_map(model) + first_attn_layer_idx = attn_layers[0][0] + first_attn_cache_idx = layer_to_cache[first_attn_layer_idx] + cache_start = ( + cache[first_attn_cache_idx].offset + if hasattr(cache[first_attn_cache_idx], "offset") + else 0 + ) + + # Check if attention layers use RoPE (Nemotron-H has none) + first_attn = _get_attn_module(attn_layers[0][1]) + has_rope = hasattr(first_attn, "rope") + + # Patch RoPE on attention layers for position-mapped prefill + # (skipped for architectures without RoPE, e.g. Nemotron-H) + original_ropes = {} + if has_rope: + for layer_idx, layer in attn_layers: + attn = _get_attn_module(layer) + original_ropes[layer_idx] = attn.rope + attn.rope = _PositionMappedRoPE( + attn.rope, selected_positions, cache_start=cache_start + ) + + try: + prompt = selected_tokens + n = int(N) + processed = 0 + + while n - processed > 1: + chunk = min(step_size, n - processed - 1) + model(prompt[processed : processed + chunk][None], cache=cache) + mx.eval([c.state for c in cache]) + processed += chunk + mx.clear_cache() + + # Last token → logits + logits = model(prompt[processed:][None], cache=cache) + mx.eval(logits) + + finally: + # Replace position-mapped RoPE with offset-adjusted RoPE for decode. + # Skipped for architectures without RoPE (e.g. Nemotron-H). + # + # Total prompt length = position_offset + M (prefix + current tokens). + # After prefill, cache offset = cache_start + N. + # Decode needs RoPE position = total_len + i, cache gives offset = cache_start + N + i. + # Adjustment = total_len - (cache_start + N) = position_offset + M - cache_start - N. + # When cache_start == position_offset (normal case): adjustment = M - N. + if has_rope: + total_prompt_len = position_offset + M + final_cache_offset = cache_start + N + adjustment = int(total_prompt_len) - int(final_cache_offset) + for layer_idx, layer in attn_layers: + attn = _get_attn_module(layer) + original = original_ropes[layer_idx] + if adjustment > 0: + attn.rope = _OffsetAdjustedRoPE(original, adjustment) + else: + attn.rope = original + + return logits + + +def cleanup_rope(model): + """Restore original RoPE on all attention layers. + + Call this after generation is complete to remove _OffsetAdjustedRoPE + wrappers installed by sparse_prefill(). No-op for architectures + without RoPE (e.g. Nemotron-H). + """ + for _, layer in _find_attention_layers(model): + attn = _get_attn_module(layer) + if attn is None or not hasattr(attn, "rope"): + continue + rope = attn.rope + if isinstance(rope, (_OffsetAdjustedRoPE, _PositionMappedRoPE)): + attn.rope = rope._original diff --git a/vllm_mlx/text_model_from_vlm.py b/vllm_mlx/text_model_from_vlm.py new file mode 100644 index 00000000..b1130fdc --- /dev/null +++ b/vllm_mlx/text_model_from_vlm.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Construct an mlx_lm TextModel from mlx_vlm-loaded model weights. + +When mlx_vlm loads a model, it strips MTP weights in sanitize(). +This module builds a parallel mlx_lm TextModel that: +1. Shares backbone + lm_head weights with the vlm model (zero-copy) +2. Loads MTP weights from safetensors on disk +3. Provides full mlx_lm API: return_hidden, n_confirmed, mtp_forward, make_mtp_cache +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + +logger = logging.getLogger(__name__) + + +def build_text_model(vlm_model: Any, model_path: str | Path) -> Any | None: + """Build an mlx_lm TextModel from a vlm-loaded model's weights. + + Args: + vlm_model: The mlx_vlm-loaded model (has .language_model attribute) + model_path: Path to the model directory (contains config.json + safetensors) + + Returns: + mlx_lm TextModel with MTP support, or None on failure. + """ + if vlm_model is None: + return None + + model_path = Path(model_path) if model_path else None + if model_path is None or not (model_path / "config.json").exists(): + return None + + try: + config = json.loads((model_path / "config.json").read_text()) + text_config = config.get("text_config", config) + + # Always import from qwen3_5 — TextModel and TextModelArgs handle both + # dense and MoE natively (MTPDecoderLayer auto-selects SparseMoeBlock + # when args.num_experts > 0). qwen3_5_moe.py does NOT export these. + from mlx_lm.models.qwen3_5 import TextModel, TextModelArgs + + # Build args with proper __post_init__ (handles partial_rotary_factor, + # rope_scaling, head_dim derivation) + args = TextModelArgs.from_dict(text_config) + text_model = TextModel(args) + + # Collect all weights first: backbone from vlm + MTP from safetensors + vlm_lm = vlm_model.language_model + vlm_weights = mlx.utils.tree_flatten(vlm_lm.parameters()) + mtp_weights = _load_mtp_weights(model_path) + + all_weight_names = set(name for name, _ in vlm_weights) + all_weight_names.update(name for name, _ in mtp_weights) + + # Quantize the TextModel skeleton to match source weights. + # Use a predicate that only quantizes layers that have .scales in source. + # This prevents quantizing layers like mtp.fc which are BF16. + quantization = text_config.get("quantization", config.get("quantization", None)) + if quantization is not None: + + def _class_predicate(path, module): + if not hasattr(module, "to_quantized"): + return False + return f"{path}.scales" in all_weight_names + + nn.quantize( + text_model, + group_size=quantization.get("group_size", 64), + bits=quantization.get("bits", 8), + class_predicate=_class_predicate, + ) + + # Transfer backbone + lm_head weights from vlm language_model (zero-copy). + # strict=False because TextModel has MTP params that vlm doesn't have yet. + text_model.load_weights(vlm_weights, strict=False) + + logger.info( + "Transferred %d weight arrays from vlm language_model", len(vlm_weights) + ) + + # Load MTP weights from safetensors + if mtp_weights: + text_model.load_weights(mtp_weights, strict=False) + logger.info("Loaded %d MTP weights from safetensors", len(mtp_weights)) + else: + logger.warning("No MTP weights found in %s", model_path.name) + + # Verify MTP is functional + if hasattr(text_model, "mtp") and text_model.mtp is not None: + mx.eval(text_model.mtp.parameters()) + logger.info( + "TextModel built with MTP support (%d layers)", + args.mtp_num_hidden_layers, + ) + else: + logger.info("TextModel built without MTP (mtp_num_hidden_layers=0)") + + return text_model + + except ImportError as e: + logger.error("Cannot import mlx_lm TextModel (need PR #990): %s", e) + return None + except Exception as e: + logger.error("Failed to build TextModel from vlm: %s", e) + return None + + +def _load_mtp_weights(model_path: Path) -> list[tuple[str, mx.array]]: + """Load MTP weights from safetensors, stripping the language_model. prefix. + + mlx_vlm's sanitize() strips mtp.* keys during model loading, + but the weights are still on disk in the safetensors files. + """ + index_file = model_path / "model.safetensors.index.json" + if not index_file.exists(): + return [] + + index = json.loads(index_file.read_text()) + weight_map = index.get("weight_map", {}) + + # Find MTP keys and their shard files + mtp_keys: dict[str, tuple[str, str]] = {} + for key, shard in weight_map.items(): + if ".mtp." in key: + # Strip "language_model." prefix to match mlx_lm namespace + clean = ( + key.replace("language_model.", "", 1) + if key.startswith("language_model.") + else key + ) + mtp_keys[key] = (clean, shard) + + if not mtp_keys: + return [] + + # Group by shard to minimize I/O + shards: dict[str, list[tuple[str, str]]] = {} + for orig, (clean, shard) in mtp_keys.items(): + shards.setdefault(shard, []).append((orig, clean)) + + weights = [] + for shard_file, key_pairs in shards.items(): + shard_path = model_path / shard_file + if not shard_path.exists(): + logger.warning("MTP shard not found: %s", shard_file) + continue + shard_data = mx.load(str(shard_path)) + for orig, clean in key_pairs: + if orig in shard_data: + weights.append((clean, shard_data[orig])) + + return weights