Skip to content

Fix _trim_cache_offset for QuantizedKVCache layers#69

Merged
waybarrios merged 1 commit intowaybarrios:feat/kv-cache-quantizationfrom
janhilgard:fix/kv-cache-trim-quantized
Feb 11, 2026
Merged

Fix _trim_cache_offset for QuantizedKVCache layers#69
waybarrios merged 1 commit intowaybarrios:feat/kv-cache-quantizationfrom
janhilgard:fix/kv-cache-trim-quantized

Conversation

@janhilgard
Copy link
Copy Markdown
Collaborator

Summary

Fixes a bug in PR #62 where KV cache quantization provides no memory savings on prefix cache HITs.

Root cause

_trim_cache_offset() silently skips QuantizedKVCache layers because their .keys attribute returns a tuple, failing the not isinstance(keys, (list, tuple)) guard. The layer passes through with its original (untrimmed) offset.

After dequantization, the cache is passed to BatchGenerator with the full original offset (e.g., 253 instead of 17). The BatchGenerator then concatenates new tokens to the entire buffer, producing oversized KV arrays (280 tokens instead of 44).

Fix

Add explicit handling for QuantizedKVCache in _trim_cache_offset() — create a shallow copy with reduced offset, preserving group_size and bits attributes.

Test results

Same workload (2 requests with shared system prompt prefix):

State Before fix After fix Baseline (no quant)
After 1st request 19 MB 19 MB 35 MB
After 2nd request 62 MB (= baseline!) 33 MB 62 MB

Test plan

  • Server starts without errors with --kv-cache-quantization
  • First request (cache MISS): quantization works correctly (19 MB vs 35 MB baseline)
  • Second request (cache HIT): quantization now works correctly (33 MB vs 62 MB before fix)
  • Responses are coherent
  • black clean

🤖 Generated with Claude Code

When KV cache quantization is enabled, prefix cache entries are stored
as QuantizedKVCache objects. The _trim_cache_offset function (used for
supersequence and LCP matches) was silently skipping these layers
because QuantizedKVCache.keys returns a tuple, failing the
`not isinstance(keys, (list, tuple))` guard.

This caused the offset to remain untrimmed, so dequantized caches
passed to BatchGenerator had their original (large) offset. The
BatchGenerator then concatenated new tokens to the full buffer instead
of the trimmed prefix, producing oversized KV arrays that negated
all memory savings from quantization.

Tested: after fix, cache_mem=33MB (correct) vs 62MB (broken, same as
unquantized baseline) for the same 2-request workload.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@waybarrios waybarrios merged commit d812eb7 into waybarrios:feat/kv-cache-quantization Feb 11, 2026
waybarrios added a commit that referenced this pull request Feb 14, 2026
* Add KV cache quantization for prefix cache memory reduction

Adds --kv-cache-quantization flag that uses mlx-lm's QuantizedKVCache
to compress stored prefix cache entries (8-bit or 4-bit), reducing
memory usage ~3.5x with minimal quality loss (~0.005 mean abs error).

Quantization happens on store, dequantization on fetch, so active
inference is unaffected. Includes CLI flags for serve and bench
commands, config wiring through SchedulerConfig, and 16 tests.

Closes #60

* Add bench-kv-cache command to benchmark quantization savings

New CLI command that compares FP16, 8-bit, and 4-bit KV cache
quantization using synthetic data, reporting memory usage, compression
ratio, quality metrics, and quantize/dequantize latency.

Usage: vllm-mlx bench-kv-cache [--layers 32] [--seq-len 512]

* Release FP16 cache reference after quantized store

After storing the quantized cache in the prefix cache, the original
FP16 reference on the request is no longer needed. Setting it to None
allows the memory to be reclaimed sooner, preventing temporary memory
spikes when quantization is enabled on long sequences.

* Fix _trim_cache_offset to handle QuantizedKVCache layers (#69)

When KV cache quantization is enabled, prefix cache entries are stored
as QuantizedKVCache objects. The _trim_cache_offset function (used for
supersequence and LCP matches) was silently skipping these layers
because QuantizedKVCache.keys returns a tuple, failing the
`not isinstance(keys, (list, tuple))` guard.

This caused the offset to remain untrimmed, so dequantized caches
passed to BatchGenerator had their original (large) offset. The
BatchGenerator then concatenated new tokens to the full buffer instead
of the trimmed prefix, producing oversized KV arrays that negated
all memory savings from quantization.

Tested: after fix, cache_mem=33MB (correct) vs 62MB (broken, same as
unquantized baseline) for the same 2-request workload.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

* feat: min_quantize_tokens threshold + trim oversized KV buffers (#73)

* feat: add min_quantize_tokens threshold and trim oversized KV buffers

- Add _trim_to_offset() to trim pre-allocated KV arrays to their actual
  used size before storage, saving memory in both FP16 and quantized paths
- Add kv_min_quantize_tokens config (default 256) to skip quantization
  for short sequences where overhead exceeds savings
- Thread the new config through SchedulerConfig and CLI arguments

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: apply black formatting

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Harden _trim_to_offset and store() in memory cache

Move the duplicate entry check in store() before trim and quantize
so repeated tokens skip the expensive work entirely. Rewrite
_trim_to_offset to validate that offset is positive before slicing,
use KVCache() instead of __new__ to avoid skipping init, call
mx.eval on the sliced arrays so the original large buffer gets freed
and memory accounting stays accurate, and skip the function entirely
when no layer actually needs trimming.

Add validation for kv_min_quantize_tokens in MemoryCacheConfig so
negative values are rejected at init time. Document the field in the
class docstring and add Args and Returns to the _trim_to_offset
docstring.

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Wayner Barrios <waybarrios@gmail.com>

* Defer MLX import in _trim_to_offset to fix non-Apple CI

* Fix mock path in TestMemoryStats to match mx.get_active_memory API

---------

Co-authored-by: Jan Hilgard <89418784+janhilgard@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants