diff --git a/tests/test_paged_cache.py b/tests/test_paged_cache.py index 8e3082c3..5d5eac40 100644 --- a/tests/test_paged_cache.py +++ b/tests/test_paged_cache.py @@ -725,3 +725,93 @@ def test_clear(self): stats = cache.get_stats() # After clear, null block is still allocated (vLLM style) assert stats["allocated_blocks"] == 1 # only null block + + def test_reconstructs_hybrid_cache_from_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + tokens = list(range(8)) + kv_keys = mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3) + kv_values = mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3) + linear_state = [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ] + extracted = [ + { + "state": (kv_keys, kv_values), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": linear_state, + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", tokens, extracted) + first_block = paged_manager.allocated_blocks[block_table.block_ids[0]] + last_block = paged_manager.allocated_blocks[block_table.block_ids[-1]] + + assert first_block.cache_data[0] is not None + assert first_block.cache_data[1] is None + assert last_block.cache_data[1] is not None + + reconstructed = cache.reconstruct_cache(block_table) + + assert reconstructed is not None + assert isinstance(reconstructed[0], KVCache) + assert isinstance(reconstructed[1], ArraysCache) + assert reconstructed[0].state[0].tolist() == kv_keys.tolist() + assert reconstructed[0].state[1].tolist() == kv_values.tolist() + assert reconstructed[1].state[0].tolist() == linear_state[0].tolist() + assert reconstructed[1].state[1].tolist() == linear_state[1].tolist() + + def test_rejects_hybrid_prefix_without_boundary_snapshot(self): + from mlx_lm.models.cache import ArraysCache, KVCache + import mlx.core as mx + + from vllm_mlx.paged_cache import BlockTable, PagedCacheManager + from vllm_mlx.prefix_cache import BlockAwarePrefixCache + + paged_manager = PagedCacheManager(block_size=4, max_blocks=10) + cache = BlockAwarePrefixCache(model=None, paged_cache_manager=paged_manager) + + extracted = [ + { + "state": ( + mx.arange(1 * 2 * 8 * 3).reshape(1, 2, 8, 3), + mx.arange(1000, 1000 + (1 * 2 * 8 * 3)).reshape(1, 2, 8, 3), + ), + "meta_state": "", + "class_ref": KVCache, + "class_name": "KVCache", + }, + { + "state": [ + mx.arange(1 * 3 * 8).reshape(1, 3, 8), + mx.arange(2000, 2000 + (1 * 2 * 4 * 4)).reshape(1, 2, 4, 4), + ], + "meta_state": "", + "class_ref": ArraysCache, + "class_name": "ArraysCache", + }, + ] + + block_table = cache.store_cache("req-1", list(range(8)), extracted) + prefix_table = BlockTable( + request_id="req-prefix", + block_ids=[block_table.block_ids[0]], + num_tokens=4, + ) + + assert cache.reconstruct_cache(prefix_table) is None diff --git a/tests/test_tokenizer_utils.py b/tests/test_tokenizer_utils.py new file mode 100644 index 00000000..d95fecc7 --- /dev/null +++ b/tests/test_tokenizer_utils.py @@ -0,0 +1,47 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tokenizer utility helpers.""" + +import platform +import sys +from unittest.mock import patch + +import pytest + +pytestmark = pytest.mark.skipif( + sys.platform != "darwin" or platform.machine() != "arm64", + reason="Requires Apple Silicon", +) + + +class TestLoadModelWithFallback: + def test_returns_successful_load_result(self): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + with patch("mlx_lm.load", return_value=(fake_model, fake_tokenizer)) as load: + model, tokenizer = load_model_with_fallback("mlx-community/Qwen3.5-4B") + + load.assert_called_once() + assert model is fake_model + assert tokenizer is fake_tokenizer + + def test_uses_tokenizer_fallback_for_tokenizer_errors(self): + from vllm_mlx.utils.tokenizer import load_model_with_fallback + + fake_model = object() + fake_tokenizer = object() + + with patch( + "mlx_lm.load", + side_effect=ValueError("Tokenizer class Foo does not exist"), + ), patch( + "vllm_mlx.utils.tokenizer._load_with_tokenizer_fallback", + return_value=(fake_model, fake_tokenizer), + ) as fallback: + model, tokenizer = load_model_with_fallback("example/model") + + fallback.assert_called_once_with("example/model") + assert model is fake_model + assert tokenizer is fake_tokenizer diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index e8f47a32..0bfe329f 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -586,7 +586,7 @@ def store_cache( # Extract and store actual tensor slices for this block if is_tensor_data and HAS_MLX: block_kv_data = self._extract_block_tensor_slice( - cache_data, global_start, global_end + cache_data, global_start, global_end, len(tokens) ) if block_kv_data: block.cache_data = block_kv_data @@ -629,56 +629,120 @@ def _extract_block_tensor_slice( cache_data: List[Dict[str, Any]], start_idx: int, end_idx: int, - ) -> Optional[List[Tuple[Any, Any]]]: + total_tokens: int, + ) -> Optional[List[Optional[Dict[str, Any]]]]: """ - Extract tensor slices for a single block from cache data. + Extract per-layer cache data for a single block. Args: - cache_data: List of layer states, each containing 'state': (keys, values) + cache_data: List of extracted layer states start_idx: Start token index in the sequence end_idx: End token index in the sequence + total_tokens: Total number of tokens covered by cache_data Returns: - List of (keys_slice, values_slice) for each layer, or None on failure + Per-layer block cache state, or None on failure """ if not HAS_MLX or not cache_data: return None try: - block_slices = [] + block_slices: List[Optional[Dict[str, Any]]] = [] for layer_state in cache_data: if "state" not in layer_state: + block_slices.append(None) continue - keys, values = layer_state["state"] + state = layer_state["state"] + meta_state = layer_state.get("meta_state") + class_ref = layer_state.get("class_ref") + class_name = layer_state.get("class_name") - # KV cache shape: (batch, n_kv_heads, seq_len, head_dim) - # Slice along seq_len dimension (axis 2) - seq_len = keys.shape[2] if hasattr(keys, "shape") else 0 + if self._can_concatenate_cache_state(state): + state_slice = self._slice_concat_cache_state( + state, start_idx, end_idx + ) + block_slices.append( + { + "state": state_slice, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "concat", + "seq_axis": 2, + } + ) + continue - if end_idx > seq_len: - # Requested range extends beyond available data - logger.debug( - f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + if end_idx == total_tokens: + block_slices.append( + { + "state": state, + "meta_state": meta_state, + "class_ref": class_ref, + "class_name": class_name, + "storage": "latest", + } ) - # Use whatever is available - actual_end = min(end_idx, seq_len) - if start_idx >= actual_end: - continue - keys_slice = keys[:, :, start_idx:actual_end, :] - values_slice = values[:, :, start_idx:actual_end, :] else: - keys_slice = keys[:, :, start_idx:end_idx, :] - values_slice = values[:, :, start_idx:end_idx, :] + block_slices.append(None) - block_slices.append((keys_slice, values_slice)) - - return block_slices if block_slices else None + return block_slices if any(entry is not None for entry in block_slices) else None except Exception as e: logger.warning(f"Failed to extract block tensor slice: {e}") return None + def _can_concatenate_cache_state(self, state: Any) -> bool: + """Return True when cache state can be concatenated block-by-block.""" + if not isinstance(state, (list, tuple)) or not state: + return False + return all( + tensor is not None + and hasattr(tensor, "shape") + and len(tensor.shape) == 4 + for tensor in state + ) + + def _slice_concat_cache_state( + self, + state: Tuple[Any, ...] | List[Any], + start_idx: int, + end_idx: int, + ) -> Tuple[Any, ...] | List[Any]: + """Slice a sequence-backed cache state across the token axis.""" + seq_len = state[0].shape[2] + actual_end = min(end_idx, seq_len) + if start_idx >= actual_end: + raise ValueError( + f"Block slice [{start_idx}:{end_idx}] exceeds seq_len {seq_len}" + ) + + def _slice_tensor(tensor: Any) -> Any: + slices = [slice(None)] * len(tensor.shape) + slices[2] = slice(start_idx, actual_end) + return tensor[tuple(slices)] + + sliced = [_slice_tensor(tensor) for tensor in state] + return tuple(sliced) if isinstance(state, tuple) else sliced + + def _concat_cache_states( + self, + states: List[Tuple[Any, ...] | List[Any]], + seq_axis: int, + ) -> Optional[Tuple[Any, ...] | List[Any]]: + """Concatenate state fragments for a sequence-backed cache layer.""" + if not states: + return None + arity = len(states[0]) + concatenated = [] + for idx in range(arity): + parts = [state[idx] for state in states] + if any(part is None for part in parts): + return None + concatenated.append(mx.concatenate(parts, axis=seq_axis)) + return tuple(concatenated) if isinstance(states[0], tuple) else concatenated + def get_cache_for_generation( self, request_id: str, @@ -763,10 +827,11 @@ def reconstruct_cache( block_table: BlockTable, ) -> Optional[List[Any]]: """ - Reconstruct KVCache objects from stored block tensor data. + Reconstruct cache objects from stored block tensor data. - This method concatenates tensor slices from all blocks and - creates new KVCache objects that can be used for inference. + Sequence-backed caches are concatenated block-by-block. Recurrent + caches such as ArraysCache are restored from the latest sequence + boundary snapshot that was actually stored. Args: block_table: BlockTable containing block IDs to reconstruct from @@ -800,67 +865,62 @@ def reconstruct_cache( if not all_block_data: return None - # Get number of layers from first block - num_layers = len(all_block_data[0]) + # Get number of layers from the richest block + num_layers = max(len(block_data) for block_data in all_block_data) if num_layers == 0: return None - # Concatenate tensors for each layer reconstructed_caches = [] - for layer_idx in range(num_layers): - layer_keys = [] - layer_values = [] + layer_entries = [ + block_data[layer_idx] + for block_data in all_block_data + if layer_idx < len(block_data) + ] + layer_entries = [entry for entry in layer_entries if entry is not None] + if not layer_entries: + return None - for block_data in all_block_data: - if layer_idx < len(block_data): - keys_slice, values_slice = block_data[layer_idx] - layer_keys.append(keys_slice) - layer_values.append(values_slice) + layer_meta = layer_entries[-1] + state = layer_meta["state"] + if layer_meta["storage"] == "concat": + state = self._concat_cache_states( + [entry["state"] for entry in layer_entries], + layer_meta["seq_axis"], + ) + elif layer_meta["storage"] == "latest": + state = layer_entries[-1]["state"] - if not layer_keys: - continue + if state is None: + return None - # Concatenate along sequence dimension (axis 2) - # Shape: (batch, n_kv_heads, seq_len, head_dim) - concat_keys = mx.concatenate(layer_keys, axis=2) - concat_values = mx.concatenate(layer_values, axis=2) + cache_cls = layer_meta.get("class_ref") + meta_state = layer_meta.get("meta_state") - # Create KVCache object - # Try to use mlx_lm's KVCache.from_state if available - try: + if cache_cls is not None and hasattr(cache_cls, "from_state"): + from mlx_lm.models.cache import ( + BatchKVCache as _BatchKVCache, + KVCache as _KVCache, + ) + + if cache_cls is _BatchKVCache: + keys, values = state[0], state[1] + cache = _KVCache() + cache.keys = keys + cache.values = values + cache.offset = keys.shape[2] + else: + cache = cache_cls.from_state(state, meta_state) + else: from mlx_lm.models.cache import KVCache - # Create new cache and set its state + if len(state) != 2: + return None cache = KVCache() - seq_len = concat_keys.shape[2] - - # Set internal state directly - # KVCache stores keys/values and offset - cache.keys = concat_keys - cache.values = concat_values - cache.offset = seq_len - - reconstructed_caches.append(cache) - - except ImportError: - # Fallback: create a simple cache-like object - class SimpleKVCache: - def __init__(self, keys, values): - self.keys = keys - self.values = values - self.offset = keys.shape[2] - - @property - def state(self): - return (self.keys, self.values) - - @property - def meta_state(self): - return (str(self.offset),) - - cache = SimpleKVCache(concat_keys, concat_values) - reconstructed_caches.append(cache) + cache.keys, cache.values = state + cache.offset = cache.keys.shape[2] + + reconstructed_caches.append(cache) if not reconstructed_caches: return None diff --git a/vllm_mlx/utils/tokenizer.py b/vllm_mlx/utils/tokenizer.py index a5088395..aaaeae55 100644 --- a/vllm_mlx/utils/tokenizer.py +++ b/vllm_mlx/utils/tokenizer.py @@ -52,6 +52,7 @@ def load_model_with_fallback(model_name: str, tokenizer_config: dict = None): try: model, tokenizer = load(model_name, tokenizer_config=tokenizer_config) + return model, tokenizer except ValueError as e: # Fallback for models with non-standard tokenizers if "TokenizersBackend" in str(e) or "Tokenizer class" in str(e):