diff --git a/tests/test_chunked_prefill_compat.py b/tests/test_chunked_prefill_compat.py new file mode 100644 index 000000000..be2022c5b --- /dev/null +++ b/tests/test_chunked_prefill_compat.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for chunked prefill compatibility with mlx-lm tuple format changes. + +Regression test for issue #155: mlx-lm >= 0.31.0 added prompt_checkpoints +as a 7th element to BatchGenerator.unprocessed_prompts tuples. The chunked +prefill code in scheduler.py hardcoded a 6-element unpacking which crashed +with ValueError on the new format. +""" + +from unittest.mock import MagicMock + +import pytest + +try: + from mlx_lm.models import cache + from mlx_lm.generate import BatchGenerator + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +@pytest.fixture +def batch_gen(): + """BatchGenerator with a minimal mock model. + + make_prompt_cache checks hasattr(model, 'make_cache') and MagicMock + returns True for any attribute, so we delete it to fall through to + the default KVCache-per-layer path. + """ + model = MagicMock() + model.layers = [MagicMock(), MagicMock()] + del model.make_cache + + gen = BatchGenerator( + model=model, + max_tokens=100, + prefill_batch_size=1, + completion_batch_size=1, + ) + return gen + + +class TestChunkedPrefillTupleCompat: + """Verify chunked prefill handles varying tuple sizes from mlx-lm.""" + + def test_7_element_tuples_unpack_correctly(self): + """Issue #155: 7-element tuples from mlx-lm >= 0.31.0 must not crash.""" + mock_cache = MagicMock() + mock_cache.empty.return_value = True + batch_prompts = [ + (0, [1, 2, 3, 4, 5], 100, [mock_cache], None, [], -1), + (1, [10, 20, 30], 50, [mock_cache], None, [], -1), + ] + + ( + uids, + inputs_raw, + max_tokens_list, + caches, + samplers, + logits_processors, + *_extra, + ) = zip(*batch_prompts) + + assert uids == (0, 1) + assert inputs_raw == ([1, 2, 3, 4, 5], [10, 20, 30]) + assert max_tokens_list == (100, 50) + assert len(_extra) == 1 # prompt_checkpoints + assert _extra[0] == (-1, -1) + + def test_6_element_tuples_still_work(self): + """Backward compat: old mlx-lm without prompt_checkpoints.""" + mock_cache = MagicMock() + mock_cache.empty.return_value = True + batch_prompts = [ + (0, [1, 2, 3], 100, [mock_cache], None, []), + ] + + ( + uids, + inputs_raw, + max_tokens_list, + caches, + samplers, + logits_processors, + *_extra, + ) = zip(*batch_prompts) + + assert uids == (0,) + assert len(_extra) == 0 + + def test_8_element_tuples_forward_compat(self): + """Future-proofing: if mlx-lm adds more fields, still works.""" + mock_cache = MagicMock() + mock_cache.empty.return_value = True + batch_prompts = [ + (0, [1, 2, 3], 100, [mock_cache], None, [], -1, "future"), + ] + + ( + uids, + inputs_raw, + max_tokens_list, + caches, + samplers, + logits_processors, + *_extra, + ) = zip(*batch_prompts) + + assert uids == (0,) + assert len(_extra) == 2 + + def test_batch_generator_insert_creates_7_element_tuples(self, batch_gen): + """Verify mlx-lm 0.31.x BatchGenerator.insert creates 7-element tuples.""" + prompt_cache = cache.make_prompt_cache(batch_gen.model) + + batch_gen.insert([[1, 2, 3, 4, 5]], max_tokens=[50], caches=[prompt_cache]) + + assert len(batch_gen.unprocessed_prompts) == 1 + prompt_tuple = batch_gen.unprocessed_prompts[0] + assert len(prompt_tuple) >= 7, ( + f"Expected >= 7 elements in prompt tuple, got {len(prompt_tuple)}. " + f"mlx-lm may have changed tuple format again." + ) + + def test_chunked_prefill_with_7_element_tuples(self, batch_gen): + """Integration: _install_chunked_prefill works with 7-element tuples.""" + from vllm_mlx.scheduler import _install_chunked_prefill + + prompt_cache = cache.make_prompt_cache(batch_gen.model) + + batch_gen.insert( + [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]], + max_tokens=[50], + caches=[prompt_cache], + ) + + _install_chunked_prefill(batch_gen, budget=4) + + # Must NOT crash with "too many values to unpack". + # Later errors (AttributeError, etc.) are expected since the mock + # model doesn't do real inference — only the unpacking matters. + try: + batch_gen._next() + except ValueError as e: + if "unpack" in str(e).lower(): + pytest.fail( + f"Issue #155 regression: chunked prefill crashed on " + f"7-element tuple unpacking: {e}" + ) + except Exception: + pass # Expected: mock model can't do real forward pass diff --git a/tests/test_mllm_hybrid_cache.py b/tests/test_mllm_hybrid_cache.py new file mode 100644 index 000000000..6e11f6739 --- /dev/null +++ b/tests/test_mllm_hybrid_cache.py @@ -0,0 +1,524 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for MLLM continuous batching with hybrid model caches. + +Hybrid models (Qwen 3.5, Nemotron 3 Super) mix attention layers (KVCache) +with recurrent/SSM layers (ArraysCache). The MLLM batch generator must +handle both cache types during merge, filter, extract, and extend operations. +""" + +import pytest + +try: + import mlx.core as mx + from mlx_lm.models.cache import ( + ArraysCache, + BatchKVCache, + KVCache, + RotatingKVCache, + BatchRotatingKVCache, + ) + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Helpers — simulate Qwen 3.5 cache layout (12 KVCache + 36 ArraysCache) +# --------------------------------------------------------------------------- + + +def _make_hybrid_cache(n_kv=12, n_arrays=36, arrays_size=2): + """Create a hybrid cache list like Qwen 3.5's make_cache(). + + Qwen 3.5 layout: is_linear = (layer_idx + 1) % 4 != 0 + So layers 0,1,2 are ArraysCache, layer 3 is KVCache, etc. + For simplicity, we just create n_arrays ArraysCache + n_kv KVCache + interleaved with the real pattern. + """ + full_attention_interval = 4 + total = n_kv + n_arrays + cache = [] + for i in range(total): + is_linear = (i + 1) % full_attention_interval != 0 + if is_linear: + cache.append(ArraysCache(size=arrays_size)) + else: + cache.append(KVCache()) + return cache + + +def _populate_kv_cache( + cache: KVCache, seq_len: int, n_kv_heads: int = 4, head_dim: int = 8 +): + """Populate a KVCache with dummy data to simulate a completed prefill.""" + # KVCache.update_and_fetch expects 4D: (batch, n_kv_heads, seq_len, head_dim) + keys = mx.random.normal((1, n_kv_heads, seq_len, head_dim)) + values = mx.random.normal((1, n_kv_heads, seq_len, head_dim)) + cache.update_and_fetch(keys, values) + + +def _populate_arrays_cache( + cache: ArraysCache, batch_size: int = 1, state_dim: int = 16 +): + """Populate an ArraysCache with dummy SSM state.""" + for i in range(len(cache.cache)): + cache.cache[i] = mx.random.normal((batch_size, state_dim)) + + +def _make_populated_hybrid_cache( + seq_len: int = 10, n_kv_heads: int = 4, head_dim: int = 8, state_dim: int = 16 +): + """Create and populate a hybrid cache simulating a completed vision encoding prefill.""" + cache = _make_hybrid_cache() + for c in cache: + if isinstance(c, KVCache): + _populate_kv_cache(c, seq_len, n_kv_heads, head_dim) + elif isinstance(c, ArraysCache): + _populate_arrays_cache(c, batch_size=1, state_dim=state_dim) + return cache + + +# --------------------------------------------------------------------------- +# Test: _make_batch_cache handles all cache types +# --------------------------------------------------------------------------- + + +class TestMakeBatchCache: + """Test _make_batch_cache() with hybrid model caches.""" + + def test_hybrid_cache_creates_correct_types(self): + """_make_batch_cache returns BatchKVCache for KVCache layers, ArraysCache for ArraysCache layers.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + # Mock model with make_cache returning hybrid layout + class FakeModel: + def make_cache(self): + return _make_hybrid_cache() + + left_padding = [0, 2] # 2-request batch, different prompt lengths + batch_cache = _make_batch_cache(FakeModel(), left_padding) + + assert len(batch_cache) == 48 # 12 KV + 36 Arrays + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear: + # ArraysCache is returned as-is with left_padding set + assert isinstance(c, ArraysCache), f"Layer {i} should be ArraysCache" + assert c.left_padding is not None + else: + assert isinstance(c, BatchKVCache), f"Layer {i} should be BatchKVCache" + + def test_pure_kv_cache_still_works(self): + """Regression: pure attention models (all KVCache) still work.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [KVCache() for _ in range(24)] + + batch_cache = _make_batch_cache(FakeModel(), [0, 1]) + assert all(isinstance(c, BatchKVCache) for c in batch_cache) + + def test_pure_arrays_cache_works(self): + """Pure SSM models (all ArraysCache) work.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [ArraysCache(size=2) for _ in range(24)] + + batch_cache = _make_batch_cache(FakeModel(), [0, 1]) + assert all(isinstance(c, ArraysCache) for c in batch_cache) + + def test_rotating_kv_cache_works(self): + """RotatingKVCache (keep=0) gets wrapped in BatchRotatingKVCache.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [RotatingKVCache(max_size=1024, keep=0) for _ in range(4)] + + batch_cache = _make_batch_cache(FakeModel(), [0]) + assert all(isinstance(c, BatchRotatingKVCache) for c in batch_cache) + + def test_rotating_kv_cache_with_keep_rejected(self): + """RotatingKVCache with keep > 0 is rejected.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class FakeModel: + def make_cache(self): + return [RotatingKVCache(max_size=1024, keep=4)] + + with pytest.raises(ValueError, match="keep tokens is not supported"): + _make_batch_cache(FakeModel(), [0]) + + def test_unsupported_cache_type_rejected(self): + """Cache types without batching support are rejected with clear error.""" + from vllm_mlx.mllm_batch_generator import _make_batch_cache + + class UnsupportedCache: + pass + + class FakeModel: + def make_cache(self): + return [UnsupportedCache()] + + with pytest.raises(ValueError, match="does not support"): + _make_batch_cache(FakeModel(), [0]) + + +# --------------------------------------------------------------------------- +# Test: Merge loop works with mixed cache types +# --------------------------------------------------------------------------- + + +class TestHybridCacheMerge: + """Test the per-layer merge loop from _process_prompts.""" + + def test_merge_hybrid_per_request_caches(self): + """Merging per-request hybrid caches produces correct batched types.""" + # Simulate 2 requests, each with a populated hybrid cache + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + + # This is the exact merge loop from _process_prompts (lines 679-685) + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + assert len(batch_cache) == 48 + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear: + assert isinstance( + c, ArraysCache + ), f"Layer {i}: merged ArraysCache should stay ArraysCache" + # Merged arrays should have batch dimension = 2 + for arr in c.cache: + if arr is not None: + assert arr.shape[0] == 2, f"Layer {i}: batch dim should be 2" + else: + assert isinstance( + c, BatchKVCache + ), f"Layer {i}: merged KVCache should become BatchKVCache" + + def test_merge_single_request(self): + """Single-request merge works (degenerate case).""" + caches = [_make_populated_hybrid_cache(seq_len=10)] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + assert len(batch_cache) == 48 + + def test_type_guard_rejects_unmergeable_cache(self): + """The capability check rejects caches without merge().""" + from vllm_mlx.mllm_batch_generator import _validate_caches_mergeable + + class NoMergeCache: + pass + + per_request_caches = [[NoMergeCache(), KVCache()]] + with pytest.raises(ValueError, match="lacks a merge"): + _validate_caches_mergeable(per_request_caches) + + +# --------------------------------------------------------------------------- +# Test: Filter on merged batch +# --------------------------------------------------------------------------- + + +class TestHybridCacheFilter: + """Test filter() on merged hybrid batches.""" + + def test_filter_keeps_correct_requests(self): + """Filter on merged batch keeps correct batch elements for both cache types.""" + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + _make_populated_hybrid_cache(seq_len=12), + ] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + # Keep only request 0 and 2 + keep_idx = mx.array([0, 2], mx.int32) + for c in batch_cache: + if hasattr(c, "filter"): + c.filter(keep_idx) + + # Verify batch dimension is now 2 + for i, c in enumerate(batch_cache): + is_linear = (i + 1) % 4 != 0 + if is_linear and isinstance(c, ArraysCache): + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 2 + ), f"Layer {i}: filtered batch dim should be 2" + + +# --------------------------------------------------------------------------- +# Test: Extract from merged batch +# --------------------------------------------------------------------------- + + +class TestHybridCacheExtract: + """Test extract() on merged hybrid batches.""" + + def test_extract_returns_correct_types(self): + """Extracting a single request returns correct unbatched types.""" + caches = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + batch_cache = [ + caches[0][layer_idx].merge([c[layer_idx] for c in caches]) + for layer_idx in range(len(caches[0])) + ] + + # Extract request 0 + extracted = [ + c.extract(0) if hasattr(c, "extract") else None for c in batch_cache + ] + + for i, c in enumerate(extracted): + is_linear = (i + 1) % 4 != 0 + if c is None: + continue + if is_linear: + assert isinstance( + c, ArraysCache + ), f"Layer {i}: extracted should be ArraysCache" + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 1 + ), f"Layer {i}: extracted batch dim should be 1" + else: + assert isinstance(c, KVCache), f"Layer {i}: extracted should be KVCache" + + +# --------------------------------------------------------------------------- +# Test: Extend merged batches +# --------------------------------------------------------------------------- + + +class TestHybridCacheExtend: + """Test extend() combining two merged hybrid batches.""" + + def test_extend_combines_batches(self): + """Extending one merged batch with another works for both cache types.""" + caches_a = [ + _make_populated_hybrid_cache(seq_len=10), + _make_populated_hybrid_cache(seq_len=15), + ] + caches_b = [ + _make_populated_hybrid_cache(seq_len=12), + ] + batch_a = [ + caches_a[0][layer_idx].merge([c[layer_idx] for c in caches_a]) + for layer_idx in range(len(caches_a[0])) + ] + batch_b = [ + caches_b[0][layer_idx].merge([c[layer_idx] for c in caches_b]) + for layer_idx in range(len(caches_b[0])) + ] + + # Extend batch_a with batch_b + for c, o in zip(batch_a, batch_b): + if c is not None and o is not None and hasattr(c, "extend"): + if not c.empty() and not o.empty(): + c.extend(o) + + # Verify combined batch has 3 elements + for i, c in enumerate(batch_a): + is_linear = (i + 1) % 4 != 0 + if is_linear and isinstance(c, ArraysCache): + for arr in c.cache: + if arr is not None: + assert ( + arr.shape[0] == 3 + ), f"Layer {i}: extended batch dim should be 3" + + +# --------------------------------------------------------------------------- +# Test: Message normalization +# --------------------------------------------------------------------------- + + +class TestNormalizeMessages: + """Test _normalize_messages() for handling real-world client formats.""" + + def test_merge_consecutive_system_messages(self): + """Consecutive system messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": "Always respond in JSON."}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert "helpful assistant" in result[0]["content"] + assert "JSON" in result[0]["content"] + assert result[1]["role"] == "user" + assert result[1]["content"] == "Hello" + + def test_merge_consecutive_user_messages(self): + """Consecutive user messages are merged into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "First part"}, + {"role": "user", "content": "Second part"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[1]["role"] == "user" + assert "First part" in result[1]["content"] + assert "Second part" in result[1]["content"] + + def test_opencode_format(self): + """OpenCode's system+system+user+user format is normalized.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "System prompt part 1"}, + {"role": "system", "content": "System prompt part 2"}, + {"role": "user", "content": "User instruction"}, + {"role": "user", "content": "User question"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_already_alternating_unchanged(self): + """Well-formed alternating messages pass through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "Bye"}, + ] + result = _normalize_messages(messages) + assert result == messages + + def test_single_message_unchanged(self): + """Single message passes through unchanged.""" + from vllm_mlx.server import _normalize_messages + + messages = [{"role": "user", "content": "Hello"}] + result = _normalize_messages(messages) + assert result == messages + + def test_empty_messages(self): + """Empty message list passes through.""" + from vllm_mlx.server import _normalize_messages + + assert _normalize_messages([]) == [] + + def test_multimodal_content_preserved(self): + """Messages with list content (multimodal) are preserved during merge.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "user", "content": "Describe this:"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "What is in this image?"}, + { + "type": "image_url", + "image_url": {"url": "http://example.com/img.png"}, + }, + ], + }, + ] + result = _normalize_messages(messages) + # When one message has list content and previous has string, + # they can't be trivially merged — keep them or convert + # At minimum, no crash + assert len(result) >= 1 + + def test_preserves_non_content_fields(self): + """Fields other than role/content are preserved.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1", "name": "sys1"}, + {"role": "system", "content": "Part 2"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + # First message retains extra fields from first of the merged group + assert result[0]["role"] == "system" + + def test_null_content_not_merged(self): + """Messages with None content (tool_calls pattern) are not merged.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "assistant", "content": "Follow-up"}, + ] + result = _normalize_messages(messages) + # None content can't be merged with string — kept separate + assert len(result) == 2 + + def test_three_consecutive_system_messages(self): + """Three consecutive system messages merge into one.""" + from vllm_mlx.server import _normalize_messages + + messages = [ + {"role": "system", "content": "Part 1"}, + {"role": "system", "content": "Part 2"}, + {"role": "system", "content": "Part 3"}, + {"role": "user", "content": "Hello"}, + ] + result = _normalize_messages(messages) + assert len(result) == 2 + assert "Part 1" in result[0]["content"] + assert "Part 3" in result[0]["content"] + + +# --------------------------------------------------------------------------- +# Test: Empty cache extend guard +# --------------------------------------------------------------------------- + + +class TestEmptyCacheExtend: + """Test the empty() guard in extend prevents crashes on unpopulated caches.""" + + def test_extend_skips_empty_caches(self): + """Extending when one cache is empty does not crash.""" + populated = _make_populated_hybrid_cache(seq_len=10) + + # Merge populated into single-request batch + batch_pop = [populated[i].merge([populated[i]]) for i in range(len(populated))] + + # Create empty caches directly (don't merge — merge() can't handle all-None) + batch_empty = _make_hybrid_cache() + + # Extend should not crash — empty guard should skip + for c, o in zip(batch_pop, batch_empty): + if c is not None and o is not None and hasattr(c, "extend"): + if not c.empty() and not o.empty(): + c.extend(o) + # If we get here without crash, test passes diff --git a/tests/test_prefix_cache_hybrid.py b/tests/test_prefix_cache_hybrid.py new file mode 100644 index 000000000..eb9920495 --- /dev/null +++ b/tests/test_prefix_cache_hybrid.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for BlockAwarePrefixCache with hybrid model caches (issues #142, #136).""" + +from unittest.mock import MagicMock + +import pytest + +try: + import mlx.core as mx + from mlx_lm.models.cache import ArraysCache, KVCache + + HAS_MLX = True +except ImportError: + HAS_MLX = False + +from vllm_mlx.prefix_cache import ( + BlockAwarePrefixCache, + NonKVCacheData, + _is_kv_layer, +) + +pytestmark = pytest.mark.skipif(not HAS_MLX, reason="MLX not available") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_kv_state(seq_len=64, n_kv_heads=4, head_dim=32): + """Simulate a KVCache layer_state dict (4D tensors).""" + keys = mx.zeros((1, n_kv_heads, seq_len, head_dim)) + values = mx.zeros((1, n_kv_heads, seq_len, head_dim)) + return { + "state": (keys, values), + "meta_state": (str(seq_len),), + "class_name": "KVCache", + "class_ref": KVCache, + } + + +def _make_arrays_state(seq_len=64, conv_dim=128, ssm_heads=8, ssm_dim=64): + """Simulate an ArraysCache layer_state dict (conv_3D + recurrent_4D).""" + conv_state = mx.zeros((1, 3, conv_dim)) + ssm_state = mx.zeros((1, ssm_heads, ssm_dim, ssm_dim)) + return { + "state": [conv_state, ssm_state], + "meta_state": "", + "class_name": "ArraysCache", + "class_ref": ArraysCache, + } + + +def _make_hybrid_cache_data(n_total=48, attn_interval=4, seq_len=128): + """Simulate extracted cache states for a hybrid model. + + Qwen 3.5 pattern: every 4th layer is attention, rest are GatedDeltaNet. + """ + cache_data = [] + for i in range(n_total): + is_attn = (i + 1) % attn_interval == 0 + if is_attn: + cache_data.append(_make_kv_state(seq_len=seq_len)) + else: + cache_data.append(_make_arrays_state()) + return cache_data + + +def _make_pure_kv_cache_data(n_layers=32, seq_len=128): + """Simulate extracted cache states for a pure attention model.""" + return [_make_kv_state(seq_len=seq_len) for _ in range(n_layers)] + + +# --------------------------------------------------------------------------- +# Tests: Layer Classification +# --------------------------------------------------------------------------- + + +class TestIsKVLayer: + def test_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "KVCache"}) is True + + def test_rotating_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "RotatingKVCache"}) is True + + def test_quantized_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "QuantizedKVCache"}) is True + + def test_batch_kvcache_is_kv(self): + assert _is_kv_layer({"class_name": "BatchKVCache"}) is True + + def test_arrays_cache_is_not_kv(self): + assert _is_kv_layer({"class_name": "ArraysCache"}) is False + + def test_cache_list_is_not_kv(self): + assert _is_kv_layer({"class_name": "CacheList"}) is False + + def test_missing_class_name_is_not_kv(self): + assert _is_kv_layer({}) is False + + def test_empty_class_name_is_not_kv(self): + assert _is_kv_layer({"class_name": ""}) is False + + +class TestExtractBlockTensorSlice: + """Test _extract_block_tensor_slice with hybrid cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_pure_kv_slicing_unchanged(self, cache): + """Pure KV model: all layers sliced as before.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + assert len(result) == 4 + for entry in result: + assert entry is not None + keys_slice, values_slice = entry + assert keys_slice.shape == (1, 4, 64, 32) + assert values_slice.shape == (1, 4, 64, 32) + + def test_hybrid_skips_non_kv_layers(self, cache): + """Hybrid model: KV layers sliced, ArraysCache layers return None.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + # Layers: 0=Arr, 1=Arr, 2=Arr, 3=KV, 4=Arr, 5=Arr, 6=Arr, 7=KV + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + assert len(result) == 8 + # Non-KV layers are None + assert result[0] is None + assert result[1] is None + assert result[2] is None + assert result[4] is None + assert result[5] is None + assert result[6] is None + # KV layers are sliced + assert result[3] is not None + keys, values = result[3] + assert keys.shape == (1, 4, 64, 32) + assert result[7] is not None + + def test_hybrid_does_not_crash(self, cache): + """Regression: the original bug — no IndexError on ArraysCache layers.""" + data = _make_hybrid_cache_data(n_total=48, attn_interval=4, seq_len=256) + result = cache._extract_block_tensor_slice(data, 0, 64) + assert result is not None + + def test_slice_beyond_seq_len(self, cache): + """KV slice beyond available seq_len clips correctly.""" + data = _make_pure_kv_cache_data(n_layers=2, seq_len=100) + result = cache._extract_block_tensor_slice(data, 64, 128) + assert result is not None + keys, values = result[0] + assert keys.shape[2] == 36 # 100 - 64 + + +# --------------------------------------------------------------------------- +# Tests: store_cache with hybrid model cache data +# --------------------------------------------------------------------------- + + +class TestStoreHybridCache: + """Test store_cache with hybrid model cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_stores_non_kv_states(self, cache): + """Hybrid cache data stores non-KV states in _non_kv_states.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 1 + non_kv = list(cache._non_kv_states.values())[0] + assert isinstance(non_kv, NonKVCacheData) + assert non_kv.total_layers == 8 + assert len(non_kv.layer_indices) == 6 # 6 ArraysCache layers + assert non_kv.layer_indices == [0, 1, 2, 4, 5, 6] + + def test_pure_kv_no_non_kv_states(self, cache): + """Pure KV model does not create non-KV states.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 0 + + def test_has_non_kv_flag_on_entry(self, cache): + """BlockCacheEntry gets has_non_kv flag.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + entry = cache._request_tables["req-1"] + assert entry.has_non_kv is True + + def test_has_non_kv_false_for_pure_kv(self, cache): + """Pure KV entries have has_non_kv=False.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + entry = cache._request_tables["req-1"] + assert entry.has_non_kv is False + + +# --------------------------------------------------------------------------- +# Tests: reconstruct_cache with hybrid model cache data +# --------------------------------------------------------------------------- + + +class TestReconstructHybridCache: + """Test reconstruct_cache with hybrid model cache data.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_pure_kv_reconstruct_unchanged(self, cache): + """Pure KV model: reconstruct works as before.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + result = cache.reconstruct_cache(bt) + assert result is not None + assert len(result) == 4 + for c in result: + assert hasattr(c, "keys") + assert hasattr(c, "values") + + def test_hybrid_reconstruct_all_layers(self, cache): + """Hybrid model: reconstructs both KV and non-KV layers.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + result = cache.reconstruct_cache(bt) + assert result is not None + assert len(result) == 8 # All layers present + # KV layers (3, 7) should be KVCache + assert hasattr(result[3], "keys") + assert hasattr(result[7], "keys") + # Non-KV layers (0,1,2,4,5,6) should be ArraysCache + assert isinstance(result[0], ArraysCache) + assert isinstance(result[4], ArraysCache) + + def test_hybrid_reconstruct_missing_non_kv_returns_none(self, cache): + """If non-KV states are missing, return None (can't reconstruct).""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + bt = cache.store_cache("req-1", tokens, data) + # Delete non-KV states to simulate missing data + cache._non_kv_states.clear() + result = cache.reconstruct_cache(bt) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: fetch_cache with hybrid model prefix matching +# --------------------------------------------------------------------------- + + +class TestFetchHybridCache: + """Test fetch_cache with hybrid model prefix matching.""" + + @pytest.fixture + def cache(self): + mock_model = MagicMock() + from vllm_mlx.paged_cache import PagedCacheManager + + paged = PagedCacheManager(block_size=64, max_blocks=100) + return BlockAwarePrefixCache(mock_model, paged) + + def test_full_match_hybrid_no_crash(self, cache): + """Full prefix match with non-KV states does not crash.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + # Same tokens — should not crash + bt, remaining = cache.fetch_cache("req-2", tokens) + + def test_pure_kv_partial_match_no_crash(self, cache): + """Pure KV model: partial prefix does not crash.""" + data = _make_pure_kv_cache_data(n_layers=4, seq_len=192) + tokens = list(range(192)) + cache.store_cache("req-1", tokens, data) + shorter_tokens = list(range(128)) + bt, remaining = cache.fetch_cache("req-2", shorter_tokens) + + def test_cleanup_removes_non_kv_states(self, cache): + """release_cache cleans up non-KV states when no other request uses same blocks.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + tokens = list(range(128)) + cache.store_cache("req-1", tokens, data) + assert len(cache._non_kv_states) == 1 + cache.release_cache("req-1") + assert len(cache._non_kv_states) == 0 + + def test_clear_removes_all_non_kv_states(self, cache): + """clear() removes all non-KV states.""" + data = _make_hybrid_cache_data(n_total=8, attn_interval=4, seq_len=128) + cache.store_cache("req-1", list(range(128)), data) + cache.store_cache("req-2", list(range(64, 192)), data) + cache.clear() + assert len(cache._non_kv_states) == 0 + + +# --------------------------------------------------------------------------- +# Tests: scheduler robustness in cache reconstruction +# --------------------------------------------------------------------------- + + +class TestSchedulerRobustness: + """Test scheduler cache reconstruction with non-KV layers.""" + + def test_from_state_works_for_both_cache_types(self): + """Both ArraysCache and KVCache can be reconstructed via from_state. + + KVCache.meta_state is always "" (inherits _BaseCache which has no + meta_state). Only subclasses like RotatingKVCache define meta_state + as a tuple. The scheduler guard should handle non-4D state tensors. + """ + arrays_state = { + "state": [mx.zeros((1, 3, 128)), mx.zeros((1, 8, 64, 64))], + "meta_state": "", + "class_name": "ArraysCache", + "class_ref": ArraysCache, + } + # KVCache.meta_state is "" (inherits _BaseCache), NOT a tuple + kv_state = { + "state": (mx.zeros((1, 4, 100, 32)), mx.zeros((1, 4, 100, 32))), + "meta_state": "", + "class_name": "KVCache", + "class_ref": KVCache, + } + extracted = [arrays_state, kv_state, arrays_state, kv_state] + + for layer_state in extracted: + cls = layer_state["class_ref"] + state = layer_state["state"] + meta = layer_state.get("meta_state", "") + obj = cls.from_state(state, meta) + assert obj is not None diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 3a7e14e24..26a5b14d9 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -341,42 +341,45 @@ async def stream_chat( # Build prompt using tokenizer if self._is_mllm: - # For MLLM, use stream_chat which yields tokens incrementally - accumulated_text = "" - token_count = 0 - - # Run stream_chat in thread pool since it's synchronous - def run_stream(): - return list( - self._model.stream_chat( - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - tools=template_tools, - **kwargs, + # For MLLM, use stream_chat which yields tokens incrementally. + # Must hold _generation_lock to prevent concurrent Metal access + # (e.g. OpenCode sends title + main request simultaneously). + async with self._generation_lock: + accumulated_text = "" + token_count = 0 + + # Run stream_chat in thread pool since it's synchronous + def run_stream(): + return list( + self._model.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + tools=template_tools, + **kwargs, + ) ) - ) - chunks = await asyncio.to_thread(run_stream) + chunks = await asyncio.to_thread(run_stream) - for chunk in chunks: - token_count += 1 - new_text = chunk.text if hasattr(chunk, "text") else str(chunk) - accumulated_text += new_text + for chunk in chunks: + token_count += 1 + new_text = chunk.text if hasattr(chunk, "text") else str(chunk) + accumulated_text += new_text - finished = chunk.finish_reason is not None + finished = chunk.finish_reason is not None - yield GenerationOutput( - text=accumulated_text, - new_text=new_text, - prompt_tokens=getattr(chunk, "prompt_tokens", 0), - completion_tokens=token_count, - finished=finished, - finish_reason=chunk.finish_reason if finished else None, - ) + yield GenerationOutput( + text=accumulated_text, + new_text=new_text, + prompt_tokens=getattr(chunk, "prompt_tokens", 0), + completion_tokens=token_count, + finished=finished, + finish_reason=chunk.finish_reason if finished else None, + ) - if finished: - break + if finished: + break return # For LLM, apply chat template and stream diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index ee8d8da7b..5dde962e9 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -30,6 +30,21 @@ logger = logging.getLogger(__name__) +def _validate_caches_mergeable(per_request_caches: List[List[Any]]) -> None: + """Validate that all cache layers support merge() for batch creation. + + Raises ValueError if any layer lacks a merge() method (e.g. QuantizedKVCache). + Called before the merge loop in _process_prompts(). + """ + for layer_idx, layer_cache in enumerate(per_request_caches[0]): + if not hasattr(layer_cache, "merge"): + raise ValueError( + f"MLLM continuous batching requires mergeable cache types " + f"but layer {layer_idx} has {type(layer_cache).__name__} " + f"which lacks a merge() method." + ) + + @dataclass class MLLMBatchRequest: """ @@ -139,20 +154,17 @@ def extend(self, other: "MLLMBatch") -> None: self.max_tokens.extend(other.max_tokens) self.requests.extend(other.requests) - # Extend cache - handle None and incompatible caches + # Extend cache - each cache type's extend() handles its own validation. + # Uses empty() (universal via _BaseCache) instead of checking .keys + # (KVCache-specific). This supports hybrid models with ArraysCache + # layers that use .cache instead of .keys/.values. for c, o in zip(self.cache, other.cache): if c is not None and o is not None and hasattr(c, "extend"): try: - # Only extend if both caches have valid keys - if ( - hasattr(c, "keys") - and c.keys is not None - and hasattr(o, "keys") - and o.keys is not None - ): + if not c.empty() and not o.empty(): c.extend(o) except Exception as e: - logger.warning(f"Failed to extend cache: {e}") + logger.warning(f"Failed to extend cache layer: {e}") def extract_cache(self, idx: int) -> List[Any]: """ @@ -207,22 +219,52 @@ def to_dict(self) -> Dict[str, Any]: def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]: """ - Create batch-aware KV cache for the language model. + Create batch-aware cache for the language model. + + Handles all cache types from hybrid models: + - KVCache → BatchKVCache (attention layers) + - ArraysCache → ArraysCache with left_padding (SSM/recurrent layers) + - RotatingKVCache → BatchRotatingKVCache + - CacheList → recursive conversion Args: model: The language model (model.language_model from VLM) left_padding: Padding amounts for left-padded prompts Returns: - List of BatchKVCache objects for each layer + List of batch-aware cache objects for each layer """ - from mlx_lm.models.cache import BatchKVCache, KVCache + from mlx_lm.models.cache import ( + ArraysCache, + BatchKVCache, + BatchRotatingKVCache, + CacheList, + KVCache, + RotatingKVCache, + ) def to_batch_cache(c): - if isinstance(c, KVCache): + # Strict type identity for KVCache — avoid catching QuantizedKVCache + if type(c) is KVCache: return BatchKVCache(left_padding) + elif isinstance(c, ArraysCache): + # ArraysCache handles batching natively — just set left_padding + c.left_padding = mx.array(left_padding) + return c + elif isinstance(c, RotatingKVCache): + if c.keep > 0: + raise ValueError( + "RotatingKVCache with keep tokens is not supported " + "in MLLM continuous batching." + ) + return BatchRotatingKVCache(c.max_size, left_padding) + elif isinstance(c, CacheList): + return CacheList(*(to_batch_cache(sub_c) for sub_c in c.caches)) else: - raise ValueError(f"{type(c)} does not yet support batching") + raise ValueError( + f"MLLM continuous batching does not support {type(c).__name__}. " + f"Supported: KVCache, ArraysCache, RotatingKVCache, CacheList." + ) if hasattr(model, "make_cache"): cache = model.make_cache() @@ -662,19 +704,11 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: per_request_caches.append(request_cache) - # Merge per-request KVCaches into a single BatchKVCache. - # KVCache.merge() creates a BatchKVCache with proper left-padding - # alignment, so all requests share a single batched cache for - # subsequent generation steps. - from mlx_lm.models.cache import KVCache - - sample_cache = per_request_caches[0][0] - if not isinstance(sample_cache, KVCache): - raise ValueError( - f"MLLM continuous batching requires standard KVCache but got " - f"{type(sample_cache).__name__}. Disable --kv-cache-quantization " - f"when using multimodal models with --continuous-batching." - ) + # Merge per-request caches into a single batched cache. + # Each cache type's merge() returns the correct batched representation: + # KVCache.merge() → BatchKVCache, ArraysCache.merge() → batched ArraysCache. + # This supports hybrid models mixing attention + SSM layers. + _validate_caches_mergeable(per_request_caches) try: batch_cache = [ diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a324..c75d3dd60 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -66,6 +66,52 @@ def to_dict(self) -> Dict[str, Any]: } +# Known KV cache class names — positional caches that can be block-sliced +# along the sequence dimension (axis=2). All produce 4D tensors: +# (batch, n_kv_heads, seq_len, head_dim). +_KV_CACHE_CLASSES = frozenset( + { + "KVCache", + "RotatingKVCache", + "QuantizedKVCache", + "ChunkedKVCache", + "ConcatenateKVCache", + "BatchKVCache", + "BatchRotatingKVCache", + } +) + + +def _is_kv_layer(layer_state: dict) -> bool: + """Check if a layer state dict represents a positional KV cache layer. + + Args: + layer_state: Dict from _extract_cache_states() containing + 'class_name', 'state', 'meta_state', 'class_ref'. + + Returns: + True if the layer is a KV cache that can be sliced along seq_len. + """ + return layer_state.get("class_name", "") in _KV_CACHE_CLASSES + + +@dataclass +class NonKVCacheData: + """Full state for non-positional cache layers (SSM, linear attention). + + Hybrid models (Qwen 3.5, Nemotron) have layers that use ArraysCache + instead of KVCache. These store cumulative state (conv + recurrent) + that cannot be sliced into blocks. This dataclass stores the full + state for reconstruction alongside block-sliced KV layers. + """ + + layer_indices: List[int] # Position in the full layer list + states: List[Any] # From cache.state for each non-KV layer + meta_states: List[Any] # From cache.meta_state for each non-KV layer + class_refs: List[Any] # type refs for from_state() reconstruction + total_layers: int # Total layer count (KV + non-KV) + + class PrefixCacheManager: """ Manages prefix caching for vllm-mlx using a trie-based LRU cache. @@ -317,18 +363,29 @@ def _delete_cache(self, model_key: Any, tokens: List[int]) -> None: del parent[tok] def _can_trim_cache(self, prompt_cache: List[Any]) -> bool: - """Check if cache can be trimmed.""" + """Check if ALL cache layers can be trimmed. + + For hybrid models (Qwen 3.5 MoE, Nemotron Mamba+Attention), the + prompt_cache is a mix of KVCache (trimmable) and ArraysCache + (not trimmable). Trimming only KVCache while leaving ArraysCache + untouched creates inconsistent state — the attention layers think + the sequence is shorter while SSM/MoE layers retain the old state. + + Previously this only checked the first cache layer, which for + hybrid models was always KVCache → incorrectly returned True. + """ if not prompt_cache: return False - # Check if first cache layer has is_trimmable method - first_cache = prompt_cache[0] - if hasattr(first_cache, "is_trimmable"): - return first_cache.is_trimmable() - return hasattr(first_cache, "trim") + return all( + c.is_trimmable() if hasattr(c, "is_trimmable") else hasattr(c, "trim") + for c in prompt_cache + ) def _trim_cache(self, prompt_cache: List[Any], num_tokens: int) -> List[Any]: """Trim cache by removing num_tokens from the end.""" for cache in prompt_cache: + if hasattr(cache, "is_trimmable") and not cache.is_trimmable(): + continue if hasattr(cache, "trim"): cache.trim(num_tokens) return prompt_cache @@ -364,6 +421,7 @@ class BlockCacheEntry: block_table: BlockTable cache_data: List[Any] # Actual KV cache data per block last_access: float + has_non_kv: bool = False # True if model has non-positional cache layers class BlockAwarePrefixCache: @@ -422,6 +480,10 @@ def __init__( self._misses = 0 self._tokens_saved = 0 + # Non-KV layer states for hybrid models (SSM, linear attention). + # Keyed by tuple(block_ids) for lookup during reconstruction. + self._non_kv_states: Dict[Tuple[int, ...], NonKVCacheData] = {} + def fetch_cache( self, request_id: str, @@ -458,6 +520,21 @@ def fetch_cache( block_table.num_tokens += block.token_count num_prefix_tokens = len(tokens) - len(remaining) + + # Guard: reject partial prefix match for hybrid models. + # If non-KV states don't exist for this exact block set, + # a hybrid model can't be correctly reconstructed (SSM/KV mismatch). + if self._non_kv_states: + candidate_key = tuple(block_table.block_ids) + if candidate_key not in self._non_kv_states: + logger.debug( + f"Rejecting partial prefix match for {request_id}: " + f"hybrid model requires full match with non-KV states" + ) + self.paged_cache.delete_block_table(request_id) + self._misses += 1 + return None, tokens + self._hits += 1 self._tokens_saved += num_prefix_tokens @@ -483,6 +560,19 @@ def fetch_cache( block_table.num_tokens += block.token_count remaining = tokens[len(matched_tokens) :] + + # Same hybrid guard for prefix index matches + if self._non_kv_states: + candidate_key = tuple(block_table.block_ids) + if candidate_key not in self._non_kv_states: + logger.debug( + f"Rejecting prefix index match for {request_id}: " + f"hybrid model requires full match with non-KV states" + ) + self.paged_cache.delete_block_table(request_id) + self._misses += 1 + return None, tokens + self._hits += 1 self._tokens_saved += len(matched_tokens) @@ -602,11 +692,36 @@ def store_cache( # Update prefix index self._update_prefix_index(tokens, block_table.block_ids) + # Extract and store non-KV layer states for hybrid models + has_non_kv = False + if is_tensor_data and cache_data: + non_kv_indices = [] + non_kv_states_list = [] + non_kv_meta_list = [] + non_kv_refs = [] + for idx, layer_state in enumerate(cache_data): + if not _is_kv_layer(layer_state): + non_kv_indices.append(idx) + non_kv_states_list.append(layer_state.get("state")) + non_kv_meta_list.append(layer_state.get("meta_state")) + non_kv_refs.append(layer_state.get("class_ref")) + if non_kv_indices: + has_non_kv = True + block_key = tuple(block_table.block_ids) + self._non_kv_states[block_key] = NonKVCacheData( + layer_indices=non_kv_indices, + states=non_kv_states_list, + meta_states=non_kv_meta_list, + class_refs=non_kv_refs, + total_layers=len(cache_data), + ) + # Store entry for request (for legacy compatibility) self._request_tables[request_id] = BlockCacheEntry( block_table=block_table, cache_data=cache_data, last_access=time.time(), + has_non_kv=has_non_kv, ) blocks_with_data = sum( @@ -629,25 +744,39 @@ def _extract_block_tensor_slice( cache_data: List[Dict[str, Any]], start_idx: int, end_idx: int, - ) -> Optional[List[Tuple[Any, Any]]]: + ) -> Optional[List[Optional[Tuple[Any, Any]]]]: """ Extract tensor slices for a single block from cache data. + For KV cache layers (positional), slices keys/values along the + sequence dimension. For non-KV layers (ArraysCache, etc.), + returns None at that position — these are stored separately + by store_cache() as whole-sequence state. + Args: - cache_data: List of layer states, each containing 'state': (keys, values) + cache_data: List of layer states from _extract_cache_states() start_idx: Start token index in the sequence end_idx: End token index in the sequence Returns: - List of (keys_slice, values_slice) for each layer, or None on failure + List with one entry per layer: + - (keys_slice, values_slice) for KV layers + - None for non-KV layers + Returns None only on complete failure. """ if not HAS_MLX or not cache_data: return None try: - block_slices = [] + block_slices: List[Optional[Tuple[Any, Any]]] = [] for layer_state in cache_data: if "state" not in layer_state: + block_slices.append(None) + continue + + # Skip non-KV layers — they can't be block-sliced + if not _is_kv_layer(layer_state): + block_slices.append(None) continue keys, values = layer_state["state"] @@ -657,13 +786,9 @@ def _extract_block_tensor_slice( seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 if end_idx > seq_len: - # Requested range extends beyond available data - logger.debug( - f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" - ) - # Use whatever is available actual_end = min(end_idx, seq_len) if start_idx >= actual_end: + block_slices.append(None) continue keys_slice = keys[:, :, start_idx:actual_end, :] values_slice = values[:, :, start_idx:actual_end, :] @@ -673,7 +798,11 @@ def _extract_block_tensor_slice( block_slices.append((keys_slice, values_slice)) - return block_slices if block_slices else None + # Return None only if no layers produced data at all + if all(s is None for s in block_slices): + return None + + return block_slices except Exception as e: logger.warning(f"Failed to extract block tensor slice: {e}") @@ -719,6 +848,14 @@ def release_cache(self, request_id: str) -> None: """ entry = self._request_tables.pop(request_id, None) if entry: + # Clean up non-KV states if this was the last reference + block_key = tuple(entry.block_table.block_ids) + other_uses = any( + tuple(e.block_table.block_ids) == block_key + for e in self._request_tables.values() + ) + if not other_uses: + self._non_kv_states.pop(block_key, None) self.paged_cache.delete_block_table(request_id) logger.debug(f"Released cache for {request_id}") @@ -752,6 +889,7 @@ def fork_cache( block_table=forked_table, cache_data=source_entry.cache_data, # Shared reference last_access=time.time(), + has_non_kv=source_entry.has_non_kv, ) logger.debug(f"Forked cache: {source_request_id} -> {new_request_id}") @@ -763,16 +901,17 @@ def reconstruct_cache( block_table: BlockTable, ) -> Optional[List[Any]]: """ - Reconstruct KVCache objects from stored block tensor data. + Reconstruct cache objects from stored block tensor data. - This method concatenates tensor slices from all blocks and - creates new KVCache objects that can be used for inference. + For pure-KV models: concatenates KV block slices into KVCache objects. + For hybrid models: also restores non-KV layers (ArraysCache, etc.) + from stored whole-sequence state via from_state(). Args: block_table: BlockTable containing block IDs to reconstruct from Returns: - List of reconstructed KVCache objects (one per layer), + List of reconstructed cache objects (one per layer), or None if reconstruction fails """ if not block_table or not block_table.block_ids: @@ -783,6 +922,10 @@ def reconstruct_cache( return None try: + # Check for non-KV states (hybrid model) + block_key = tuple(block_table.block_ids) + non_kv_data = self._non_kv_states.get(block_key) + # Collect cache data from all blocks all_block_data = [] for block_id in block_table.block_ids: @@ -805,69 +948,113 @@ def reconstruct_cache( if num_layers == 0: return None - # Concatenate tensors for each layer + # If hybrid model but no non-KV states, can't reconstruct + if non_kv_data is not None and non_kv_data.total_layers != num_layers: + logger.warning( + f"Layer count mismatch: blocks have {num_layers}, " + f"non-KV data expects {non_kv_data.total_layers}" + ) + return None + + # Build set of non-KV layer indices for fast lookup + non_kv_idx_set = set() + if non_kv_data is not None: + non_kv_idx_set = set(non_kv_data.layer_indices) + + # Check if any non-KV layers exist in the data but we have + # no stored states — indicates hybrid model with missing data + if not non_kv_data: + for layer_idx in range(num_layers): + if all_block_data[0][layer_idx] is None: + logger.debug( + "Hybrid model detected but no non-KV states " + "stored — cannot reconstruct" + ) + return None + + # Build layer_idx → position mapping for O(1) lookup + non_kv_pos_map = {} + if non_kv_data is not None: + non_kv_pos_map = { + idx: pos for pos, idx in enumerate(non_kv_data.layer_indices) + } + + # Reconstruct each layer reconstructed_caches = [] for layer_idx in range(num_layers): - layer_keys = [] - layer_values = [] - - for block_data in all_block_data: - if layer_idx < len(block_data): - keys_slice, values_slice = block_data[layer_idx] - layer_keys.append(keys_slice) - layer_values.append(values_slice) - - if not layer_keys: - continue - - # Concatenate along sequence dimension (axis 2) - # Shape: (batch, n_kv_heads, seq_len, head_dim) - concat_keys = mx.concatenate(layer_keys, axis=2) - concat_values = mx.concatenate(layer_values, axis=2) - - # Create KVCache object - # Try to use mlx_lm's KVCache.from_state if available - try: - from mlx_lm.models.cache import KVCache - - # Create new cache and set its state - cache = KVCache() - seq_len = concat_keys.shape[2] - - # Set internal state directly - # KVCache stores keys/values and offset - cache.keys = concat_keys - cache.values = concat_values - cache.offset = seq_len - - reconstructed_caches.append(cache) - - except ImportError: - # Fallback: create a simple cache-like object - class SimpleKVCache: - def __init__(self, keys, values): - self.keys = keys - self.values = values - self.offset = keys.shape[2] - - @property - def state(self): - return (self.keys, self.values) - - @property - def meta_state(self): - return (str(self.offset),) - - cache = SimpleKVCache(concat_keys, concat_values) - reconstructed_caches.append(cache) + if layer_idx in non_kv_idx_set: + # Non-KV layer: restore from stored whole-sequence state + pos = non_kv_pos_map[layer_idx] + state = non_kv_data.states[pos] + meta = non_kv_data.meta_states[pos] + cls = non_kv_data.class_refs[pos] + + if cls is not None and hasattr(cls, "from_state"): + cache_obj = cls.from_state(state, meta) + else: + logger.warning(f"No class_ref for non-KV layer {layer_idx}") + return None + + reconstructed_caches.append(cache_obj) + else: + # KV layer: concatenate block slices + layer_keys = [] + layer_values = [] + + for block_data in all_block_data: + if layer_idx < len(block_data): + entry = block_data[layer_idx] + if entry is not None: + keys_slice, values_slice = entry + layer_keys.append(keys_slice) + layer_values.append(values_slice) + + if not layer_keys: + logger.debug(f"No KV data for layer {layer_idx}") + return None + + # Concatenate along sequence dimension (axis 2) + concat_keys = mx.concatenate(layer_keys, axis=2) + concat_values = mx.concatenate(layer_values, axis=2) + + try: + from mlx_lm.models.cache import KVCache + + cache_obj = KVCache() + cache_obj.keys = concat_keys + cache_obj.values = concat_values + cache_obj.offset = concat_keys.shape[2] + reconstructed_caches.append(cache_obj) + + except ImportError: + + class SimpleKVCache: + def __init__(self, keys, values): + self.keys = keys + self.values = values + self.offset = keys.shape[2] + + @property + def state(self): + return (self.keys, self.values) + + @property + def meta_state(self): + return (str(self.offset),) + + reconstructed_caches.append( + SimpleKVCache(concat_keys, concat_values) + ) if not reconstructed_caches: return None logger.debug( - f"Reconstructed cache: {len(reconstructed_caches)} layers, " - f"{block_table.num_tokens} tokens from {len(block_table.block_ids)} blocks" + f"Reconstructed cache: {len(reconstructed_caches)} layers " + f"({len(non_kv_idx_set)} non-KV), " + f"{block_table.num_tokens} tokens from " + f"{len(block_table.block_ids)} blocks" ) return reconstructed_caches @@ -944,6 +1131,7 @@ def clear(self) -> None: """Clear all cached data.""" self._request_tables.clear() self._prefix_index.clear() + self._non_kv_states.clear() self.paged_cache.clear() self.reset_stats() diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 88d144cb7..92285f238 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -392,6 +392,7 @@ def _chunked_next(self=batch_gen): # noqa: C901 caches, samplers, logits_processors, + *_extra, ) = zip(*batch_prompts) lengths = [len(p) for p in inputs_raw] max_length = max(lengths) @@ -1487,7 +1488,16 @@ def _reconstruct_cache_from_states( # Fallback: try KVCache manual reconstruction from mlx_lm.models.cache import KVCache - if len(state) != 2: + if ( + not isinstance(state, (tuple, list)) + or len(state) != 2 + or not hasattr(state[0], "shape") + or state[0].ndim != 4 + ): + logger.debug( + f"[mid_prefill_cache] skipping non-KV layer " + f"(state type={type(state).__name__})" + ) return None cache = KVCache() cache.keys, cache.values = state diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index f0328d4e6..6380ee731 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1326,12 +1326,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re messages.append(msg_dict) images, videos = [], [] # MLLM extracts these from messages logger.debug(f"MLLM: Processing {len(messages)} messages") + messages = _normalize_messages(messages) else: # For LLM, extract text, images, and videos separately messages, images, videos = extract_multimodal_content( request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) has_media = bool(images or videos) @@ -1434,6 +1436,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) +def _normalize_messages(messages: list[dict]) -> list[dict]: + """Normalize message roles and merge consecutive same-role messages. + + 1. Maps non-standard roles to standard ones (e.g. ``developer`` → ``system``). + 2. Merges consecutive same-role messages to satisfy chat template constraints + (Qwen 3.5, Llama, etc. require alternating roles). + + Only merges when both messages have string content. Messages with list + content (multimodal) are left as-is to preserve image/video attachments. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + New list with normalized roles and consecutive same-role messages merged. + """ + # OpenAI Responses API uses "developer" instead of "system". + # Map it so chat templates don't fail and fall back to raw prefill. + _ROLE_MAP = {"developer": "system"} + + if not messages: + return messages + + merged = [messages[0].copy()] + if merged[0]["role"] in _ROLE_MAP: + merged[0]["role"] = _ROLE_MAP[merged[0]["role"]] + for msg in messages[1:]: + prev = merged[-1] + role = _ROLE_MAP.get(msg["role"], msg["role"]) + if ( + role == prev["role"] + and isinstance(prev.get("content"), str) + and isinstance(msg.get("content"), str) + ): + # Merge string content with double newline separator + prev["content"] = prev["content"] + "\n\n" + msg["content"] + logger.debug( + f"Merged consecutive {role} messages " + f"({len(prev['content'])} chars total)" + ) + else: + copy = msg.copy() + copy["role"] = role + merged.append(copy) + + mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP) + merged_count = len(messages) - len(merged) + if mapped_roles or merged_count: + parts = [] + if mapped_roles: + parts.append(f"mapped {mapped_roles} role(s)") + if merged_count: + parts.append(f"merged {len(messages)} → {len(merged)}") + logger.info(f"Normalized messages: {', '.join(parts)}") + + return merged + + def _inject_json_instruction(messages: list, instruction: str) -> list: """ Inject JSON instruction into messages. @@ -1529,6 +1589,7 @@ async def create_anthropic_message( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens, @@ -1686,6 +1747,7 @@ async def _stream_anthropic_messages( openai_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = _normalize_messages(messages) chat_kwargs = { "max_tokens": openai_request.max_tokens or _default_max_tokens,