diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index ab1e975e..ac9aeb18 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -141,9 +141,18 @@ def serve_command(args): max_cache_blocks=args.max_cache_blocks, # Chunked prefill chunked_prefill_tokens=args.chunked_prefill_tokens, + # KV cache quantization + kv_bits=args.kv_cache_bits, + kv_group_size=args.kv_cache_group_size, ) print("Mode: Continuous batching (for multiple concurrent users)") + if args.kv_cache_bits: + savings = "~75%" if args.kv_cache_bits == 4 else "~50%" + print( + f"KV cache quantization: {args.kv_cache_bits}-bit " + f"(group_size={args.kv_cache_group_size}, {savings} memory savings)" + ) if args.chunked_prefill_tokens > 0: print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step") print(f"Stream interval: {args.stream_interval} tokens") @@ -506,6 +515,21 @@ def main(): default=1000, help="Maximum number of cache blocks (default: 1000)", ) + # KV cache quantization + serve_parser.add_argument( + "--kv-cache-bits", + type=int, + default=None, + choices=[4, 8], + help="Quantize KV cache in prefix cache to reduce memory (4=~75%% savings, 8=~50%% savings). " + "Default: disabled (full precision).", + ) + serve_parser.add_argument( + "--kv-cache-group-size", + type=int, + default=64, + help="Group size for KV cache quantization (default: 64)", + ) # Chunked prefill serve_parser.add_argument( "--chunked-prefill-tokens", diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index 902f33f7..3580f56b 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -84,6 +84,13 @@ def _array_memory(arr) -> int: return 0 +def _tuple_memory(t) -> int: + """Estimate memory of a value that may be a list/tuple of arrays (quantized) or a single array.""" + if isinstance(t, (tuple, list)): + return sum(_array_memory(a) for a in t) + return _array_memory(t) + + def estimate_kv_cache_memory(cache: list[Any]) -> int: """ Estimate memory usage of a KV cache in bytes. @@ -92,6 +99,9 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: total memory footprint using shape+dtype metadata to avoid triggering lazy evaluation (which would cause a VRAM spike). + Supports both standard KVCache (keys/values are arrays) and + QuantizedKVCache (keys/values are 3-tuples of arrays). + Args: cache: List of layer cache objects, each containing keys/values tensors. @@ -107,27 +117,28 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: # Handle different cache object types # Check dict first since dicts have .keys() method that would match below if isinstance(layer_cache, dict) and "state" in layer_cache: - # Extracted state dict + # Extracted state dict — state is (keys, values) where each may be + # a tuple of arrays (QuantizedKVCache) or a single array (KVCache) keys, values = layer_cache["state"] - total_bytes += _array_memory(keys) - total_bytes += _array_memory(values) + total_bytes += _tuple_memory(keys) + total_bytes += _tuple_memory(values) elif hasattr(layer_cache, "state") and not isinstance(layer_cache, dict): # Cache with state property returning (keys, values) try: keys, values = layer_cache.state - total_bytes += _array_memory(keys) - total_bytes += _array_memory(values) + total_bytes += _tuple_memory(keys) + total_bytes += _tuple_memory(values) except (TypeError, ValueError): pass elif hasattr(layer_cache, "keys") and hasattr(layer_cache, "values"): - # Standard KVCache with keys/values attributes (not dict) + # Standard KVCache or QuantizedKVCache with keys/values attributes keys_attr = layer_cache.keys values_attr = layer_cache.values - # Ensure these are arrays, not methods + # Ensure these are arrays/tuples, not methods if not callable(keys_attr): - total_bytes += _array_memory(keys_attr) + total_bytes += _tuple_memory(keys_attr) if not callable(values_attr): - total_bytes += _array_memory(values_attr) + total_bytes += _tuple_memory(values_attr) return total_bytes @@ -234,17 +245,35 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: - """Create shallow copies of KVCache layers with offset reduced by *trim_by*. + """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced. This is used when returning a cached KV state to the scheduler so that the last N positions are "freed" and the model will recompute them on the next forward pass (preventing duplicate KV entries). + + Supports both KVCache (keys/values are arrays) and QuantizedKVCache + (keys/values are 3-tuples of arrays). """ from mlx_lm.models.cache import KVCache + try: + from mlx_lm.models.cache import ( + QuantizedKVCache as _QKVCache, # noqa: N812 + ) + except ImportError: + _QKVCache = None # noqa: N806 + trimmed: list[Any] = [] for layer_cache in cache: - if hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"): + if _QKVCache is not None and isinstance(layer_cache, _QKVCache): + tc = _QKVCache.__new__(_QKVCache) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc.offset = max(layer_cache.offset - trim_by, 0) + tc.group_size = layer_cache.group_size + tc.bits = layer_cache.bits + trimmed.append(tc) + elif hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"): tc = KVCache.__new__(KVCache) tc.keys = layer_cache.keys tc.values = layer_cache.values @@ -255,6 +284,73 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: return trimmed +def quantize_kv_cache( + cache: list[Any], group_size: int = 64, bits: int = 8 +) -> list[Any]: + """Quantize KVCache layers to QuantizedKVCache for memory-efficient storage. + + Converts each KVCache layer to QuantizedKVCache using mlx-lm's + ``to_quantized()`` method. Non-KVCache layers (e.g. MambaCache) are + left unchanged. + + Args: + cache: List of cache layer objects (KVCache, MambaCache, etc.) + group_size: Quantization group size (default 64) + bits: Bits per element — 4 or 8 (default 8) + + Returns: + New list with KVCache layers replaced by QuantizedKVCache. + """ + result = [] + for layer_cache in cache: + if hasattr(layer_cache, "to_quantized"): + result.append(layer_cache.to_quantized(group_size=group_size, bits=bits)) + else: + result.append(layer_cache) + return result + + +def dequantize_kv_cache(cache: list[Any]) -> list[Any]: + """Dequantize QuantizedKVCache layers back to KVCache. + + Converts each QuantizedKVCache layer back to a standard KVCache so it + can be used with BatchGenerator (which does not support quantized caches). + Non-quantized layers are left unchanged. + + Args: + cache: List of cache layer objects (QuantizedKVCache, KVCache, etc.) + + Returns: + New list with QuantizedKVCache layers replaced by KVCache. + """ + import mlx.core as mx + from mlx_lm.models.cache import KVCache, QuantizedKVCache + + result = [] + for layer_cache in cache: + if isinstance(layer_cache, QuantizedKVCache): + kv = KVCache() + # Dequantize keys and values from 3-tuple format + keys_q = layer_cache.keys + values_q = layer_cache.values + if keys_q is not None and values_q is not None: + kv.keys = mx.dequantize( + *keys_q, + group_size=layer_cache.group_size, + bits=layer_cache.bits, + ) + kv.values = mx.dequantize( + *values_q, + group_size=layer_cache.group_size, + bits=layer_cache.bits, + ) + kv.offset = layer_cache.offset + result.append(kv) + else: + result.append(layer_cache) + return result + + class MemoryAwarePrefixCache: """ Prefix cache with memory-based eviction. diff --git a/vllm_mlx/scheduler.py b/vllm_mlx/scheduler.py index 26ef5315..a8238cbf 100644 --- a/vllm_mlx/scheduler.py +++ b/vllm_mlx/scheduler.py @@ -89,6 +89,12 @@ class SchedulerConfig: # 0 = disabled. Only effective when chunked_prefill_tokens > 0. mid_prefill_save_interval: int = 8192 + # KV cache quantization: quantize cached KV states to reduce memory usage. + # None = disabled (default), 4 = 4-bit (~75% savings), 8 = 8-bit (~50% savings). + # Only affects prefix cache storage — BatchGenerator uses full-precision KV cache. + kv_bits: Optional[int] = None + kv_group_size: int = 64 + @dataclass class SchedulerOutput: @@ -573,6 +579,14 @@ def __init__( f"Prefix cache enabled with max_entries={self.config.prefix_cache_size}" ) + # KV cache quantization + if self.config.kv_bits is not None: + savings = "~75%" if self.config.kv_bits == 4 else "~50%" + logger.info( + f"KV cache quantization enabled: {self.config.kv_bits}-bit, " + f"group_size={self.config.kv_group_size} ({savings} memory savings)" + ) + # Thread-safe set for deferred aborts (main thread → executor thread) # CPython GIL guarantees set.add() and `x in set` are atomic. self._pending_abort_ids: Set[str] = set() @@ -718,12 +732,24 @@ def _prompt_cache_save(uid, extracted_cache): return prompt_tokens = list(request.prompt_token_ids) + + # Quantize cache before storing if KV quantization enabled + cache_to_store = extracted_cache + if self.config.kv_bits is not None: + from .memory_cache import quantize_kv_cache + + cache_to_store = quantize_kv_cache( + cache_to_store, + group_size=self.config.kv_group_size, + bits=self.config.kv_bits, + ) + _t0 = _time.monotonic() # evict_prefixes=False: keep mid-prefill boundary entries so # that future requests with the same prefix but different # suffix get a prefix cache hit (critical for agentic multi-turn). stored = self.memory_aware_cache.store( - prompt_tokens, extracted_cache, evict_prefixes=False + prompt_tokens, cache_to_store, evict_prefixes=False ) _dt = _time.monotonic() - _t0 if stored: @@ -777,6 +803,17 @@ def _mid_prefill_save(uid, processed_tokens, prompt_cache): if not reconstructed: return + # Quantize cache before storing if KV quantization enabled + cache_to_store = reconstructed + if self.config.kv_bits is not None: + from .memory_cache import quantize_kv_cache + + cache_to_store = quantize_kv_cache( + cache_to_store, + group_size=self.config.kv_group_size, + bits=self.config.kv_bits, + ) + prefix_tokens = list(request.prompt_token_ids[:total_cached]) # Remove previous intermediate entry to avoid memory waste @@ -785,7 +822,7 @@ def _mid_prefill_save(uid, processed_tokens, prompt_cache): self.memory_aware_cache.remove(list(old_key)) _t0 = _time.monotonic() - stored = self.memory_aware_cache.store(prefix_tokens, reconstructed) + stored = self.memory_aware_cache.store(prefix_tokens, cache_to_store) _dt = _time.monotonic() - _t0 if stored: @@ -1095,6 +1132,11 @@ def add_request(self, request: Request) -> None: _fetch_dt = _time.monotonic() - _fetch_t0 request.cache_hit_type = self.memory_aware_cache._last_match_type if cache: + # Dequantize if KV cache quantization is active + if self.config.kv_bits is not None: + from .memory_cache import dequantize_kv_cache + + cache = dequantize_kv_cache(cache) request.prompt_cache = cache request.cached_tokens = len(request.prompt_token_ids) - len(remaining) request.remaining_tokens = remaining @@ -1115,6 +1157,11 @@ def add_request(self, request: Request) -> None: # Use legacy prefix cache cache, remaining = self.prefix_cache.fetch_cache(request.prompt_token_ids) if cache: + # Dequantize if KV cache quantization is active + if self.config.kv_bits is not None: + from .memory_cache import dequantize_kv_cache + + cache = dequantize_kv_cache(cache) request.cache_hit_type = "hit" request.prompt_cache = cache request.cached_tokens = len(request.prompt_token_ids) - len(remaining) @@ -1487,12 +1534,23 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: full_token_sequence = list(request.prompt_token_ids) + list( request.output_token_ids ) + # Quantize cache before storing if KV quantization enabled + cache_to_store = request._extracted_cache + if self.config.kv_bits is not None: + from .memory_cache import quantize_kv_cache + + cache_to_store = quantize_kv_cache( + cache_to_store, + group_size=self.config.kv_group_size, + bits=self.config.kv_bits, + ) + import time as _time _store_t0 = _time.monotonic() stored = self.memory_aware_cache.store( full_token_sequence, - request._extracted_cache, + cache_to_store, evict_prefixes=False, ) _store_dt = _time.monotonic() - _store_t0 @@ -1532,9 +1590,19 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None: full_token_sequence = list(request.prompt_token_ids) + list( request.output_token_ids ) + # Quantize cache before storing if KV quantization enabled + cache_to_store = request._extracted_cache + if self.config.kv_bits is not None: + from .memory_cache import quantize_kv_cache + + cache_to_store = quantize_kv_cache( + cache_to_store, + group_size=self.config.kv_group_size, + bits=self.config.kv_bits, + ) self.prefix_cache.store_cache( full_token_sequence, - request._extracted_cache, + cache_to_store, ) logger.debug( f"Stored cache for request {request_id} "