Add KV cache quantization for prefix cache memory reduction#62
Add KV cache quantization for prefix cache memory reduction#62waybarrios merged 9 commits intomainfrom
Conversation
Adds --kv-cache-quantization flag that uses mlx-lm's QuantizedKVCache to compress stored prefix cache entries (8-bit or 4-bit), reducing memory usage ~3.5x with minimal quality loss (~0.005 mean abs error). Quantization happens on store, dequantization on fetch, so active inference is unaffected. Includes CLI flags for serve and bench commands, config wiring through SchedulerConfig, and 16 tests. Closes #60
6cf1a9c to
27e6884
Compare
New CLI command that compares FP16, 8-bit, and 4-bit KV cache quantization using synthetic data, reporting memory usage, compression ratio, quality metrics, and quantize/dequantize latency. Usage: vllm-mlx bench-kv-cache [--layers 32] [--seq-len 512]
|
Ran the KV cache quantization benchmark on synthetic data (32 layers, 512 seq len, 32 heads, 128 head dim) to see how much memory we actually save and what the quality tradeoff looks like.
The quantize/dequantize overhead is pretty small, around 2-9ms per round trip depending on bit width. 8-bit is the sweet spot for most use cases since the quality loss is basically negligible. 4-bit is there if you really need to squeeze memory but the max error gets noticeable. You can run this yourself with the new bench command: vllm-mlx bench-kv-cache |
|
Same benchmark with larger head dimensions to see how it scales. head_dim=512 (32 layers, 512 seq len, 32 heads):
head_dim=1024 (32 layers, 512 seq len, 32 heads):
The compression ratio stays consistent at 3.56x for 8-bit and 6.4x for 4-bit regardless of head dimension. Mean error also stays stable. The main difference is absolute memory saved, which gets pretty significant at larger dimensions. Going from 4 GB down to 1.1 GB with 8-bit at head_dim=1024 is a big deal for running larger models on machines with limited RAM. |
|
Pushing it further with head_dim=2048 and head_dim=4096. These are the kind of dimensions you see in larger models. head_dim=2048 (32 layers, 512 seq len, 32 heads):
head_dim=4096 (32 layers, 512 seq len, 32 heads):
At head_dim=4096 you go from 16 GB down to 4.6 GB with 8-bit, or 2.5 GB with 4-bit. That's the difference between fitting in memory or not on a 16 GB machine. The compression ratio and error stay the same across all dimensions, only the absolute savings and latency scale up. To reproduce: vllm-mlx bench-kv-cache --head-dim 2048 |
|
Tested with long context windows, seq_len=32k and seq_len=64k. This simulates what happens when you cache long conversations or large documents. seq_len=32768 (32 layers, 32 heads, head_dim=128):
seq_len=65536 (32 layers, 32 heads, head_dim=128):
At 32k context the overhead is still reasonable, around 80-300ms. At 64k things get heavier, especially the 8-bit dequantize at 8.4 seconds and 4-bit quantize at 25.5 seconds. That said, the memory savings are massive: going from 64 GB down to 18 GB with 8-bit or 10 GB with 4-bit. For long context scenarios on machines with 32-64 GB of unified memory, this is what makes the difference between being able to cache those conversations or not. To reproduce: vllm-mlx bench-kv-cache --seq-len 32768 |
|
The quantize/dequantize latency at 64k is worth calling out. 8.4 seconds for dequantize and 25 seconds for quantize is way too slow if this is in the hot path. At 32k it's still fine (80-300ms), but 64k hits a wall. A few things I want to try to bring that down:
For now this works well up to 32k context, which covers most real use cases. The long context optimization is something I'll dig into separately. |
|
Did some profiling to understand the latency spike at 64k. Turns out it's memory pressure, not a code issue. Here's what happens at seq_len=65536 (32 layers, 32 heads, head_dim=128):
The machine has 128 GB of unified memory. At 112 GB (87% usage), macOS starts doing memory compaction and swap, which is what causes the jump from ~460ms at 49k to ~7200ms at 65k. To confirm, I ran dequantize again after deleting the original FP16 cache. With only ~44 GB in memory instead of ~112 GB, it dropped to 3188ms, roughly half the time. Scaling from 8k to 49k is perfectly linear:
The cliff only appears when total memory usage crosses ~80% of available RAM. In production this won't be a problem because the real flow is: quantize on store (then the original is discarded) and dequantize on fetch (the quantized version stays in cache). You never have all three copies in memory at the same time like the benchmark does. |
After storing the quantized cache in the prefix cache, the original FP16 reference on the request is no longer needed. Setting it to None allows the memory to be reclaimed sooner, preventing temporary memory spikes when quantization is enabled on long sequences.
|
Follow-up on the latency investigation. Simulated the actual production flow (quantize, release original FP16 cache, then dequantize on fetch) and the latency spike at 64k is gone. Production flow (original freed before dequantize):
Compared to the benchmark (which kept all three copies in memory):
Scaling is perfectly linear now across the full range. The 7 second spike was caused by the benchmark holding the original FP16 cache, the quantized cache, and the dequantized cache all in memory at once (112 GB on a 128 GB machine). In the real server flow, the original FP16 reference is released right after the quantized version is stored, so you never hit that memory wall. Added a small fix in the scheduler to explicitly set _extracted_cache = None after storing the quantized version, so the FP16 data gets freed immediately. |
Local Testing Results: KV Cache QuantizationTested on Apple Silicon (M4 Ultra, 247GB unified memory) with Test SetupTwo server instances running identical workloads:
vllm-mlx serve mlx-community/Qwen3-0.6B-4bit \
--port 1244 --host 0.0.0.0 --continuous-batching \
--max-num-seqs 4 --cache-memory-mb 2048 \
--kv-cache-quantization --kv-cache-quantization-bits 8
Functional Verification
Memory SavingsBoth servers received identical requests (same prompts, same
Matches the theoretical ~50% savings (fp16 → uint8 + small overhead for scale/bias params). Raw Log EvidenceQuantized server: Baseline server: ConclusionKV cache quantization works correctly with ~47% real-world memory savings. No impact on response quality or throughput. Prefix cache hit/miss behavior unchanged. ✅ |
|
@janhilgard what do you think about the latency when we increase the tokens an reduce the bits in quantization? Any plan how to fix or deal with it more professionally? |
Bug Report: Cache HIT path stores oversized KV arraysFindingWhen testing PR #62 locally against the same workload as above, I found that KV cache quantization only delivers memory savings on cache MISS, not on cache HIT. EvidenceFour
Root causeWhen there's a prefix cache HIT, the cached KV arrays are passed to This means:
ImpactThe quantization provides no real memory savings for the common case of repeated requests with shared prefixes (cache HITs), which is the primary use case for prefix caching. Suggested fixTrim the extracted cache arrays to # Before quantizing, trim KV arrays to actual offset to avoid storing unused buffer
if self._config.kv_quantize:
cache = _trim_to_offset(cache) # new helper
cache = _quantize_cache(cache, ...) |
|
Good question. Your profiling already nailed the key insight — the production flow scales linearly because we never hold all three copies in memory simultaneously. The 615ms dequantize at 65K is very reasonable. Here's my take on the latency picture: What's already fine:
Where it gets interesting (4-bit):
Concrete improvements we could pursue (roughly priority-ordered):
For now I'd say the current 8-bit implementation with the PR #69 fix covers the main use case well. The threshold-based approach (#2) would be a clean next step — low complexity, avoids unnecessary overhead on short entries. The lazy dequant (#3) is the real long-term win but requires more work. What do you think? Want me to prototype the threshold approach? |
When KV cache quantization is enabled, prefix cache entries are stored as QuantizedKVCache objects. The _trim_cache_offset function (used for supersequence and LCP matches) was silently skipping these layers because QuantizedKVCache.keys returns a tuple, failing the `not isinstance(keys, (list, tuple))` guard. This caused the offset to remain untrimmed, so dequantized caches passed to BatchGenerator had their original (large) offset. The BatchGenerator then concatenated new tokens to the full buffer instead of the trimmed prefix, producing oversized KV arrays that negated all memory savings from quantization. Tested: after fix, cache_mem=33MB (correct) vs 62MB (broken, same as unquantized baseline) for the same 2-request workload. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
|
Nice breakdown on the latency side. Your priority ordering makes sense to me. For PR #69, already merged, that fixes the QuantizedKVCache handling in _trim_cache_offset. The tuple guard was a subtle one, good find. The threshold-based quantization is the one I'd want next. Quantizing a 50-token entry saves almost nothing but you still pay the full per-layer overhead. Something like a min_quantize_tokens check in store() with a reasonable default (maybe 256 or 512) would keep it simple. Go ahead and prototype that. Lazy dequantization is appealing but yeah, threading that through BatchGenerator is a bigger change. Let's revisit once we have real numbers from the threshold approach. Adaptive bit selection is cool but I think most people will be fine with 8-bit as the default. Anyone who needs 4-bit can just pass the flag. Lower priority for now. On the oversized buffer thing you reported, that's a real issue but it lives in the cache extraction path, not in the quantization logic itself. The non-quantized path has the same problem, it just stores oversized FP16 arrays. A _trim_to_offset step before _quantize_cache in store() would fix it. Can you fold that into the threshold PR? |
|
Both items from @waybarrios's comment are now implemented in PR #73:
PR: #73 |
|
@janhilgard any update so far on this manner? |
|
Did you update the branch? @janhilgard |
|
Hey! Here's the current status:
I think #62 is ready to merge as-is. Then #73 can go on top. Want me to rebase #73 onto |
|
No changes from my side — the branch is as you left it after the merge with main on Feb 12. PR #69 fix is already in there. Ready to merge whenever you are. |
* feat: add min_quantize_tokens threshold and trim oversized KV buffers - Add _trim_to_offset() to trim pre-allocated KV arrays to their actual used size before storage, saving memory in both FP16 and quantized paths - Add kv_min_quantize_tokens config (default 256) to skip quantization for short sequences where overhead exceeds savings - Thread the new config through SchedulerConfig and CLI arguments Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: apply black formatting Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Harden _trim_to_offset and store() in memory cache Move the duplicate entry check in store() before trim and quantize so repeated tokens skip the expensive work entirely. Rewrite _trim_to_offset to validate that offset is positive before slicing, use KVCache() instead of __new__ to avoid skipping init, call mx.eval on the sliced arrays so the original large buffer gets freed and memory accounting stays accurate, and skip the function entirely when no layer actually needs trimming. Add validation for kv_min_quantize_tokens in MemoryCacheConfig so negative values are rejected at init time. Document the field in the class docstring and add Args and Returns to the _trim_to_offset docstring. --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Wayner Barrios <waybarrios@gmail.com>
There was a problem hiding this comment.
I ran the full test suite locally after the latest push (23 quantization tests + 42 memory cache mock tests + full suite): 902 passed, 1 pre-existing failure (unrelated mock issue on main).
Pushed a fix for the ModuleNotFoundError: No module named 'mlx' that was failing test-matrix (3.10/3.11/3.12). The root cause: _trim_to_offset() imported mlx.core unconditionally at the top of the function, but store() calls it on every entry — including mock objects in tests running on Linux CI without MLX.
Fix: Extracted _needs_kv_trim() — a duck-typed helper that checks if any layer has oversized KV arrays using getattr instead of isinstance(layer, KVCache). The MLX imports now only happen when trimming is actually needed. Also fixed a ruff N806 lint error in _trim_cache_offset.
Changes
- Core quantization —
_quantize_cache()/_dequantize_cache()with store-on-quantize, fetch-on-dequantize pattern _trim_cache_offsetfix for QuantizedKVCache — explicit handling with offset/group_size/bits preserved_trim_to_offset— trims pre-allocated KV buffers to actual used size before storagekv_min_quantize_tokensthreshold — skips quantization for short sequences (default: 256)bench-kv-cacheCLI command — synthetic benchmarks for quantization- Config validation —
kv_min_quantize_tokens >= 0enforced in__post_init__ - 23 tests covering round-trip, memory reduction, config, store/fetch integration, trim, and threshold
Improvements
- The
_dequantize_cache(x) if self._config.kv_quantize else xpattern appears 5x infetch()— could be a_maybe_dequantize()method - Disk persistence (
save_to_disk/load_from_disk): works with QuantizedKVCache (verified), but ifkv_quantize=Falseat load time and entries were saved quantized, they'll be returned raw. Consider always calling_dequantize_cacheon fetch regardless of config flag (it's a no-op on regular KVCache) - No test for supersequence/LCP match with quantized cache — the
_trim_cache_offsetfix addresses this path but no test exercises it directly
Overall the feature is solid and ready to merge once CI is green.
73279ca to
bb53581
Compare
Resolve merge conflicts from kv_min_quantize_tokens feature (PR waybarrios#62). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Merge 17 upstream commits including: - KV cache quantization for prefix cache memory reduction (waybarrios#62) - Streaming tool call parsing via ToolParser integration (waybarrios#46) - MTP speculative decoding for Qwen3-Next (waybarrios#82) - GPT-OSS reasoning parser and Harmony format parsers - mlx-lm >= 0.30.5 requirement, transformers >= 5.0.0 - BatchMambaCache fix for mlx-lm >= 0.30.6 (waybarrios#89) - MLLM continuous batching fixes (waybarrios#76) - Force MLLM mode option (waybarrios#81) - Various bug fixes Conflict resolution: - server.py: Replaced local tool_call_buffering with upstream's ToolParser-based streaming (more robust) - cli.py: Deduplicated --mllm, --default-temperature, --default-top-p args (upstream already added them), kept local --embedding-model - mamba_cache.py: Took upstream's conditional HAS_MAMBA_CACHE approach - pyproject.toml: Took upstream's version and dependency changes Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This adds KV cache quantization support for the prefix cache, so stored cache entries take up much less memory.
The new --kv-cache-quantization flag tells the system to compress KV cache entries using mlx-lm's QuantizedKVCache when storing them in the prefix cache. By default it uses 8-bit quantization, but you can also go down to 4-bit with --kv-cache-quantization-bits 4. In practice this gives around 3.5x memory savings with very little quality loss (mean absolute error around 0.005).
The approach is simple: quantize when storing, dequantize when fetching. Active inference stays on full precision since BatchKVCache doesn't have a quantized variant, so there's no impact on generation quality during a request.
What changed:
memory_cache.py - New _quantize_cache() and _dequantize_cache() helpers. Added kv_quantize, kv_bits, and kv_group_size to MemoryCacheConfig. Updated estimate_kv_cache_memory() and _trim_cache_offset() to handle quantized layers. All 5 fetch return points now dequantize when needed, and store() quantizes before saving.
scheduler.py - Added kv_cache_quantization fields to SchedulerConfig and wired them through to MemoryCacheConfig.
cli.py - Three new flags (--kv-cache-quantization, --kv-cache-quantization-bits, --kv-cache-quantization-group-size) for both serve and bench commands.
tests/test_kv_cache_quantization.py - 16 tests covering round-trip correctness, memory reduction, config propagation, mixed cache layers, and store/fetch integration.
Usage:
vllm-mlx serve model --continuous-batching --kv-cache-quantization
vllm-mlx serve model --continuous-batching --kv-cache-quantization --kv-cache-quantization-bits 4
Testing:
Closes #60