Skip to content

Add KV cache quantization for prefix cache memory reduction#62

Merged
waybarrios merged 9 commits intomainfrom
feat/kv-cache-quantization
Feb 14, 2026
Merged

Add KV cache quantization for prefix cache memory reduction#62
waybarrios merged 9 commits intomainfrom
feat/kv-cache-quantization

Conversation

@waybarrios
Copy link
Copy Markdown
Owner

@waybarrios waybarrios commented Feb 11, 2026

This adds KV cache quantization support for the prefix cache, so stored cache entries take up much less memory.

The new --kv-cache-quantization flag tells the system to compress KV cache entries using mlx-lm's QuantizedKVCache when storing them in the prefix cache. By default it uses 8-bit quantization, but you can also go down to 4-bit with --kv-cache-quantization-bits 4. In practice this gives around 3.5x memory savings with very little quality loss (mean absolute error around 0.005).

The approach is simple: quantize when storing, dequantize when fetching. Active inference stays on full precision since BatchKVCache doesn't have a quantized variant, so there's no impact on generation quality during a request.

What changed:

memory_cache.py - New _quantize_cache() and _dequantize_cache() helpers. Added kv_quantize, kv_bits, and kv_group_size to MemoryCacheConfig. Updated estimate_kv_cache_memory() and _trim_cache_offset() to handle quantized layers. All 5 fetch return points now dequantize when needed, and store() quantizes before saving.

scheduler.py - Added kv_cache_quantization fields to SchedulerConfig and wired them through to MemoryCacheConfig.

cli.py - Three new flags (--kv-cache-quantization, --kv-cache-quantization-bits, --kv-cache-quantization-group-size) for both serve and bench commands.

tests/test_kv_cache_quantization.py - 16 tests covering round-trip correctness, memory reduction, config propagation, mixed cache layers, and store/fetch integration.

Usage:

vllm-mlx serve model --continuous-batching --kv-cache-quantization
vllm-mlx serve model --continuous-batching --kv-cache-quantization --kv-cache-quantization-bits 4

Testing:

  • All 16 tests pass
  • Black and ruff clean
  • Still needs a manual test with a real model

Closes #60

Adds --kv-cache-quantization flag that uses mlx-lm's QuantizedKVCache
to compress stored prefix cache entries (8-bit or 4-bit), reducing
memory usage ~3.5x with minimal quality loss (~0.005 mean abs error).

Quantization happens on store, dequantization on fetch, so active
inference is unaffected. Includes CLI flags for serve and bench
commands, config wiring through SchedulerConfig, and 16 tests.

Closes #60
@waybarrios waybarrios force-pushed the feat/kv-cache-quantization branch from 6cf1a9c to 27e6884 Compare February 11, 2026 03:06
@waybarrios waybarrios self-assigned this Feb 11, 2026
New CLI command that compares FP16, 8-bit, and 4-bit KV cache
quantization using synthetic data, reporting memory usage, compression
ratio, quality metrics, and quantize/dequantize latency.

Usage: vllm-mlx bench-kv-cache [--layers 32] [--seq-len 512]
@waybarrios
Copy link
Copy Markdown
Owner Author

waybarrios commented Feb 11, 2026

Ran the KV cache quantization benchmark on synthetic data (32 layers, 512 seq len, 32 heads, 128 head dim) to see how much memory we actually save and what the quality tradeoff looks like.

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 512.00 MB 1.00x 0.000 0.000 - -
8-bit 144.00 MB 3.56x 0.00456 0.02897 2.1 ms 3.9 ms
4-bit 80.00 MB 6.40x 0.07711 0.50543 8.8 ms 4.7 ms

The quantize/dequantize overhead is pretty small, around 2-9ms per round trip depending on bit width. 8-bit is the sweet spot for most use cases since the quality loss is basically negligible. 4-bit is there if you really need to squeeze memory but the max error gets noticeable.

You can run this yourself with the new bench command:

vllm-mlx bench-kv-cache
vllm-mlx bench-kv-cache --layers 64 --seq-len 1024

@waybarrios
Copy link
Copy Markdown
Owner Author

Same benchmark with larger head dimensions to see how it scales.

head_dim=512 (32 layers, 512 seq len, 32 heads):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 2048.00 MB 1.00x 0.000 0.000 - -
8-bit 576.00 MB 3.56x 0.00456 0.03141 5.1 ms 14.7 ms
4-bit 320.00 MB 6.40x 0.07710 0.55279 4.0 ms 3.2 ms

head_dim=1024 (32 layers, 512 seq len, 32 heads):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 4096.00 MB 1.00x 0.000 0.000 - -
8-bit 1152.00 MB 3.56x 0.00456 0.03238 9.8 ms 30.0 ms
4-bit 640.00 MB 6.40x 0.07711 0.49785 7.0 ms 6.1 ms

The compression ratio stays consistent at 3.56x for 8-bit and 6.4x for 4-bit regardless of head dimension. Mean error also stays stable. The main difference is absolute memory saved, which gets pretty significant at larger dimensions. Going from 4 GB down to 1.1 GB with 8-bit at head_dim=1024 is a big deal for running larger models on machines with limited RAM.

@waybarrios
Copy link
Copy Markdown
Owner Author

Pushing it further with head_dim=2048 and head_dim=4096. These are the kind of dimensions you see in larger models.

head_dim=2048 (32 layers, 512 seq len, 32 heads):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 8192.00 MB 1.00x 0.000 0.000 - -
8-bit 2304.00 MB 3.56x 0.00456 0.03174 17.8 ms 65.8 ms
4-bit 1280.00 MB 6.40x 0.07711 0.52067 13.5 ms 12.2 ms

head_dim=4096 (32 layers, 512 seq len, 32 heads):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 16384.00 MB 1.00x 0.000 0.000 - -
8-bit 4608.00 MB 3.56x 0.00456 0.03126 33.7 ms 132.1 ms
4-bit 2560.00 MB 6.40x 0.07711 0.52438 26.2 ms 24.5 ms

At head_dim=4096 you go from 16 GB down to 4.6 GB with 8-bit, or 2.5 GB with 4-bit. That's the difference between fitting in memory or not on a 16 GB machine. The compression ratio and error stay the same across all dimensions, only the absolute savings and latency scale up.

To reproduce:

vllm-mlx bench-kv-cache --head-dim 2048
vllm-mlx bench-kv-cache --head-dim 4096

@waybarrios
Copy link
Copy Markdown
Owner Author

Tested with long context windows, seq_len=32k and seq_len=64k. This simulates what happens when you cache long conversations or large documents.

seq_len=32768 (32 layers, 32 heads, head_dim=128):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 32768.00 MB 1.00x 0.000 0.000 - -
8-bit 9216.00 MB 3.56x 0.00456 0.03281 80.0 ms 300.5 ms
4-bit 5120.00 MB 6.40x 0.07711 0.55431 109.8 ms 49.8 ms

seq_len=65536 (32 layers, 32 heads, head_dim=128):

Mode Memory Compression Mean Error Max Error Quantize Dequantize
FP16 65536.00 MB 1.00x 0.000 0.000 - -
8-bit 18432.00 MB 3.56x 0.00456 0.03278 171.9 ms 8429.8 ms
4-bit 10240.00 MB 6.40x 0.07711 0.56201 25531.3 ms 4045.0 ms

At 32k context the overhead is still reasonable, around 80-300ms. At 64k things get heavier, especially the 8-bit dequantize at 8.4 seconds and 4-bit quantize at 25.5 seconds. That said, the memory savings are massive: going from 64 GB down to 18 GB with 8-bit or 10 GB with 4-bit. For long context scenarios on machines with 32-64 GB of unified memory, this is what makes the difference between being able to cache those conversations or not.

To reproduce:

vllm-mlx bench-kv-cache --seq-len 32768
vllm-mlx bench-kv-cache --seq-len 65536

@waybarrios
Copy link
Copy Markdown
Owner Author

