Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,18 @@ def serve_command(args):
max_cache_blocks=args.max_cache_blocks,
# Chunked prefill
chunked_prefill_tokens=args.chunked_prefill_tokens,
# KV cache quantization
kv_bits=args.kv_cache_bits,
kv_group_size=args.kv_cache_group_size,
)

print("Mode: Continuous batching (for multiple concurrent users)")
if args.kv_cache_bits:
savings = "~75%" if args.kv_cache_bits == 4 else "~50%"
print(
f"KV cache quantization: {args.kv_cache_bits}-bit "
f"(group_size={args.kv_cache_group_size}, {savings} memory savings)"
)
if args.chunked_prefill_tokens > 0:
print(f"Chunked prefill: {args.chunked_prefill_tokens} tokens per step")
print(f"Stream interval: {args.stream_interval} tokens")
Expand Down Expand Up @@ -506,6 +515,21 @@ def main():
default=1000,
help="Maximum number of cache blocks (default: 1000)",
)
# KV cache quantization
serve_parser.add_argument(
"--kv-cache-bits",
type=int,
default=None,
choices=[4, 8],
help="Quantize KV cache in prefix cache to reduce memory (4=~75%% savings, 8=~50%% savings). "
"Default: disabled (full precision).",
)
serve_parser.add_argument(
"--kv-cache-group-size",
type=int,
default=64,
help="Group size for KV cache quantization (default: 64)",
)
# Chunked prefill
serve_parser.add_argument(
"--chunked-prefill-tokens",
Expand Down
118 changes: 107 additions & 11 deletions vllm_mlx/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def _array_memory(arr) -> int:
return 0


def _tuple_memory(t) -> int:
"""Estimate memory of a value that may be a list/tuple of arrays (quantized) or a single array."""
if isinstance(t, (tuple, list)):
return sum(_array_memory(a) for a in t)
return _array_memory(t)


def estimate_kv_cache_memory(cache: list[Any]) -> int:
"""
Estimate memory usage of a KV cache in bytes.
Expand All @@ -92,6 +99,9 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
total memory footprint using shape+dtype metadata to avoid triggering
lazy evaluation (which would cause a VRAM spike).

Supports both standard KVCache (keys/values are arrays) and
QuantizedKVCache (keys/values are 3-tuples of arrays).

Args:
cache: List of layer cache objects, each containing keys/values tensors.

Expand All @@ -107,27 +117,28 @@ def estimate_kv_cache_memory(cache: list[Any]) -> int:
# Handle different cache object types
# Check dict first since dicts have .keys() method that would match below
if isinstance(layer_cache, dict) and "state" in layer_cache:
# Extracted state dict
# Extracted state dict — state is (keys, values) where each may be
# a tuple of arrays (QuantizedKVCache) or a single array (KVCache)
keys, values = layer_cache["state"]
total_bytes += _array_memory(keys)
total_bytes += _array_memory(values)
total_bytes += _tuple_memory(keys)
total_bytes += _tuple_memory(values)
elif hasattr(layer_cache, "state") and not isinstance(layer_cache, dict):
# Cache with state property returning (keys, values)
try:
keys, values = layer_cache.state
total_bytes += _array_memory(keys)
total_bytes += _array_memory(values)
total_bytes += _tuple_memory(keys)
total_bytes += _tuple_memory(values)
except (TypeError, ValueError):
pass
elif hasattr(layer_cache, "keys") and hasattr(layer_cache, "values"):
# Standard KVCache with keys/values attributes (not dict)
# Standard KVCache or QuantizedKVCache with keys/values attributes
keys_attr = layer_cache.keys
values_attr = layer_cache.values
# Ensure these are arrays, not methods
# Ensure these are arrays/tuples, not methods
if not callable(keys_attr):
total_bytes += _array_memory(keys_attr)
total_bytes += _tuple_memory(keys_attr)
if not callable(values_attr):
total_bytes += _array_memory(values_attr)
total_bytes += _tuple_memory(values_attr)

return total_bytes

Expand Down Expand Up @@ -234,17 +245,35 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry:


def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
"""Create shallow copies of KVCache layers with offset reduced by *trim_by*.
"""Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced.

This is used when returning a cached KV state to the scheduler so that
the last N positions are "freed" and the model will recompute them on the
next forward pass (preventing duplicate KV entries).

Supports both KVCache (keys/values are arrays) and QuantizedKVCache
(keys/values are 3-tuples of arrays).
"""
from mlx_lm.models.cache import KVCache

try:
from mlx_lm.models.cache import (
QuantizedKVCache as _QKVCache, # noqa: N812
)
except ImportError:
_QKVCache = None # noqa: N806

trimmed: list[Any] = []
for layer_cache in cache:
if hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"):
if _QKVCache is not None and isinstance(layer_cache, _QKVCache):
tc = _QKVCache.__new__(_QKVCache)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
trimmed.append(tc)
elif hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys"):
tc = KVCache.__new__(KVCache)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
Expand All @@ -255,6 +284,73 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
return trimmed


def quantize_kv_cache(
cache: list[Any], group_size: int = 64, bits: int = 8
) -> list[Any]:
"""Quantize KVCache layers to QuantizedKVCache for memory-efficient storage.

Converts each KVCache layer to QuantizedKVCache using mlx-lm's
``to_quantized()`` method. Non-KVCache layers (e.g. MambaCache) are
left unchanged.

Args:
cache: List of cache layer objects (KVCache, MambaCache, etc.)
group_size: Quantization group size (default 64)
bits: Bits per element — 4 or 8 (default 8)

Returns:
New list with KVCache layers replaced by QuantizedKVCache.
"""
result = []
for layer_cache in cache:
if hasattr(layer_cache, "to_quantized"):
result.append(layer_cache.to_quantized(group_size=group_size, bits=bits))
else:
result.append(layer_cache)
return result


def dequantize_kv_cache(cache: list[Any]) -> list[Any]:
"""Dequantize QuantizedKVCache layers back to KVCache.

Converts each QuantizedKVCache layer back to a standard KVCache so it
can be used with BatchGenerator (which does not support quantized caches).
Non-quantized layers are left unchanged.

Args:
cache: List of cache layer objects (QuantizedKVCache, KVCache, etc.)

Returns:
New list with QuantizedKVCache layers replaced by KVCache.
"""
import mlx.core as mx
from mlx_lm.models.cache import KVCache, QuantizedKVCache

result = []
for layer_cache in cache:
if isinstance(layer_cache, QuantizedKVCache):
kv = KVCache()
# Dequantize keys and values from 3-tuple format
keys_q = layer_cache.keys
values_q = layer_cache.values
if keys_q is not None and values_q is not None:
kv.keys = mx.dequantize(
*keys_q,
group_size=layer_cache.group_size,
bits=layer_cache.bits,
)
kv.values = mx.dequantize(
*values_q,
group_size=layer_cache.group_size,
bits=layer_cache.bits,
)
kv.offset = layer_cache.offset
result.append(kv)
else:
result.append(layer_cache)
return result


class MemoryAwarePrefixCache:
"""
Prefix cache with memory-based eviction.
Expand Down
76 changes: 72 additions & 4 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class SchedulerConfig:
# 0 = disabled. Only effective when chunked_prefill_tokens > 0.
mid_prefill_save_interval: int = 8192

# KV cache quantization: quantize cached KV states to reduce memory usage.
# None = disabled (default), 4 = 4-bit (~75% savings), 8 = 8-bit (~50% savings).
# Only affects prefix cache storage — BatchGenerator uses full-precision KV cache.
kv_bits: Optional[int] = None
kv_group_size: int = 64


@dataclass
class SchedulerOutput:
Expand Down Expand Up @@ -573,6 +579,14 @@ def __init__(
f"Prefix cache enabled with max_entries={self.config.prefix_cache_size}"
)

# KV cache quantization
if self.config.kv_bits is not None:
savings = "~75%" if self.config.kv_bits == 4 else "~50%"
logger.info(
f"KV cache quantization enabled: {self.config.kv_bits}-bit, "
f"group_size={self.config.kv_group_size} ({savings} memory savings)"
)

# Thread-safe set for deferred aborts (main thread → executor thread)
# CPython GIL guarantees set.add() and `x in set` are atomic.
self._pending_abort_ids: Set[str] = set()
Expand Down Expand Up @@ -718,12 +732,24 @@ def _prompt_cache_save(uid, extracted_cache):
return

prompt_tokens = list(request.prompt_token_ids)

# Quantize cache before storing if KV quantization enabled
cache_to_store = extracted_cache
if self.config.kv_bits is not None:
from .memory_cache import quantize_kv_cache

cache_to_store = quantize_kv_cache(
cache_to_store,
group_size=self.config.kv_group_size,
bits=self.config.kv_bits,
)

_t0 = _time.monotonic()
# evict_prefixes=False: keep mid-prefill boundary entries so
# that future requests with the same prefix but different
# suffix get a prefix cache hit (critical for agentic multi-turn).
stored = self.memory_aware_cache.store(
prompt_tokens, extracted_cache, evict_prefixes=False
prompt_tokens, cache_to_store, evict_prefixes=False
)
_dt = _time.monotonic() - _t0
if stored:
Expand Down Expand Up @@ -777,6 +803,17 @@ def _mid_prefill_save(uid, processed_tokens, prompt_cache):
if not reconstructed:
return

# Quantize cache before storing if KV quantization enabled
cache_to_store = reconstructed
if self.config.kv_bits is not None:
from .memory_cache import quantize_kv_cache

cache_to_store = quantize_kv_cache(
cache_to_store,
group_size=self.config.kv_group_size,
bits=self.config.kv_bits,
)

prefix_tokens = list(request.prompt_token_ids[:total_cached])

# Remove previous intermediate entry to avoid memory waste
Expand All @@ -785,7 +822,7 @@ def _mid_prefill_save(uid, processed_tokens, prompt_cache):
self.memory_aware_cache.remove(list(old_key))

_t0 = _time.monotonic()
stored = self.memory_aware_cache.store(prefix_tokens, reconstructed)
stored = self.memory_aware_cache.store(prefix_tokens, cache_to_store)
_dt = _time.monotonic() - _t0

if stored:
Expand Down Expand Up @@ -1095,6 +1132,11 @@ def add_request(self, request: Request) -> None:
_fetch_dt = _time.monotonic() - _fetch_t0
request.cache_hit_type = self.memory_aware_cache._last_match_type
if cache:
# Dequantize if KV cache quantization is active
if self.config.kv_bits is not None:
from .memory_cache import dequantize_kv_cache

cache = dequantize_kv_cache(cache)
request.prompt_cache = cache
request.cached_tokens = len(request.prompt_token_ids) - len(remaining)
request.remaining_tokens = remaining
Expand All @@ -1115,6 +1157,11 @@ def add_request(self, request: Request) -> None:
# Use legacy prefix cache
cache, remaining = self.prefix_cache.fetch_cache(request.prompt_token_ids)
if cache:
# Dequantize if KV cache quantization is active
if self.config.kv_bits is not None:
from .memory_cache import dequantize_kv_cache

cache = dequantize_kv_cache(cache)
request.cache_hit_type = "hit"
request.prompt_cache = cache
request.cached_tokens = len(request.prompt_token_ids) - len(remaining)
Expand Down Expand Up @@ -1487,12 +1534,23 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
full_token_sequence = list(request.prompt_token_ids) + list(
request.output_token_ids
)
# Quantize cache before storing if KV quantization enabled
cache_to_store = request._extracted_cache
if self.config.kv_bits is not None:
from .memory_cache import quantize_kv_cache

cache_to_store = quantize_kv_cache(
cache_to_store,
group_size=self.config.kv_group_size,
bits=self.config.kv_bits,
)

import time as _time

_store_t0 = _time.monotonic()
stored = self.memory_aware_cache.store(
full_token_sequence,
request._extracted_cache,
cache_to_store,
evict_prefixes=False,
)
_store_dt = _time.monotonic() - _store_t0
Expand Down Expand Up @@ -1532,9 +1590,19 @@ def _cleanup_finished(self, finished_ids: Set[str]) -> None:
full_token_sequence = list(request.prompt_token_ids) + list(
request.output_token_ids
)
# Quantize cache before storing if KV quantization enabled
cache_to_store = request._extracted_cache
if self.config.kv_bits is not None:
from .memory_cache import quantize_kv_cache

cache_to_store = quantize_kv_cache(
cache_to_store,
group_size=self.config.kv_group_size,
bits=self.config.kv_bits,
)
self.prefix_cache.store_cache(
full_token_sequence,
request._extracted_cache,
cache_to_store,
)
logger.debug(
f"Stored cache for request {request_id} "
Expand Down
Loading