Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 30 additions & 48 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,34 @@ def bench_kv_cache_command(args):
)


def _add_kv_cache_quantization_args(parser: argparse.ArgumentParser) -> None:
"""Add KV cache quantization arguments to an argparse parser."""
parser.add_argument(
"--kv-cache-quantization",
action="store_true",
help="Quantize stored KV caches to reduce memory (8-bit by default)",
)
parser.add_argument(
"--kv-cache-quantization-bits",
type=int,
default=8,
choices=[4, 8],
help="Bit width for KV cache quantization (default: 8)",
)
parser.add_argument(
"--kv-cache-quantization-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)
parser.add_argument(
"--kv-cache-min-quantize-tokens",
type=int,
default=256,
help="Minimum tokens for quantization to apply (default: 256)",
)


def main():
parser = argparse.ArgumentParser(
description="vllm-mlx: Apple Silicon MLX backend for vLLM",
Expand Down Expand Up @@ -633,30 +661,7 @@ def main():
help="Disable memory-aware cache, use legacy entry-count based cache",
)
# KV cache quantization options
serve_parser.add_argument(
"--kv-cache-quantization",
action="store_true",
help="Quantize stored KV caches to reduce memory (8-bit by default)",
)
serve_parser.add_argument(
"--kv-cache-quantization-bits",
type=int,
default=8,
choices=[4, 8],
help="Bit width for KV cache quantization (default: 8)",
)
serve_parser.add_argument(
"--kv-cache-quantization-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)
serve_parser.add_argument(
"--kv-cache-min-quantize-tokens",
type=int,
default=256,
help="Minimum tokens for quantization to apply (default: 256)",
)
_add_kv_cache_quantization_args(serve_parser)
serve_parser.add_argument(
"--stream-interval",
type=int,
Expand Down Expand Up @@ -847,30 +852,7 @@ def main():
help="Disable memory-aware cache, use legacy entry-count based cache",
)
# KV cache quantization options
bench_parser.add_argument(
"--kv-cache-quantization",
action="store_true",
help="Quantize stored KV caches to reduce memory (8-bit by default)",
)
bench_parser.add_argument(
"--kv-cache-quantization-bits",
type=int,
default=8,
choices=[4, 8],
help="Bit width for KV cache quantization (default: 8)",
)
bench_parser.add_argument(
"--kv-cache-quantization-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)
bench_parser.add_argument(
"--kv-cache-min-quantize-tokens",
type=int,
default=256,
help="Minimum tokens for quantization to apply (default: 256)",
)
_add_kv_cache_quantization_args(bench_parser)
# Paged cache options (experimental)
bench_parser.add_argument(
"--use-paged-cache",
Expand Down
56 changes: 20 additions & 36 deletions vllm_mlx/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
Returns:
Estimated memory usage in bytes.
"""
try:
from mlx_lm.models.cache import QuantizedKVCache
except ImportError:
QuantizedKVCache = None # noqa: N806

if not cache:
return 0

Expand All @@ -112,9 +117,7 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
total_bytes += _array_memory(keys)
total_bytes += _array_memory(values)
# Handle QuantizedKVCache: keys/values are tuples of (data, scales, biases)
elif hasattr(layer_cache, "keys") and isinstance(
getattr(layer_cache, "keys", None), (list, tuple)
):
elif QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):
for arr in layer_cache.keys:
total_bytes += _array_memory(arr)
for arr in layer_cache.values:
Expand Down Expand Up @@ -274,12 +277,12 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
trimmed: list[Any] = []
for layer_cache in cache:
if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):
tc = QuantizedKVCache.__new__(QuantizedKVCache)
tc = QuantizedKVCache(
group_size=layer_cache.group_size, bits=layer_cache.bits
)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
trimmed.append(tc)
elif (
hasattr(layer_cache, "offset")
Expand Down Expand Up @@ -406,6 +409,12 @@ class MemoryAwarePrefixCache:
This class is NOT thread-safe. Use external locking if needed.
"""

def _maybe_dequantize(self, cache: list[Any]) -> list[Any]:
"""Return dequantized cache if KV quantization is enabled, else pass through."""
if self._config.kv_quantize:
return _dequantize_cache(cache)
return cache

def __init__(
self,
model: Any,
Expand Down Expand Up @@ -479,12 +488,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
self._stats.hits += 1
self._stats.tokens_saved += len(tokens)
self._last_match_type = "exact"
cache_out = (
_dequantize_cache(entry.cache)
if self._config.kv_quantize
else entry.cache
)
return cache_out, []
return self._maybe_dequantize(entry.cache), []

# --- O(log N) prefix & supersequence match via sorted index ---
best_match: _CacheEntry | None = None
Expand Down Expand Up @@ -554,23 +558,13 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
self._stats.hits += 1
self._stats.tokens_saved += n_requested
self._last_match_type = "supersequence"
trimmed_cache = (
_dequantize_cache(trimmed_cache)
if self._config.kv_quantize
else trimmed_cache
)
return trimmed_cache, []
return self._maybe_dequantize(trimmed_cache), []
else:
self._entries.move_to_end(best_super.tokens)
self._stats.hits += 1
self._stats.tokens_saved += n_requested
self._last_match_type = "supersequence"
cache_out = (
_dequantize_cache(best_super.cache)
if self._config.kv_quantize
else best_super.cache
)
return cache_out, []
return self._maybe_dequantize(best_super.cache), []

# --- Prefix match ---
if best_match is not None:
Expand All @@ -579,12 +573,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
self._stats.tokens_saved += best_length
remaining = tokens[best_length:]
self._last_match_type = "prefix"
cache_out = (
_dequantize_cache(best_match.cache)
if self._config.kv_quantize
else best_match.cache
)
return cache_out, remaining
return self._maybe_dequantize(best_match.cache), remaining

# --- LCP (Longest Common Prefix) for divergent sequences ---
# This handles the agentic pattern: same system+context prefix
Expand Down Expand Up @@ -646,12 +635,7 @@ def fetch(self, tokens: list[int]) -> tuple[list[Any] | None, list[int]]:
f"trimmed={excess} remaining={len(remaining)}"
)
self._last_match_type = "lcp"
trimmed_cache = (
_dequantize_cache(trimmed_cache)
if self._config.kv_quantize
else trimmed_cache
)
return trimmed_cache, remaining
return self._maybe_dequantize(trimmed_cache), remaining

self._stats.misses += 1
self._last_match_type = "miss"
Expand Down