diff --git a/pyproject.toml b/pyproject.toml index c59536fa..038c28df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,10 +28,8 @@ classifiers = [ dependencies = [ # MLX - Required for Apple Silicon GPU acceleration - "mlx>=0.20.0; platform_system == 'Darwin' and platform_machine == 'arm64'", - # mlx-lm 0.30.6 removed MambaCache; pin until we support ArraysCache-only versions. - # See: https://github.com/vllm-project/vllm-metal/issues/100 - "mlx-lm>=0.20.0,<0.30.6; platform_system == 'Darwin' and platform_machine == 'arm64'", + "mlx>=0.29.2; platform_system == 'Darwin' and platform_machine == 'arm64'", + "mlx-lm>=0.28.4; platform_system == 'Darwin' and platform_machine == 'arm64'", "mlx-vlm>=0.3.0; platform_system == 'Darwin' and platform_machine == 'arm64'", # Vision-language model support # Model loading and weights "transformers>=4.40.0", diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 714275f4..6ee9a3b5 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -5,11 +5,12 @@ from unittest.mock import MagicMock import mlx.core as mx +import pytest import vllm_metal.v1.model_runner as mr -class StubMambaCache: +class StubArraysCache: @property def state(self): return [] @@ -32,7 +33,7 @@ def fake_make_prompt_cache(model): kv.keys = mx.zeros((1, 8, 0, 64)) kv.values = mx.zeros((1, 8, 0, 64)) kv.offset = 0 - return [kv, StubMambaCache(), kv] + return [kv, StubArraysCache(), kv] monkeypatch.setattr(mr, "make_prompt_cache", fake_make_prompt_cache) @@ -92,6 +93,134 @@ def fake_make_prompt_cache(model): insert_spy.assert_called_once() +class TestHybridCacheMergeExtract: + """Regression tests for hybrid (KV + ArraysCache) batching. + + Background: + - `mlx-lm==0.30.6` removed `MambaCache` and hybrid models now use `ArraysCache`. + - Older mlx-lm versions don't provide `ArraysCache.merge()` / `extract()`. + + These tests validate that vllm-metal can merge per-request caches into a batched + cache, run a batched forward pass, and then extract per-request caches back, + without depending on `MambaCache` or new mlx-lm APIs. + """ + + _ARRAYS_CACHE_ENTRIES = 2 + _ARRAYS_CACHE_FEATURES = 4 + + _KV_NUM_HEADS = 1 + _KV_HEAD_DIM = 2 + + def _make_arrays_cache(self, v0: float | None, v1: float | None) -> mr.ArraysCache: + cache = mr.ArraysCache(self._ARRAYS_CACHE_ENTRIES) + if v0 is not None: + cache[0] = mx.full((1, self._ARRAYS_CACHE_FEATURES), v0, dtype=mx.float32) + if v1 is not None: + cache[1] = mx.full((1, self._ARRAYS_CACHE_FEATURES), v1, dtype=mx.float32) + return cache + + def _make_kv_cache(self, seq_len: int, value: float) -> mr.KVCache: + kv = mr.KVCache() + kv.keys = mx.full( + (1, self._KV_NUM_HEADS, seq_len, self._KV_HEAD_DIM), + value, + dtype=mx.float32, + ) + kv.values = mx.full( + (1, self._KV_NUM_HEADS, seq_len, self._KV_HEAD_DIM), + value + 0.5, + dtype=mx.float32, + ) + kv.offset = seq_len + return kv + + def test_arrays_cache_merge_extract_roundtrip(self) -> None: + """Merging then extracting ArraysCache round-trips per request.""" + arrays_cache_req0 = self._make_arrays_cache(1.0, 11.0) + arrays_cache_req1 = self._make_arrays_cache(2.0, 22.0) + + merged = mr._merge_kv_caches([[arrays_cache_req0], [arrays_cache_req1]]) + extracted_req0 = mr._extract_kv_cache(merged, 0)[0] + extracted_req1 = mr._extract_kv_cache(merged, 1)[0] + + assert isinstance(merged[0], mr.ArraysCache) + assert isinstance(extracted_req0, mr.ArraysCache) + assert isinstance(extracted_req1, mr.ArraysCache) + assert bool(mx.allclose(extracted_req0.state[0], arrays_cache_req0.state[0])) + assert bool(mx.allclose(extracted_req0.state[1], arrays_cache_req0.state[1])) + assert bool(mx.allclose(extracted_req1.state[0], arrays_cache_req1.state[0])) + assert bool(mx.allclose(extracted_req1.state[1], arrays_cache_req1.state[1])) + + def test_arrays_cache_merge_extract_handles_missing_entries(self) -> None: + """Missing per-request entries become zeros after merging. + + ArraysCache merging densifies per-entry state into a batch array when at + least one request has that entry populated. Requests that had `None` + for the entry are represented as zeros in the merged state. + """ + arrays_cache_req0 = self._make_arrays_cache(1.0, 11.0) + arrays_cache_req1 = self._make_arrays_cache(2.0, None) + + merged = mr._merge_kv_caches([[arrays_cache_req0], [arrays_cache_req1]]) + + extracted_req0 = mr._extract_kv_cache(merged, 0)[0] + extracted_req1 = mr._extract_kv_cache(merged, 1)[0] + + assert isinstance(extracted_req0, mr.ArraysCache) + assert isinstance(extracted_req1, mr.ArraysCache) + + assert bool(mx.allclose(extracted_req0.state[0], arrays_cache_req0.state[0])) + assert bool(mx.allclose(extracted_req0.state[1], arrays_cache_req0.state[1])) + assert bool(mx.allclose(extracted_req1.state[0], arrays_cache_req1.state[0])) + + missing = extracted_req1.state[1] + assert missing is not None + assert missing.shape == (1, self._ARRAYS_CACHE_FEATURES) + assert bool(mx.allclose(missing, mx.zeros_like(missing))) + + def test_mixed_kv_and_arrays_cache_merge_extract_roundtrip(self) -> None: + """Merging/extracting preserves both KVCache and ArraysCache layers.""" + kv_cache_req0 = self._make_kv_cache(seq_len=2, value=1.0) + kv_cache_req1 = self._make_kv_cache(seq_len=4, value=2.0) + arrays_cache_req0 = self._make_arrays_cache(3.0, 33.0) + arrays_cache_req1 = self._make_arrays_cache(4.0, 44.0) + + merged = mr._merge_kv_caches( + [[kv_cache_req0, arrays_cache_req0], [kv_cache_req1, arrays_cache_req1]] + ) + extracted_req0 = mr._extract_kv_cache(merged, 0) + extracted_req1 = mr._extract_kv_cache(merged, 1) + + assert isinstance(merged[0], mr.BatchKVCache) + assert isinstance(merged[1], mr.ArraysCache) + + kv_req0_out, arrays_req0_out = extracted_req0 + kv_req1_out, arrays_req1_out = extracted_req1 + + assert isinstance(kv_req0_out, mr.KVCache) + assert isinstance(kv_req1_out, mr.KVCache) + assert isinstance(arrays_req0_out, mr.ArraysCache) + assert isinstance(arrays_req1_out, mr.ArraysCache) + + assert kv_req0_out.offset == kv_cache_req0.offset + assert kv_req1_out.offset == kv_cache_req1.offset + assert bool(mx.allclose(kv_req0_out.keys, kv_cache_req0.keys)) + assert bool(mx.allclose(kv_req0_out.values, kv_cache_req0.values)) + assert bool(mx.allclose(kv_req1_out.keys, kv_cache_req1.keys)) + assert bool(mx.allclose(kv_req1_out.values, kv_cache_req1.values)) + + assert bool(mx.allclose(arrays_req0_out.state[0], arrays_cache_req0.state[0])) + assert bool(mx.allclose(arrays_req0_out.state[1], arrays_cache_req0.state[1])) + assert bool(mx.allclose(arrays_req1_out.state[0], arrays_cache_req1.state[0])) + assert bool(mx.allclose(arrays_req1_out.state[1], arrays_cache_req1.state[1])) + + def test_merge_kv_caches_rejects_mixed_cache_types_within_layer(self) -> None: + arrays_cache = self._make_arrays_cache(1.0, 2.0) + kv_cache = mr.KVCache() + with pytest.raises(TypeError, match="Mixed cache types in a single layer"): + mr._merge_kv_caches([[arrays_cache], [kv_cache]]) + + class TestPrefixCacheEviction: def test_eviction_under_max_bytes(self) -> None: # 1KB limit diff --git a/vllm_metal/v1/model_runner.py b/vllm_metal/v1/model_runner.py index ab7d5530..139767a9 100644 --- a/vllm_metal/v1/model_runner.py +++ b/vllm_metal/v1/model_runner.py @@ -27,7 +27,6 @@ ArraysCache, BatchKVCache, KVCache, - MambaCache, make_prompt_cache, ) @@ -149,7 +148,7 @@ class CachedPrefix: """Cached KV state for a token prefix. cache_state contains (k, v) tuples for KVCache layers, or None for - MambaCache/ArraysCache layers in hybrid models. + ArraysCache layers in hybrid models. """ token_ids: list[int] @@ -207,8 +206,8 @@ def _evict_until_fits(self, needed_bytes: int) -> None: def insert(self, token_ids: list[int], cache: list[KVCache]) -> None: """Insert a prefix cache entry with memory-based eviction. - Only KVCache layers are cached. MambaCache/ArraysCache layers are - skipped (stored as None) for hybrid model compatibility. + Only KVCache layers are cached. ArraysCache layers are skipped (stored as + None) for hybrid model compatibility. """ prefix_hash = _compute_prefix_hash(token_ids) if prefix_hash in self._cache: @@ -249,8 +248,8 @@ def restore_cache( ) -> list[KVCache]: """Restore a cached prefix to a fresh KVCache. - Only KVCache layers are restored. MambaCache/ArraysCache layers - remain in their fresh state for hybrid model compatibility. + Only KVCache layers are restored. ArraysCache layers remain in their + fresh state for hybrid model compatibility. """ cache_model = ( model.language_model @@ -284,59 +283,50 @@ def get_stats(self) -> dict: # Type alias for any cache type supported by the model -AnyCache: TypeAlias = KVCache | MambaCache | ArraysCache +AnyCache: TypeAlias = KVCache | ArraysCache -class BatchMambaCache: - """Batched cache for Mamba/SSM layers. +def _merge_arrays_caches(caches: list[ArraysCache]) -> ArraysCache: + """Merge per-request ArraysCache objects into a single batched ArraysCache. - Wraps multiple MambaCache instances into a single batched cache - for efficient batched forward passes on hybrid models. + This mirrors the behavior of `mlx_lm.models.cache.ArraysCache.merge` but is + implemented here for compatibility with older mlx-lm versions that do not + provide `merge()` / `extract()`. """ + if not caches: + raise ValueError("caches must be non-empty") - def __init__(self, caches: list[MambaCache | ArraysCache]): - """Create a batched Mamba cache from individual caches. + num_entries = len(caches[0].state) + batch_size = len(caches) - Args: - caches: List of MambaCache instances to batch - """ - self._batch_size = len(caches) - self._cache_size = len(caches[0].cache) if caches else 0 - - # Stack each state array across the batch dimension - self.cache: list[mx.array | None] = [] - for i in range(self._cache_size): - states = [c.cache[i] for c in caches] - if all(s is not None for s in states): - self.cache.append(mx.concatenate(states, axis=0)) - else: - self.cache.append(None) - - def __getitem__(self, idx: int) -> mx.array | None: - return self.cache[idx] + merged = ArraysCache(num_entries) + for entry_idx in range(num_entries): + values = [cache.state[entry_idx] for cache in caches] + template = next((value for value in values if value is not None), None) + if template is None: + continue - def __setitem__(self, idx: int, value: mx.array | None) -> None: - self.cache[idx] = value + shape = list(template.shape) + shape[0] = batch_size + merged_state = mx.zeros(tuple(shape), template.dtype) + for batch_idx, value in enumerate(values): + if value is None: + continue + merged_state[batch_idx : batch_idx + 1] = value - def extract(self, idx: int) -> MambaCache: - """Extract a single request's cache from the batch. + merged[entry_idx] = merged_state - Args: - idx: Index of the request in the batch - - Returns: - MambaCache for the individual request - """ - cache = MambaCache() - for i in range(self._cache_size): - if self.cache[i] is not None: - cache.cache[i] = self.cache[i][idx : idx + 1] - return cache + return merged -def _is_mamba_cache(cache: AnyCache) -> bool: - """Check if a cache is a Mamba-style cache (ArraysCache or MambaCache).""" - return isinstance(cache, (MambaCache, ArraysCache)) +def _extract_arrays_cache(batch_cache: ArraysCache, idx: int) -> ArraysCache: + """Extract a single request's ArraysCache from a batched ArraysCache.""" + state = batch_cache.state + extracted = ArraysCache(len(state)) + extracted.state = [ + None if value is None else value[idx : idx + 1] for value in state + ] + return extracted def _mlx_greedy_sample(logits: mx.array) -> mx.array: @@ -386,7 +376,7 @@ class RequestState: # vLLM applies repetition penalties to both prompt+output tokens, but applies # presence/frequency penalties only to generated (output) tokens. prompt_len: int - cache: list[AnyCache] # Per-layer caches (KVCache or MambaCache for hybrid models) + cache: list[AnyCache] # Per-layer caches (KVCache or ArraysCache for hybrid models) sampling_params: SamplingParams # Sampling parameters for this request generator: torch.Generator | None = None generated_tokens: int = 0 @@ -394,7 +384,7 @@ class RequestState: def _merge_kv_caches( caches_list: list[list[AnyCache]], -) -> list[BatchKVCache | BatchMambaCache]: +) -> list[BatchKVCache | ArraysCache]: """Merge multiple per-request caches into batched caches. Args: @@ -407,13 +397,20 @@ def _merge_kv_caches( return [] num_layers = len(caches_list[0]) - merged: list[BatchKVCache | BatchMambaCache] = [] + merged: list[BatchKVCache | ArraysCache] = [] for layer_idx in range(num_layers): layer_caches = [caches[layer_idx] for caches in caches_list] - if _is_mamba_cache(layer_caches[0]): - batch_cache = BatchMambaCache(layer_caches) - else: + if isinstance(layer_caches[0], ArraysCache): + arrays_caches: list[ArraysCache] = [] + for cache in layer_caches: + if not isinstance(cache, ArraysCache): + raise TypeError( + "Mixed cache types in a single layer: expected ArraysCache" + ) + arrays_caches.append(cache) + batch_cache = _merge_arrays_caches(arrays_caches) + else: # KV-like caches batch_cache = BatchKVCache.merge(layer_caches) merged.append(batch_cache) @@ -421,7 +418,7 @@ def _merge_kv_caches( def _extract_kv_cache( - batch_caches: list[BatchKVCache | BatchMambaCache], idx: int + batch_caches: list[BatchKVCache | ArraysCache], idx: int ) -> list[AnyCache]: """Extract a single request's cache from batched caches. @@ -432,7 +429,13 @@ def _extract_kv_cache( Returns: List of caches for the request, one per layer """ - return [cache.extract(idx) for cache in batch_caches] + extracted: list[AnyCache] = [] + for cache in batch_caches: + if isinstance(cache, ArraysCache): + extracted.append(_extract_arrays_cache(cache, idx)) + else: + extracted.append(cache.extract(idx)) + return extracted class MetalModelRunner: