diff --git a/vllm_mlx/cli.py b/vllm_mlx/cli.py index f62d42397..e9087105b 100644 --- a/vllm_mlx/cli.py +++ b/vllm_mlx/cli.py @@ -569,6 +569,34 @@ def bench_kv_cache_command(args): ) +def _add_kv_cache_quantization_args(parser: argparse.ArgumentParser) -> None: + """Add KV cache quantization arguments to an argparse parser.""" + parser.add_argument( + "--kv-cache-quantization", + action="store_true", + help="Quantize stored KV caches to reduce memory (8-bit by default)", + ) + parser.add_argument( + "--kv-cache-quantization-bits", + type=int, + default=8, + choices=[4, 8], + help="Bit width for KV cache quantization (default: 8)", + ) + parser.add_argument( + "--kv-cache-quantization-group-size", + type=int, + default=64, + help="Group size for KV cache quantization (default: 64)", + ) + parser.add_argument( + "--kv-cache-min-quantize-tokens", + type=int, + default=256, + help="Minimum tokens for quantization to apply (default: 256)", + ) + + def main(): parser = argparse.ArgumentParser( description="vllm-mlx: Apple Silicon MLX backend for vLLM", @@ -633,30 +661,7 @@ def main(): help="Disable memory-aware cache, use legacy entry-count based cache", ) # KV cache quantization options - serve_parser.add_argument( - "--kv-cache-quantization", - action="store_true", - help="Quantize stored KV caches to reduce memory (8-bit by default)", - ) - serve_parser.add_argument( - "--kv-cache-quantization-bits", - type=int, - default=8, - choices=[4, 8], - help="Bit width for KV cache quantization (default: 8)", - ) - serve_parser.add_argument( - "--kv-cache-quantization-group-size", - type=int, - default=64, - help="Group size for KV cache quantization (default: 64)", - ) - serve_parser.add_argument( - "--kv-cache-min-quantize-tokens", - type=int, - default=256, - help="Minimum tokens for quantization to apply (default: 256)", - ) + _add_kv_cache_quantization_args(serve_parser) serve_parser.add_argument( "--stream-interval", type=int, @@ -847,30 +852,7 @@ def main(): help="Disable memory-aware cache, use legacy entry-count based cache", ) # KV cache quantization options - bench_parser.add_argument( - "--kv-cache-quantization", - action="store_true", - help="Quantize stored KV caches to reduce memory (8-bit by default)", - ) - bench_parser.add_argument( - "--kv-cache-quantization-bits", - type=int, - default=8, - choices=[4, 8], - help="Bit width for KV cache quantization (default: 8)", - ) - bench_parser.add_argument( - "--kv-cache-quantization-group-size", - type=int, - default=64, - help="Group size for KV cache quantization (default: 64)", - ) - bench_parser.add_argument( - "--kv-cache-min-quantize-tokens", - type=int, - default=256, - help="Minimum tokens for quantization to apply (default: 256)", - ) + _add_kv_cache_quantization_args(bench_parser) # Paged cache options (experimental) bench_parser.add_argument( "--use-paged-cache", diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f43763541..01856a181 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -98,6 +98,11 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: Returns: Estimated memory usage in bytes. """ + try: + from mlx_lm.models.cache import QuantizedKVCache + except ImportError: + QuantizedKVCache = None # noqa: N806 + if not cache: return 0 @@ -112,9 +117,7 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int: total_bytes += _array_memory(keys) total_bytes += _array_memory(values) # Handle QuantizedKVCache: keys/values are tuples of (data, scales, biases) - elif hasattr(layer_cache, "keys") and isinstance( - getattr(layer_cache, "keys", None), (list, tuple) - ): + elif QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache): for arr in layer_cache.keys: total_bytes += _array_memory(arr) for arr in layer_cache.values: @@ -274,12 +277,12 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: trimmed: list[Any] = [] for layer_cache in cache: if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache): - tc = QuantizedKVCache.__new__(QuantizedKVCache) + tc = QuantizedKVCache( + group_size=layer_cache.group_size, bits=layer_cache.bits + ) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) - tc.group_size = layer_cache.group_size - tc.bits = layer_cache.bits trimmed.append(tc) elif ( hasattr(layer_cache, "offset") @@ -406,6 +409,12 @@ class MemoryAwarePrefixCache: This class is NOT thread-safe. Use external locking if needed. """ + def _maybe_dequantize(self, cache: list[Any]) -> list[Any]: + """Return dequantized cache if KV quantization is enabled, else pass through.""" + if self._config.kv_quantize: + return _dequantize_cache(cache) + return cache + def __init__( self, model: Any, @@ -479,12 +488,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._stats.hits += 1 self._stats.tokens_saved += len(tokens) self._last_match_type = "exact" - cache_out = ( - _dequantize_cache(entry.cache) - if self._config.kv_quantize - else entry.cache - ) - return cache_out, [] + return self._maybe_dequantize(entry.cache), [] # --- O(log N) prefix & supersequence match via sorted index --- best_match: _CacheEntry | None = None @@ -554,23 +558,13 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._stats.hits += 1 self._stats.tokens_saved += n_requested self._last_match_type = "supersequence" - trimmed_cache = ( - _dequantize_cache(trimmed_cache) - if self._config.kv_quantize - else trimmed_cache - ) - return trimmed_cache, [] + return self._maybe_dequantize(trimmed_cache), [] else: self._entries.move_to_end(best_super.tokens) self._stats.hits += 1 self._stats.tokens_saved += n_requested self._last_match_type = "supersequence" - cache_out = ( - _dequantize_cache(best_super.cache) - if self._config.kv_quantize - else best_super.cache - ) - return cache_out, [] + return self._maybe_dequantize(best_super.cache), [] # --- Prefix match --- if best_match is not None: @@ -579,12 +573,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: self._stats.tokens_saved += best_length remaining = tokens[best_length:] self._last_match_type = "prefix" - cache_out = ( - _dequantize_cache(best_match.cache) - if self._config.kv_quantize - else best_match.cache - ) - return cache_out, remaining + return self._maybe_dequantize(best_match.cache), remaining # --- LCP (Longest Common Prefix) for divergent sequences --- # This handles the agentic pattern: same system+context prefix @@ -646,12 +635,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]: f"trimmed={excess} remaining={len(remaining)}" ) self._last_match_type = "lcp" - trimmed_cache = ( - _dequantize_cache(trimmed_cache) - if self._config.kv_quantize - else trimmed_cache - ) - return trimmed_cache, remaining + return self._maybe_dequantize(trimmed_cache), remaining self._stats.misses += 1 self._last_match_type = "miss"