diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index 8dd9cf25..64667995 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -249,17 +249,33 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: - """Create shallow copies of KVCache layers with offset reduced by *trim_by*. + """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced. This is used when returning a cached KV state to the scheduler so that the last N positions are "freed" and the model will recompute them on the next forward pass (preventing duplicate KV entries). + + Supports both KVCache (keys/values are arrays) and QuantizedKVCache + (keys/values are 3-tuples of arrays). """ from mlx_lm.models.cache import KVCache + try: + from mlx_lm.models.cache import QuantizedKVCache as _QKVCache + except ImportError: + _QKVCache = None + trimmed: list[Any] = [] for layer_cache in cache: - if ( + if _QKVCache is not None and isinstance(layer_cache, _QKVCache): + tc = _QKVCache.__new__(_QKVCache) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc.offset = max(layer_cache.offset - trim_by, 0) + tc.group_size = layer_cache.group_size + tc.bits = layer_cache.bits + trimmed.append(tc) + elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") and not isinstance(layer_cache.keys, (list, tuple))