Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion tests/test_kv_cache_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
MemoryCacheConfig,
_dequantize_cache,
_quantize_cache,
_trim_to_offset,
estimate_kv_cache_memory,
)

Expand Down Expand Up @@ -180,7 +181,12 @@ def test_store_fetch_without_quantization(self):

def test_store_fetch_with_quantization(self):
model = self._make_cache_and_model()
config = MemoryCacheConfig(kv_quantize=True, kv_bits=8, max_memory_mb=500)
config = MemoryCacheConfig(
kv_quantize=True,
kv_bits=8,
kv_min_quantize_tokens=0,
max_memory_mb=500,
)
pc = MemoryAwarePrefixCache(model, config)

cache = _make_kv_cache(n_layers=2, seq_len=50)
Expand Down Expand Up @@ -215,3 +221,122 @@ def test_quantized_store_reduces_tracked_memory(self):
pc_q8.store(tokens, cache)

assert pc_q8._current_memory < pc_fp16._current_memory


class TestTrimToOffset:
"""Test the _trim_to_offset helper function."""

def test_trim_oversized_arrays(self):
"""KV arrays larger than offset should be trimmed."""
cache = []
for _ in range(2):
kv = KVCache()
kv.keys = mx.random.normal((1, 8, 4096, 64))
kv.values = mx.random.normal((1, 8, 4096, 64))
kv.offset = 512
cache.append(kv)
mx.eval(*[kv.keys for kv in cache], *[kv.values for kv in cache])

trimmed = _trim_to_offset(cache)
for layer in trimmed:
assert layer.keys.shape[2] == 512
assert layer.values.shape[2] == 512
assert layer.offset == 512

def test_no_trim_when_exact(self):
"""No trimming needed when arrays match offset."""
cache = _make_kv_cache(n_layers=2, seq_len=100)
trimmed = _trim_to_offset(cache)
for orig, tr in zip(cache, trimmed):
assert tr.keys.shape == orig.keys.shape
assert tr.values.shape == orig.values.shape

def test_non_kvcache_layers_preserved(self):
"""Non-KVCache layers pass through unchanged."""
fake_layer = {"state": mx.zeros((1, 16, 64)), "type": "mamba"}
result = _trim_to_offset([fake_layer])
assert result[0] is fake_layer

def test_none_keys_passthrough(self):
"""KVCache with None keys should pass through."""
kv = KVCache()
result = _trim_to_offset([kv])
assert result[0] is kv


class TestMinQuantizeTokensThreshold:
"""Test that short sequences skip quantization."""

def _make_model(self):
class FakeModel:
pass

return FakeModel()

def test_store_skips_quantization_below_threshold(self):
"""Sequences shorter than min_quantize_tokens should not be quantized."""
model = self._make_model()
config = MemoryCacheConfig(
kv_quantize=True,
kv_bits=8,
kv_min_quantize_tokens=256,
max_memory_mb=500,
)
pc = MemoryAwarePrefixCache(model, config)

cache = _make_kv_cache(n_layers=2, seq_len=50)
tokens = list(range(50))
pc.store(tokens, cache)

stored_entry = list(pc._entries.values())[0]
for layer in stored_entry.cache:
assert isinstance(
layer, KVCache
), "Short sequences should remain as KVCache (not quantized)"

def test_store_quantizes_above_threshold(self):
"""Sequences >= min_quantize_tokens should be quantized."""
model = self._make_model()
config = MemoryCacheConfig(
kv_quantize=True,
kv_bits=8,
kv_min_quantize_tokens=256,
max_memory_mb=500,
)
pc = MemoryAwarePrefixCache(model, config)

cache = _make_kv_cache(n_layers=2, seq_len=300)
tokens = list(range(300))
pc.store(tokens, cache)

stored_entry = list(pc._entries.values())[0]
for layer in stored_entry.cache:
assert isinstance(
layer, QuantizedKVCache
), "Long sequences should be quantized"

def test_trim_applied_without_quantization(self):
"""Oversized arrays should be trimmed even without quantization."""
model = self._make_model()
config = MemoryCacheConfig(kv_quantize=False, max_memory_mb=500)
pc = MemoryAwarePrefixCache(model, config)

# Create oversized cache: arrays have 4096 but offset is 100
cache = []
for _ in range(2):
kv = KVCache()
kv.keys = mx.random.normal((1, 8, 4096, 64))
kv.values = mx.random.normal((1, 8, 4096, 64))
kv.offset = 100
cache.append(kv)
mx.eval(*[kv.keys for kv in cache], *[kv.values for kv in cache])

tokens = list(range(100))
pc.store(tokens, cache)

stored_entry = list(pc._entries.values())[0]
for layer in stored_entry.cache:
assert (
layer.keys.shape[2] == 100
), f"Expected trimmed to 100, got {layer.keys.shape[2]}"
assert layer.values.shape[2] == 100
14 changes: 14 additions & 0 deletions vllm_mlx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def serve_command(args):
kv_cache_quantization=args.kv_cache_quantization,
kv_cache_quantization_bits=args.kv_cache_quantization_bits,
kv_cache_quantization_group_size=args.kv_cache_quantization_group_size,
kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens,
)