waybarrios commented Feb 11, 2026

The quantize/dequantize latency at 64k is worth calling out. 8.4 seconds for dequantize and 25 seconds for quantize is way too slow if this is in the hot path. At 32k it's still fine (80-300ms), but 64k hits a wall.

A few things I want to try to bring that down:

  • Process the cache in chunks instead of all at once
  • Only dequantize the layers we actually need instead of the full stack
  • Use a hybrid strategy where short caches stay in FP16 and we only quantize the longer ones that are worth compressing

For now this works well up to 32k context, which covers most real use cases. The long context optimization is something I'll dig into separately.

@waybarrios
Copy link
Copy Markdown
Owner Author

Did some profiling to understand the latency spike at 64k. Turns out it's memory pressure, not a code issue.

Here's what happens at seq_len=65536 (32 layers, 32 heads, head_dim=128):

Step Memory Notes
Initial 0.00 GB
After creating FP16 cache 68.72 GB The original cache
After quantize 78.38 GB Original + quantized (68 + 9 GB)
After dequantize 112.74 GB Original + quantized + dequantized (68 + 9 + 34 GB)

The machine has 128 GB of unified memory. At 112 GB (87% usage), macOS starts doing memory compaction and swap, which is what causes the jump from ~460ms at 49k to ~7200ms at 65k.

To confirm, I ran dequantize again after deleting the original FP16 cache. With only ~44 GB in memory instead of ~112 GB, it dropped to 3188ms, roughly half the time.

Scaling from 8k to 49k is perfectly linear:

seq_len FP16 size Quantize (8-bit) Dequantize (8-bit)
8192 4.3 GB 18.2 ms 67.1 ms
16384 8.6 GB 33.4 ms 137.5 ms
32768 17.2 GB 77.2 ms 276.9 ms
49152 25.8 GB 118.0 ms 462.9 ms
65536 34.4 GB 166.1 ms 7215.1 ms

The cliff only appears when total memory usage crosses ~80% of available RAM.

In production this won't be a problem because the real flow is: quantize on store (then the original is discarded) and dequantize on fetch (the quantized version stays in cache). You never have all three copies in memory at the same time like the benchmark does.

After storing the quantized cache in the prefix cache, the original
FP16 reference on the request is no longer needed. Setting it to None
allows the memory to be reclaimed sooner, preventing temporary memory
spikes when quantization is enabled on long sequences.
@waybarrios
Copy link
Copy Markdown
Owner Author

Follow-up on the latency investigation. Simulated the actual production flow (quantize, release original FP16 cache, then dequantize on fetch) and the latency spike at 64k is gone.

Production flow (original freed before dequantize):

seq_len Quantize Dequantize Peak Memory
8k 19 ms 81 ms 9.9 GB
16k 33 ms 136 ms 19.9 GB
32k 85 ms 291 ms 39.7 GB
49k 183 ms 471 ms 59.6 GB
65k 154 ms 615 ms 79.5 GB

Compared to the benchmark (which kept all three copies in memory):

seq_len Dequantize (benchmark) Dequantize (production)
49k 463 ms 471 ms
65k 7215 ms 615 ms

Scaling is perfectly linear now across the full range. The 7 second spike was caused by the benchmark holding the original FP16 cache, the quantized cache, and the dequantized cache all in memory at once (112 GB on a 128 GB machine). In the real server flow, the original FP16 reference is released right after the quantized version is stored, so you never hit that memory wall.

Added a small fix in the scheduler to explicitly set _extracted_cache = None after storing the quantized version, so the FP16 data gets freed immediately.

@janhilgard
Copy link
Copy Markdown
Collaborator

Local Testing Results: KV Cache Quantization

Tested on Apple Silicon (M4 Ultra, 247GB unified memory) with mlx-community/Qwen3-0.6B-4bit model.

Test Setup

Two server instances running identical workloads:

  • Port 1244: with KV cache quantization (8-bit, group_size=64)
  • Port 1245: without KV cache quantization (baseline)
vllm-mlx serve mlx-community/Qwen3-0.6B-4bit \
    --port 1244 --host 0.0.0.0 --continuous-batching \
    --max-num-seqs 4 --cache-memory-mb 2048 \
    --kv-cache-quantization --kv-cache-quantization-bits 8

Note: tested with a parallel implementation (#67, now closed as duplicate) using --kv-cache-bits 8. Same underlying mechanism (QuantizedKVCache + mx.dequantize()), results are directly applicable.

Functional Verification

Check Result
Server startup ✅ No errors
Startup log: KV cache quantization enabled: 8-bit ✅ Present
Health check /health {"status": "healthy"}
Chat completion (non-streaming) ✅ 298–330 tok/s
Prefix cache MISS (1st request) cache_fetch MISScache_store stored=True
Prefix cache HIT (2nd request, same system prompt) cache_fetch HIT cached=17 remaining=27
Response coherence ✅ Normal outputs, no artifacts from quantization

Memory Savings

Both servers received identical requests (same prompts, same max_tokens):

Cache state With 8-bit quant Without quant Savings
After 1st request (253 tokens cached) 19 MB 35 MB 46%
After 2nd request (447 tokens cached) 33 MB 62 MB 47%

Matches the theoretical ~50% savings (fp16 → uint8 + small overhead for scale/bias params).

Raw Log Evidence

Quantized server:

KV cache quantization enabled: 8-bit, group_size=64 (~50% memory savings)
[cache_store] request=e452afeb tokens=253 stored=True cache_entries=2 cache_mem=19MB
[cache_fetch] request=940349d0 HIT cached=17 remaining=27
[cache_store] request=940349d0 tokens=194 stored=True cache_entries=4 cache_mem=33MB

Baseline server:

[cache_store] request=f9353909 tokens=253 stored=True cache_entries=2 cache_mem=35MB
[cache_fetch] request=a06ca6eb HIT cached=17 remaining=27
[cache_store] request=a06ca6eb tokens=194 stored=True cache_entries=4 cache_mem=62MB

Conclusion

KV cache quantization works correctly with ~47% real-world memory savings. No impact on response quality or throughput. Prefix cache hit/miss behavior unchanged. ✅

@waybarrios waybarrios added the enhancement New feature or request label Feb 11, 2026
@waybarrios
Copy link
Copy Markdown
Owner Author

@janhilgard what do you think about the latency when we increase the tokens an reduce the bits in quantization? Any plan how to fix or deal with it more professionally?

@janhilgard
Copy link
Copy Markdown
Collaborator

Bug Report: Cache HIT path stores oversized KV arrays

Finding

When testing PR #62 locally against the same workload as above, I found that KV cache quantization only delivers memory savings on cache MISS, not on cache HIT.

Evidence

Four _quantize_cache calls were logged (2 per request: prompt-only + full):

Store Expected seq_len Actual keys shape mem_before mem_after
R1 prompt-only (53 tok) 53 (1, 8, 53, 128) 5.80MB 3.08MB
R1 full (253 tok) 253 (1, 8, 253, 128) 27.67MB 14.70MB
R2 prompt-only (44 tok) 44 (1, 8, **280**, 128) 30.62MB 16.27MB
R2 full (194 tok) 194 (1, 8, **430**, 128) 47.03MB 24.99MB
  • R2's prompt-only: seq_len=280 instead of 44. That's 253 (R1 full) + 27 (remaining tokens)
  • R2's full: seq_len=430 instead of 194. That's 280 + 150 (output tokens)

Root cause

When there's a prefix cache HIT, the cached KV arrays are passed to BatchGenerator.insert() via _merge_caches. The BatchGenerator appends new tokens to the existing cache buffer rather than creating a fresh buffer. When batch.extract_cache(e) is called, it returns the full accumulated buffer including the old cached data.

This means:

  1. Quantization works correctly (0.53x ratio on all calls) — but it quantizes oversized arrays
  2. Memory accounting is inflatedestimate_kv_cache_memory correctly measures the oversized quantized arrays
  3. Net effect: After 2 requests, cache_mem=62MB (same as unquantized baseline) vs expected ~33MB

Impact

The quantization provides no real memory savings for the common case of repeated requests with shared prefixes (cache HITs), which is the primary use case for prefix caching.

Suggested fix

Trim the extracted cache arrays to offset before quantizing in store(), or trim before calling _quantize_cache. Something like:

# Before quantizing, trim KV arrays to actual offset to avoid storing unused buffer
if self._config.kv_quantize:
    cache = _trim_to_offset(cache)  # new helper
    cache = _quantize_cache(cache, ...)

@janhilgard
Copy link
Copy Markdown
Collaborator

Good question. Your profiling already nailed the key insight — the production flow scales linearly because we never hold all three copies in memory simultaneously. The 615ms dequantize at 65K is very reasonable.

Here's my take on the latency picture:

What's already fine:

  • 8-bit at typical context lengths (≤32K): ~80-300ms round-trip is negligible compared to generation time (a 32K-token generation at 60 tok/s takes ~9 minutes)
  • The quantize cost is fully off the critical path (happens during post-generation cleanup)
  • Dequantize only fires on cache HITs, and the time saved by skipping prefill far exceeds the dequantize overhead

Where it gets interesting (4-bit):

  • Your benchmarks show 4-bit dequantize is actually faster than 8-bit at large head_dims (24.5ms vs 132ms at head_dim=4096). That's a nice property — the more you compress, the less data to move through Metal
  • The real concern with 4-bit isn't latency but quality: max_error=0.52 means individual KV values can be off by half a unit. For long multi-turn conversations where errors compound, this could cause subtle degradation

Concrete improvements we could pursue (roughly priority-ordered):

  1. PR Fix _trim_cache_offset for QuantizedKVCache layers #69 already helps — the _trim_cache_offset fix prevents storing oversized buffers on cache HITs, so we quantize/dequantize only the actual token count instead of inflated arrays (280 tokens → 44 tokens in our test case)

  2. Threshold-based quantization — only quantize entries above N tokens (e.g., 512). Short entries save little memory but still pay the full quantize/dequantize overhead per layer. A simple if len(tokens) < min_quantize_tokens: skip in store() would handle this

  3. Lazy dequantization — instead of dequantizing all layers upfront in fetch(), return the quantized cache and dequantize layer-by-layer during the model forward pass. This would overlap dequantize compute with prefill and avoid a single blocking dequantize call. Biggest win for long contexts but requires deeper integration with BatchGenerator

  4. Adaptive bit selection — expose a --kv-cache-bits auto mode that uses 8-bit by default and falls back to 4-bit only when cache memory exceeds a threshold (e.g., 80% of limit). This gives quality where it matters and compression when you need it

For now I'd say the current 8-bit implementation with the PR #69 fix covers the main use case well. The threshold-based approach (#2) would be a clean next step — low complexity, avoids unnecessary overhead on short entries. The lazy dequant (#3) is the real long-term win but requires more work.

What do you think? Want me to prototype the threshold approach?

When KV cache quantization is enabled, prefix cache entries are stored
as QuantizedKVCache objects. The _trim_cache_offset function (used for
supersequence and LCP matches) was silently skipping these layers
because QuantizedKVCache.keys returns a tuple, failing the
`not isinstance(keys, (list, tuple))` guard.

This caused the offset to remain untrimmed, so dequantized caches
passed to BatchGenerator had their original (large) offset. The
BatchGenerator then concatenated new tokens to the full buffer instead
of the trimmed prefix, producing oversized KV arrays that negated
all memory savings from quantization.

Tested: after fix, cache_mem=33MB (correct) vs 62MB (broken, same as
unquantized baseline) for the same 2-request workload.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
@waybarrios
Copy link
Copy Markdown
Owner Author

Nice breakdown on the latency side. Your priority ordering makes sense to me.

For PR #69, already merged, that fixes the QuantizedKVCache handling in _trim_cache_offset. The tuple guard was a subtle one, good find.

The threshold-based quantization is the one I'd want next. Quantizing a 50-token entry saves almost nothing but you still pay the full per-layer overhead. Something like a min_quantize_tokens check in store() with a reasonable default (maybe 256 or 512) would keep it simple. Go ahead and prototype that.

Lazy dequantization is appealing but yeah, threading that through BatchGenerator is a bigger change. Let's revisit once we have real numbers from the threshold approach.

Adaptive bit selection is cool but I think most people will be fine with 8-bit as the default. Anyone who needs 4-bit can just pass the flag. Lower priority for now.

On the oversized buffer thing you reported, that's a real issue but it lives in the cache extraction path, not in the quantization logic itself. The non-quantized path has the same problem, it just stores oversized FP16 arrays. A _trim_to_offset step before _quantize_cache in store() would fix it. Can you fold that into the threshold PR?

@janhilgard
Copy link
Copy Markdown
Collaborator

Both items from @waybarrios's comment are now implemented in PR #73:

  1. Threshold-based quantization — new kv_min_quantize_tokens config (default 256). Sequences shorter than this skip quantization since the overhead exceeds memory savings. Configurable via --kv-cache-min-quantize-tokens.

  2. Trim oversized KV buffers — new _trim_to_offset() helper trims pre-allocated KV arrays to their actual used size (offset) before storage. Applied unconditionally — saves memory in both FP16 and quantized paths.

PR: #73

@waybarrios
Copy link
Copy Markdown
Owner Author

@janhilgard any update so far on this manner?

@waybarrios
Copy link
Copy Markdown
Owner Author

Did you update the branch? @janhilgard

@janhilgard
Copy link
Copy Markdown
Collaborator

Hey! Here's the current status:

  1. PR Add KV cache quantization for prefix cache memory reduction #62 branch — I haven't pushed any additional changes since PR Fix _trim_cache_offset for QuantizedKVCache layers #69 got merged and you synced with main on Feb 12. The branch looks clean and mergeable.

  2. PR feat: min_quantize_tokens threshold + trim oversized KV buffers #73 (threshold + trim) — this is the follow-up you requested, implementing:

    • min_quantize_tokens threshold (default 256) to skip quantization for short sequences
    • _trim_to_offset() to trim oversized KV buffers before storage

    Last updated Feb 13 with a hardening commit. This one is based on main, not on #62, so it should be merged after Add KV cache quantization for prefix cache memory reduction #62.

  3. Remaining work from our discussion:

    • Lazy dequantization (layer-by-layer during forward pass) — haven't started this yet, it's a bigger change
    • Adaptive bit selection — lower priority, punting for now

I think #62 is ready to merge as-is. Then #73 can go on top. Want me to rebase #73 onto feat/kv-cache-quantization so they can be reviewed together, or keep them separate?

@janhilgard
Copy link
Copy Markdown
Collaborator

No changes from my side — the branch is as you left it after the merge with main on Feb 12. PR #69 fix is already in there. Ready to merge whenever you are.

* feat: add min_quantize_tokens threshold and trim oversized KV buffers

- Add _trim_to_offset() to trim pre-allocated KV arrays to their actual
  used size before storage, saving memory in both FP16 and quantized paths
- Add kv_min_quantize_tokens config (default 256) to skip quantization
  for short sequences where overhead exceeds savings
- Thread the new config through SchedulerConfig and CLI arguments

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: apply black formatting

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Harden _trim_to_offset and store() in memory cache

Move the duplicate entry check in store() before trim and quantize
so repeated tokens skip the expensive work entirely. Rewrite
_trim_to_offset to validate that offset is positive before slicing,
use KVCache() instead of __new__ to avoid skipping init, call
mx.eval on the sliced arrays so the original large buffer gets freed
and memory accounting stays accurate, and skip the function entirely
when no layer actually needs trimming.

Add validation for kv_min_quantize_tokens in MemoryCacheConfig so
negative values are rejected at init time. Document the field in the
class docstring and add Args and Returns to the _trim_to_offset
docstring.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Wayner Barrios <waybarrios@gmail.com>
Copy link
Copy Markdown
Owner Author

@waybarrios waybarrios left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran the full test suite locally after the latest push (23 quantization tests + 42 memory cache mock tests + full suite): 902 passed, 1 pre-existing failure (unrelated mock issue on main).

Pushed a fix for the ModuleNotFoundError: No module named 'mlx' that was failing test-matrix (3.10/3.11/3.12). The root cause: _trim_to_offset() imported mlx.core unconditionally at the top of the function, but store() calls it on every entry — including mock objects in tests running on Linux CI without MLX.

Fix: Extracted _needs_kv_trim() — a duck-typed helper that checks if any layer has oversized KV arrays using getattr instead of isinstance(layer, KVCache). The MLX imports now only happen when trimming is actually needed. Also fixed a ruff N806 lint error in _trim_cache_offset.

Changes

  1. Core quantization_quantize_cache() / _dequantize_cache() with store-on-quantize, fetch-on-dequantize pattern
  2. _trim_cache_offset fix for QuantizedKVCache — explicit handling with offset/group_size/bits preserved
  3. _trim_to_offset — trims pre-allocated KV buffers to actual used size before storage
  4. kv_min_quantize_tokens threshold — skips quantization for short sequences (default: 256)
  5. bench-kv-cache CLI command — synthetic benchmarks for quantization
  6. Config validationkv_min_quantize_tokens >= 0 enforced in __post_init__
  7. 23 tests covering round-trip, memory reduction, config, store/fetch integration, trim, and threshold

Improvements

  • The _dequantize_cache(x) if self._config.kv_quantize else x pattern appears 5x in fetch() — could be a _maybe_dequantize() method
  • Disk persistence (save_to_disk/load_from_disk): works with QuantizedKVCache (verified), but if kv_quantize=False at load time and entries were saved quantized, they'll be returned raw. Consider always calling _dequantize_cache on fetch regardless of config flag (it's a no-op on regular KVCache)
  • No test for supersequence/LCP match with quantized cache — the _trim_cache_offset fix addresses this path but no test exercises it directly

