Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ classifiers = [

dependencies = [
# MLX - Required for Apple Silicon GPU acceleration
"mlx>=0.20.0; platform_system == 'Darwin' and platform_machine == 'arm64'",
# mlx-lm 0.30.6 removed MambaCache; pin until we support ArraysCache-only versions.
# See: https://github.com/vllm-project/vllm-metal/issues/100
"mlx-lm>=0.20.0,<0.30.6; platform_system == 'Darwin' and platform_machine == 'arm64'",
"mlx>=0.29.2; platform_system == 'Darwin' and platform_machine == 'arm64'",
"mlx-lm>=0.28.4; platform_system == 'Darwin' and platform_machine == 'arm64'",
"mlx-vlm>=0.3.0; platform_system == 'Darwin' and platform_machine == 'arm64'", # Vision-language model support
# Model loading and weights
"transformers>=4.40.0",
Expand Down
133 changes: 131 additions & 2 deletions tests/test_prefix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from unittest.mock import MagicMock

import mlx.core as mx
import pytest

import vllm_metal.v1.model_runner as mr


class StubMambaCache:
class StubArraysCache:
@property
def state(self):
return []
Expand All @@ -32,7 +33,7 @@ def fake_make_prompt_cache(model):
kv.keys = mx.zeros((1, 8, 0, 64))
kv.values = mx.zeros((1, 8, 0, 64))
kv.offset = 0
return [kv, StubMambaCache(), kv]
return [kv, StubArraysCache(), kv]

monkeypatch.setattr(mr, "make_prompt_cache", fake_make_prompt_cache)

Expand Down Expand Up @@ -92,6 +93,134 @@ def fake_make_prompt_cache(model):
insert_spy.assert_called_once()


class TestHybridCacheMergeExtract:
"""Regression tests for hybrid (KV + ArraysCache) batching.

Background:
- `mlx-lm==0.30.6` removed `MambaCache` and hybrid models now use `ArraysCache`.
- Older mlx-lm versions don't provide `ArraysCache.merge()` / `extract()`.

These tests validate that vllm-metal can merge per-request caches into a batched
cache, run a batched forward pass, and then extract per-request caches back,
without depending on `MambaCache` or new mlx-lm APIs.
"""

_ARRAYS_CACHE_ENTRIES = 2
_ARRAYS_CACHE_FEATURES = 4

_KV_NUM_HEADS = 1
_KV_HEAD_DIM = 2

def _make_arrays_cache(self, v0: float | None, v1: float | None) -> mr.ArraysCache:
cache = mr.ArraysCache(self._ARRAYS_CACHE_ENTRIES)
if v0 is not None:
cache[0] = mx.full((1, self._ARRAYS_CACHE_FEATURES), v0, dtype=mx.float32)
if v1 is not None:
cache[1] = mx.full((1, self._ARRAYS_CACHE_FEATURES), v1, dtype=mx.float32)
return cache

def _make_kv_cache(self, seq_len: int, value: float) -> mr.KVCache:
kv = mr.KVCache()
kv.keys = mx.full(
(1, self._KV_NUM_HEADS, seq_len, self._KV_HEAD_DIM),
value,
dtype=mx.float32,
)
kv.values = mx.full(
(1, self._KV_NUM_HEADS, seq_len, self._KV_HEAD_DIM),
value + 0.5,
dtype=mx.float32,
)
kv.offset = seq_len
return kv

def test_arrays_cache_merge_extract_roundtrip(self) -> None:
"""Merging then extracting ArraysCache round-trips per request."""
arrays_cache_req0 = self._make_arrays_cache(1.0, 11.0)
arrays_cache_req1 = self._make_arrays_cache(2.0, 22.0)

merged = mr._merge_kv_caches([[arrays_cache_req0], [arrays_cache_req1]])
extracted_req0 = mr._extract_kv_cache(merged, 0)[0]
extracted_req1 = mr._extract_kv_cache(merged, 1)[0]

assert isinstance(merged[0], mr.ArraysCache)
assert isinstance(extracted_req0, mr.ArraysCache)
assert isinstance(extracted_req1, mr.ArraysCache)
assert bool(mx.allclose(extracted_req0.state[0], arrays_cache_req0.state[0]))
assert bool(mx.allclose(extracted_req0.state[1], arrays_cache_req0.state[1]))
assert bool(mx.allclose(extracted_req1.state[0], arrays_cache_req1.state[0]))
assert bool(mx.allclose(extracted_req1.state[1], arrays_cache_req1.state[1]))

def test_arrays_cache_merge_extract_handles_missing_entries(self) -> None:
"""Missing per-request entries become zeros after merging.

ArraysCache merging densifies per-entry state into a batch array when at
least one request has that entry populated. Requests that had `None`
for the entry are represented as zeros in the merged state.
"""
arrays_cache_req0 = self._make_arrays_cache(1.0, 11.0)
arrays_cache_req1 = self._make_arrays_cache(2.0, None)

merged = mr._merge_kv_caches([[arrays_cache_req0], [arrays_cache_req1]])

extracted_req0 = mr._extract_kv_cache(merged, 0)[0]
extracted_req1 = mr._extract_kv_cache(merged, 1)[0]

assert isinstance(extracted_req0, mr.ArraysCache)
assert isinstance(extracted_req1, mr.ArraysCache)

assert bool(mx.allclose(extracted_req0.state[0], arrays_cache_req0.state[0]))
assert bool(mx.allclose(extracted_req0.state[1], arrays_cache_req0.state[1]))
assert bool(mx.allclose(extracted_req1.state[0], arrays_cache_req1.state[0]))

missing = extracted_req1.state[1]
assert missing is not None
assert missing.shape == (1, self._ARRAYS_CACHE_FEATURES)
assert bool(mx.allclose(missing, mx.zeros_like(missing)))

def test_mixed_kv_and_arrays_cache_merge_extract_roundtrip(self) -> None:
"""Merging/extracting preserves both KVCache and ArraysCache layers."""
kv_cache_req0 = self._make_kv_cache(seq_len=2, value=1.0)
kv_cache_req1 = self._make_kv_cache(seq_len=4, value=2.0)
arrays_cache_req0 = self._make_arrays_cache(3.0, 33.0)
arrays_cache_req1 = self._make_arrays_cache(4.0, 44.0)

merged = mr._merge_kv_caches(
[[kv_cache_req0, arrays_cache_req0], [kv_cache_req1, arrays_cache_req1]]
)
extracted_req0 = mr._extract_kv_cache(merged, 0)
extracted_req1 = mr._extract_kv_cache(merged, 1)

assert isinstance(merged[0], mr.BatchKVCache)
assert isinstance(merged[1], mr.ArraysCache)

kv_req0_out, arrays_req0_out = extracted_req0
kv_req1_out, arrays_req1_out = extracted_req1

assert isinstance(kv_req0_out, mr.KVCache)
assert isinstance(kv_req1_out, mr.KVCache)
assert isinstance(arrays_req0_out, mr.ArraysCache)
assert isinstance(arrays_req1_out, mr.ArraysCache)

assert kv_req0_out.offset == kv_cache_req0.offset
assert kv_req1_out.offset == kv_cache_req1.offset
assert bool(mx.allclose(kv_req0_out.keys, kv_cache_req0.keys))
assert bool(mx.allclose(kv_req0_out.values, kv_cache_req0.values))
assert bool(mx.allclose(kv_req1_out.keys, kv_cache_req1.keys))
assert bool(mx.allclose(kv_req1_out.values, kv_cache_req1.values))

assert bool(mx.allclose(arrays_req0_out.state[0], arrays_cache_req0.state[0]))
assert bool(mx.allclose(arrays_req0_out.state[1], arrays_cache_req0.state[1]))
assert bool(mx.allclose(arrays_req1_out.state[0], arrays_cache_req1.state[0]))
assert bool(mx.allclose(arrays_req1_out.state[1], arrays_cache_req1.state[1]))

def test_merge_kv_caches_rejects_mixed_cache_types_within_layer(self) -> None:
arrays_cache = self._make_arrays_cache(1.0, 2.0)
kv_cache = mr.KVCache()
with pytest.raises(TypeError, match="Mixed cache types in a single layer"):
mr._merge_kv_caches([[arrays_cache], [kv_cache]])


class TestPrefixCacheEviction:
def test_eviction_under_max_bytes(self) -> None:
# 1KB limit
Expand Down
115 changes: 59 additions & 56 deletions vllm_metal/v1/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
ArraysCache,
BatchKVCache,
KVCache,
MambaCache,
make_prompt_cache,
)

Expand Down Expand Up @@ -149,7 +148,7 @@ class CachedPrefix:
"""Cached KV state for a token prefix.

cache_state contains (k, v) tuples for KVCache layers, or None for
MambaCache/ArraysCache layers in hybrid models.
ArraysCache layers in hybrid models.
"""

token_ids: list[int]
Expand Down Expand Up @@ -207,8 +206,8 @@ def _evict_until_fits(self, needed_bytes: int) -> None:
def insert(self, token_ids: list[int], cache: list[KVCache]) -> None:
"""Insert a prefix cache entry with memory-based eviction.

Only KVCache layers are cached. MambaCache/ArraysCache layers are
skipped (stored as None) for hybrid model compatibility.
Only KVCache layers are cached. ArraysCache layers are skipped (stored as
None) for hybrid model compatibility.
"""
prefix_hash = _compute_prefix_hash(token_ids)
if prefix_hash in self._cache:
Expand Down Expand Up @@ -249,8 +248,8 @@ def restore_cache(
) -> list[KVCache]:
"""Restore a cached prefix to a fresh KVCache.

Only KVCache layers are restored. MambaCache/ArraysCache layers
remain in their fresh state for hybrid model compatibility.
Only KVCache layers are restored. ArraysCache layers remain in their
fresh state for hybrid model compatibility.
"""
cache_model = (
model.language_model
Expand Down Expand Up @@ -284,59 +283,50 @@ def get_stats(self) -> dict:


# Type alias for any cache type supported by the model
AnyCache: TypeAlias = KVCache | MambaCache | ArraysCache
AnyCache: TypeAlias = KVCache | ArraysCache


class BatchMambaCache:
"""Batched cache for Mamba/SSM layers.
def _merge_arrays_caches(caches: list[ArraysCache]) -> ArraysCache:
"""Merge per-request ArraysCache objects into a single batched ArraysCache.

Wraps multiple MambaCache instances into a single batched cache
for efficient batched forward passes on hybrid models.
This mirrors the behavior of `mlx_lm.models.cache.ArraysCache.merge` but is
implemented here for compatibility with older mlx-lm versions that do not
provide `merge()` / `extract()`.
"""
if not caches:
raise ValueError("caches must be non-empty")

def __init__(self, caches: list[MambaCache | ArraysCache]):
"""Create a batched Mamba cache from individual caches.
num_entries = len(caches[0].state)
batch_size = len(caches)

Args:
caches: List of MambaCache instances to batch
"""
self._batch_size = len(caches)
self._cache_size = len(caches[0].cache) if caches else 0

# Stack each state array across the batch dimension
self.cache: list[mx.array | None] = []
for i in range(self._cache_size):
states = [c.cache[i] for c in caches]
if all(s is not None for s in states):
self.cache.append(mx.concatenate(states, axis=0))
else:
self.cache.append(None)

def __getitem__(self, idx: int) -> mx.array | None:
return self.cache[idx]
merged = ArraysCache(num_entries)
for entry_idx in range(num_entries):
values = [cache.state[entry_idx] for cache in caches]
template = next((value for value in values if value is not None), None)
if template is None:
continue

def __setitem__(self, idx: int, value: mx.array | None) -> None:
self.cache[idx] = value
shape = list(template.shape)
shape[0] = batch_size
merged_state = mx.zeros(tuple(shape), template.dtype)
for batch_idx, value in enumerate(values):
if value is None:
continue
merged_state[batch_idx : batch_idx + 1] = value

def extract(self, idx: int) -> MambaCache:
"""Extract a single request's cache from the batch.
merged[entry_idx] = merged_state

Args:
idx: Index of the request in the batch

Returns:
MambaCache for the individual request
"""
cache = MambaCache()
for i in range(self._cache_size):
if self.cache[i] is not None:
cache.cache[i] = self.cache[i][idx : idx + 1]
return cache
return merged


def _is_mamba_cache(cache: AnyCache) -> bool:
"""Check if a cache is a Mamba-style cache (ArraysCache or MambaCache)."""
return isinstance(cache, (MambaCache, ArraysCache))
def _extract_arrays_cache(batch_cache: ArraysCache, idx: int) -> ArraysCache:
"""Extract a single request's ArraysCache from a batched ArraysCache."""
state = batch_cache.state
extracted = ArraysCache(len(state))
extracted.state = [
None if value is None else value[idx : idx + 1] for value in state
]
return extracted


def _mlx_greedy_sample(logits: mx.array) -> mx.array:
Expand Down Expand Up @@ -386,15 +376,15 @@ class RequestState:
# vLLM applies repetition penalties to both prompt+output tokens, but applies
# presence/frequency penalties only to generated (output) tokens.
prompt_len: int
cache: list[AnyCache] # Per-layer caches (KVCache or MambaCache for hybrid models)
cache: list[AnyCache] # Per-layer caches (KVCache or ArraysCache for hybrid models)
sampling_params: SamplingParams # Sampling parameters for this request
generator: torch.Generator | None = None
generated_tokens: int = 0


def _merge_kv_caches(
caches_list: list[list[AnyCache]],
) -> list[BatchKVCache | BatchMambaCache]:
) -> list[BatchKVCache | ArraysCache]:
"""Merge multiple per-request caches into batched caches.

Args:
Expand All @@ -407,21 +397,28 @@ def _merge_kv_caches(
return []

num_layers = len(caches_list[0])
merged: list[BatchKVCache | BatchMambaCache] = []
merged: list[BatchKVCache | ArraysCache] = []

for layer_idx in range(num_layers):
layer_caches = [caches[layer_idx] for caches in caches_list]
if _is_mamba_cache(layer_caches[0]):
batch_cache = BatchMambaCache(layer_caches)
else:
if isinstance(layer_caches[0], ArraysCache):
arrays_caches: list[ArraysCache] = []
for cache in layer_caches:
if not isinstance(cache, ArraysCache):
raise TypeError(
"Mixed cache types in a single layer: expected ArraysCache"
)
arrays_caches.append(cache)
batch_cache = _merge_arrays_caches(arrays_caches)
else: # KV-like caches
batch_cache = BatchKVCache.merge(layer_caches)
merged.append(batch_cache)

return merged


def _extract_kv_cache(
batch_caches: list[BatchKVCache | BatchMambaCache], idx: int
batch_caches: list[BatchKVCache | ArraysCache], idx: int
) -> list[AnyCache]:
"""Extract a single request's cache from batched caches.

Expand All @@ -432,7 +429,13 @@ def _extract_kv_cache(
Returns:
List of caches for the request, one per layer
"""
return [cache.extract(idx) for cache in batch_caches]
extracted: list[AnyCache] = []
for cache in batch_caches:
if isinstance(cache, ArraysCache):
extracted.append(_extract_arrays_cache(cache, idx))
else:
extracted.append(cache.extract(idx))
return extracted


class MetalModelRunner:
Expand Down