Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions vllm_mlx/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,17 +249,33 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry:


def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
"""Create shallow copies of KVCache layers with offset reduced by *trim_by*.
"""Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced.

This is used when returning a cached KV state to the scheduler so that
the last N positions are "freed" and the model will recompute them on the
next forward pass (preventing duplicate KV entries).

Supports both KVCache (keys/values are arrays) and QuantizedKVCache
(keys/values are 3-tuples of arrays).
"""
from mlx_lm.models.cache import KVCache

try:
from mlx_lm.models.cache import QuantizedKVCache as _QKVCache
except ImportError:
_QKVCache = None

trimmed: list[Any] = []
for layer_cache in cache:
if (
if _QKVCache is not None and isinstance(layer_cache, _QKVCache):
tc = _QKVCache.__new__(_QKVCache)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
trimmed.append(tc)
elif (
hasattr(layer_cache, "offset")
and hasattr(layer_cache, "keys")
and not isinstance(layer_cache.keys, (list, tuple))
Expand Down