Overall the feature is solid and ready to merge once CI is green.

@waybarrios waybarrios force-pushed the feat/kv-cache-quantization branch from 73279ca to bb53581 Compare February 13, 2026 23:48
@waybarrios waybarrios merged commit 5f2fd32 into main Feb 14, 2026
7 checks passed
janhilgard added a commit to janhilgard/vllm-mlx that referenced this pull request Feb 14, 2026
Resolve merge conflicts from kv_min_quantize_tokens feature (PR waybarrios#62).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
sooth pushed a commit to sooth/vllm-mlx that referenced this pull request Feb 27, 2026
Merge 17 upstream commits including:
- KV cache quantization for prefix cache memory reduction (waybarrios#62)
- Streaming tool call parsing via ToolParser integration (waybarrios#46)
- MTP speculative decoding for Qwen3-Next (waybarrios#82)
- GPT-OSS reasoning parser and Harmony format parsers
- mlx-lm >= 0.30.5 requirement, transformers >= 5.0.0
- BatchMambaCache fix for mlx-lm >= 0.30.6 (waybarrios#89)
- MLLM continuous batching fixes (waybarrios#76)
- Force MLLM mode option (waybarrios#81)
- Various bug fixes

Conflict resolution:
- server.py: Replaced local tool_call_buffering with upstream's
  ToolParser-based streaming (more robust)
- cli.py: Deduplicated --mllm, --default-temperature, --default-top-p
  args (upstream already added them), kept local --embedding-model
- mamba_cache.py: Took upstream's conditional HAS_MAMBA_CACHE approach
- pyproject.toml: Took upstream's version and dependency changes

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@waybarrios waybarrios deleted the feat/kv-cache-quantization branch March 21, 2026 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add KV cache quantization support (--kv-cache-quantization)

2 participants