print("Mode: Continuous batching (for multiple concurrent users)")
Expand Down Expand Up @@ -222,6 +223,7 @@ async def run_benchmark():
kv_cache_quantization=args.kv_cache_quantization,
kv_cache_quantization_bits=args.kv_cache_quantization_bits,
kv_cache_quantization_group_size=args.kv_cache_quantization_group_size,
kv_cache_min_quantize_tokens=args.kv_cache_min_quantize_tokens,
)
engine_config = EngineConfig(
model_name=args.model,
Expand Down Expand Up @@ -649,6 +651,12 @@ def main():
default=64,
help="Group size for KV cache quantization (default: 64)",
)
serve_parser.add_argument(
"--kv-cache-min-quantize-tokens",
type=int,
default=256,
help="Minimum tokens for quantization to apply (default: 256)",
)
serve_parser.add_argument(
"--stream-interval",
type=int,
Expand Down Expand Up @@ -857,6 +865,12 @@ def main():
default=64,
help="Group size for KV cache quantization (default: 64)",
)
bench_parser.add_argument(
"--kv-cache-min-quantize-tokens",
type=int,
default=256,
help="Minimum tokens for quantization to apply (default: 256)",
)
# Paged cache options (experimental)
bench_parser.add_argument(
"--use-paged-cache",
Expand Down
75 changes: 68 additions & 7 deletions vllm_mlx/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class MemoryCacheConfig:
kv_quantize: Whether to quantize KV cache layers for reduced memory.
kv_bits: Number of bits for KV cache quantization.
kv_group_size: Group size for KV cache quantization.
kv_min_quantize_tokens: Minimum sequence length for quantization to apply.
"""

max_memory_mb: int | None = None
Expand All @@ -163,6 +164,7 @@ class MemoryCacheConfig:
kv_quantize: bool = False
kv_bits: int = 8
kv_group_size: int = 64
kv_min_quantize_tokens: int = 256

def __post_init__(self) -> None:
if not 0.0 < self.max_memory_percent <= 1.0:
Expand All @@ -171,6 +173,10 @@ def __post_init__(self) -> None:
)
if self.max_entries < 1:
raise ValueError(f"max_entries must be >= 1, got {self.max_entries}")
if self.kv_min_quantize_tokens < 0:
raise ValueError(
f"kv_min_quantize_tokens must be >= 0, got {self.kv_min_quantize_tokens}"
)

def compute_memory_limit(self) -> int:
"""
Expand Down Expand Up @@ -290,6 +296,55 @@ def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
return trimmed


def _trim_to_offset(cache: list[Any]) -> list[Any]:
"""Trim KV arrays to their actual used size (offset) before storage.

KV arrays are often pre-allocated larger than needed (e.g. 4096 slots
when only 100 are used). This slices them down to ``offset`` and
evaluates the result so the original large buffer can be freed.

Args:
cache: List of cache layer objects (KVCache or other types).

Returns:
New list with KVCache layers trimmed to their offset.
Non-KVCache layers are passed through unchanged.
"""
import mlx.core as mx
from mlx_lm.models.cache import KVCache

needs_trim = any(
isinstance(layer, KVCache)
and layer.keys is not None
and 0 < layer.offset < layer.keys.shape[2]
for layer in cache
)
if not needs_trim:
return cache

trimmed = []
eval_targets = []
for layer in cache:
if isinstance(layer, KVCache) and layer.keys is not None:
offset = layer.offset
if offset <= 0 or offset >= layer.keys.shape[2]:
trimmed.append(layer)
continue
tc = KVCache()
tc.keys = layer.keys[:, :, :offset, :]
tc.values = layer.values[:, :, :offset, :]
tc.offset = offset
eval_targets.extend([tc.keys, tc.values])
trimmed.append(tc)
else:
trimmed.append(layer)

if eval_targets:
mx.eval(*eval_targets)

return trimmed


def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]:
"""Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is."""
from mlx_lm.models.cache import KVCache
Expand Down Expand Up @@ -620,19 +675,25 @@ def store(
if not tokens or not cache:
return False

# Quantize cache layers if configured
if self._config.kv_quantize:
cache = _quantize_cache(
cache, self._config.kv_bits, self._config.kv_group_size
)

tokens_key = tuple(tokens)

# If already cached, just update LRU order
# If already cached, just update LRU order (skip expensive trim/quantize)
if tokens_key in self._entries:
self._entries.move_to_end(tokens_key)
return True

# Trim oversized KV arrays to actual used size
cache = _trim_to_offset(cache)

# Quantize if enabled and sequence is long enough
if (
self._config.kv_quantize
and len(tokens) >= self._config.kv_min_quantize_tokens
):
cache = _quantize_cache(
cache, self._config.kv_bits, self._config.kv_group_size
)

# Create entry and estimate memory
entry = _CacheEntry.create(tokens, cache)

Expand Down
2 changes: 2 additions & 0 deletions vllm_mlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class SchedulerConfig:
kv_cache_quantization: bool = False
kv_cache_quantization_bits: int = 8
kv_cache_quantization_group_size: int = 64
kv_cache_min_quantize_tokens: int = 256

# Paged cache settings (experimental - for memory efficiency)
use_paged_cache: bool = (
Expand Down Expand Up @@ -562,6 +563,7 @@ def __init__(
kv_quantize=self.config.kv_cache_quantization,
kv_bits=self.config.kv_cache_quantization_bits,
kv_group_size=self.config.kv_cache_quantization_group_size,
kv_min_quantize_tokens=self.config.kv_cache_min_quantize_tokens,
)
self.memory_aware_cache = MemoryAwarePrefixCache(
model=model,
Expand Down