Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_mlx/api/anthropic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ class AnthropicUsage(BaseModel):
class AnthropicResponseContentBlock(BaseModel):
"""A content block in the Anthropic response."""

type: str # "text" or "tool_use"
type: str # "text", "thinking", or "tool_use"
text: str | None = None
# thinking block
thinking: str | None = None
# tool_use fields
id: str | None = None
name: str | None = None
Expand Down
176 changes: 156 additions & 20 deletions vllm_mlx/memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,44 +255,121 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry:


def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]:
"""Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced.
"""Create copies of cache layers with the last ``trim_by`` positions removed.

This is used when returning a cached KV state to the scheduler so that
the last N positions are "freed" and the model will recompute them on the
next forward pass (preventing duplicate KV entries).

Supports both KVCache (keys/values are arrays) and QuantizedKVCache
(keys/values are 3-tuples of arrays).
"""
from mlx_lm.models.cache import KVCache
For plain KVCache: reduces offset (surplus data beyond offset is harmless
since merge slices to ``keys[:, :, :offset, :]``).

try:
from mlx_lm.models.cache import QuantizedKVCache
except ImportError:
QuantizedKVCache = None # noqa: N806
For RotatingKVCache: actually trims the circular buffer — reducing offset
alone breaks ``size()`` / ``_temporal_order`` invariants.

Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper.
"""
import mlx.core as mx
from mlx_lm.models.cache import RotatingKVCache

trimmed: list[Any] = []
eval_targets: list[Any] = []
for layer_cache in cache:
if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache):
tc = QuantizedKVCache.__new__(QuantizedKVCache)
if isinstance(layer_cache, _QuantizedCacheWrapper):
# Shallow copy with reduced offset
tc = _QuantizedCacheWrapper.__new__(_QuantizedCacheWrapper)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
tc.group_size = layer_cache.group_size
tc.bits = layer_cache.bits
tc.group_size = layer_cache.group_size
tc.orig_type = layer_cache.orig_type
tc.orig_attrs = layer_cache.orig_attrs
trimmed.append(tc)
elif isinstance(layer_cache, RotatingKVCache):
if layer_cache.keys is None or trim_by <= 0:
trimmed.append(layer_cache)
continue
# RotatingKVCache: must trim buffer, not just offset.
# The buffer stores the last min(offset, max_size) tokens in a
# circular arrangement. Trimming excess positions from the END
# means removing the newest entries (chronologically last).
old_offset = layer_cache.offset
new_offset = max(old_offset - trim_by, 0)
old_size = min(old_offset, layer_cache.max_size)
entries_to_keep = max(0, old_size - trim_by)

orig_cls = type(layer_cache)
tc = orig_cls.__new__(orig_cls)
tc.offset = new_offset
tc.max_size = layer_cache.max_size
tc.keep = getattr(layer_cache, "keep", 0)
tc.step = getattr(layer_cache, "step", layer_cache.max_size)

if entries_to_keep <= 0:
# All buffer content is beyond the trim point — clear
tc.keys = None
tc.values = None
tc._idx = 0
elif entries_to_keep < old_size:
# Reorder to temporal order, keep the oldest entries
ordered_k = layer_cache._temporal_order(layer_cache.keys)
ordered_v = layer_cache._temporal_order(layer_cache.values)
kept_k = ordered_k[:, :, :entries_to_keep, :]
kept_v = ordered_v[:, :, :entries_to_keep, :]

if new_offset >= tc.max_size:
# Invariant: when offset >= max_size, buffer must be
# full (keys.shape[2] == max_size). Left-pad with
# zeros to restore the full buffer. Zeros represent
# positions evicted long ago; _idx = max_size so
# _temporal_order returns as-is and _update_in_place
# rotates to overwrite zeros first.
pad_n = tc.max_size - entries_to_keep
pad_k = mx.zeros(
(kept_k.shape[0], kept_k.shape[1], pad_n, kept_k.shape[3]),
dtype=kept_k.dtype,
)
pad_v = mx.zeros(
(kept_v.shape[0], kept_v.shape[1], pad_n, kept_v.shape[3]),
dtype=kept_v.dtype,
)
tc.keys = mx.concatenate([pad_k, kept_k], axis=2)
tc.values = mx.concatenate([pad_v, kept_v], axis=2)
tc._idx = tc.max_size
else:
tc.keys = kept_k
tc.values = kept_v
tc._idx = entries_to_keep
eval_targets.extend([tc.keys, tc.values])
else:
# No entries removed (trim_by == 0 already handled above,
# this covers entries_to_keep == old_size edge case)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc._idx = layer_cache._idx
trimmed.append(tc)
elif (
hasattr(layer_cache, "offset")
and hasattr(layer_cache, "keys")
and not isinstance(layer_cache.keys, (list, tuple))
):
tc = KVCache.__new__(KVCache)
orig_cls = type(layer_cache)
tc = orig_cls.__new__(orig_cls)
tc.keys = layer_cache.keys
tc.values = layer_cache.values
tc.offset = max(layer_cache.offset - trim_by, 0)
# Preserve type-specific attrs (max_size, keep, step, _idx)
for attr in ("max_size", "keep", "step", "_idx"):
if hasattr(layer_cache, attr):
setattr(tc, attr, getattr(layer_cache, attr))
trimmed.append(tc)
else:
trimmed.append(layer_cache)

if eval_targets:
mx.eval(*eval_targets)

return trimmed


Expand Down Expand Up @@ -353,35 +430,94 @@ def _trim_to_offset(cache: list[Any]) -> list[Any]:
return trimmed


class _QuantizedCacheWrapper:
"""Lightweight wrapper storing quantized KV arrays + original cache metadata.

Unlike ``QuantizedKVCache``, this preserves enough info to reconstruct
the *original* cache type (KVCache, RotatingKVCache, etc.) on dequantize.
"""

__slots__ = (
"keys",
"values",
"offset",
"bits",
"group_size",
"orig_type",
"orig_attrs",
)

def __init__(self, layer: Any, bits: int, group_size: int):
import mlx.core as mx

self.keys = mx.quantize(layer.keys, group_size=group_size, bits=bits)
self.values = mx.quantize(layer.values, group_size=group_size, bits=bits)
self.offset = layer.offset
self.bits = bits
self.group_size = group_size
self.orig_type = type(layer)
# Preserve RotatingKVCache-specific attrs
self.orig_attrs = {}
for attr in ("max_size", "keep", "step", "_idx"):
if hasattr(layer, attr):
self.orig_attrs[attr] = getattr(layer, attr)


def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]:
"""Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is."""
"""Quantize KV cache layers to reduce memory.

Only plain KVCache layers are quantized. RotatingKVCache (sliding window)
is left as-is because its internal _idx/rotation state is tightly coupled
with update_and_fetch logic and cannot survive quantize/dequantize roundtrip.
RotatingKVCache is typically small (max_size=1024) so skipping it is fine.
"""
from mlx_lm.models.cache import KVCache

quantized = []
for layer in cache:
if isinstance(layer, KVCache) and layer.keys is not None:
quantized.append(layer.to_quantized(group_size=group_size, bits=bits))
if type(layer) is KVCache and getattr(layer, "keys", None) is not None:
quantized.append(_QuantizedCacheWrapper(layer, bits, group_size))
else:
quantized.append(layer)
return quantized


def _dequantize_cache(cache: list[Any]) -> list[Any]:
"""Dequantize QuantizedKVCache layers back to regular KVCache."""
"""Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers.

All layers are copied (never returned by reference) so that the model's
``update_and_fetch`` mutations don't corrupt the stored cache entry.
"""
import mlx.core as mx
from mlx_lm.models.cache import KVCache, QuantizedKVCache

result = []
for layer in cache:
if isinstance(layer, QuantizedKVCache) and layer.keys is not None:
kv = KVCache()
if isinstance(layer, _QuantizedCacheWrapper):
# Reconstruct original cache type from quantized data
orig_cls = layer.orig_type
kv = orig_cls.__new__(orig_cls)
kv.keys = mx.dequantize(
*layer.keys, group_size=layer.group_size, bits=layer.bits
)
kv.values = mx.dequantize(
*layer.values, group_size=layer.group_size, bits=layer.bits
)
kv.offset = layer.offset
# Restore type-specific attrs (max_size, keep, step, _idx)
for attr, val in layer.orig_attrs.items():
setattr(kv, attr, val)
result.append(kv)
elif hasattr(layer, "keys") and hasattr(layer, "offset"):
# Deep-copy non-quantized cache layers (e.g. RotatingKVCache)
# so model's in-place mutations don't corrupt stored entries
orig_cls = type(layer)
kv = orig_cls.__new__(orig_cls)
kv.keys = mx.array(layer.keys) if layer.keys is not None else None
kv.values = mx.array(layer.values) if layer.values is not None else None
kv.offset = layer.offset
for attr in ("max_size", "keep", "step", "_idx"):
if hasattr(layer, attr):
setattr(kv, attr, getattr(layer, attr))
result.append(kv)
else:
result.append(layer)
Expand Down
77 changes: 40 additions & 37 deletions vllm_mlx/mllm_batch_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,44 @@ def extend(self, other: "MLLMBatch") -> None:

def extract_cache(self, idx: int) -> List[Any]:
"""
Extract cache for a single request (for caching).
Extract cache for a single request (for prefix caching).

Args:
idx: Index of request in batch

Returns:
Cache state for that request
Handles BatchRotatingKVCache negative left_padding bug:
during generation with rotation, left_padding becomes negative,
causing extract() to use Python negative indexing and truncate
the buffer to only generation tokens instead of the full window.
"""
return [c.extract(idx) if hasattr(c, "extract") else None for c in self.cache]
from mlx_lm.models.cache import (
BatchRotatingKVCache,
RotatingKVCache,
)

result = []
for c in self.cache:
if not hasattr(c, "extract"):
result.append(None)
elif isinstance(c, BatchRotatingKVCache):
# Custom extraction: clamp left_padding to >= 0
cache = RotatingKVCache(c.max_size)
padding = max(0, c.left_padding[idx].item())
offset = c.offset[idx].item()
cache.keys = c.keys[idx : idx + 1]
cache.values = c.values[idx : idx + 1]
cache._idx = c._idx
if c.rotated:
cache.keys = mx.roll(cache.keys, -c._idx, axis=2)
cache.values = mx.roll(cache.values, -c._idx, axis=2)
cache._idx = c.max_size
cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx])
cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx])
cache.offset = offset
cache._idx = cache.keys.shape[2]
cache.step = getattr(c, "step", c.max_size)
cache.keep = getattr(c, "keep", 0)
result.append(cache)
else:
result.append(c.extract(idx))
return result


class MLLMBatchStats:
Expand Down Expand Up @@ -205,32 +234,6 @@ def to_dict(self) -> Dict[str, Any]:
}


def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]:
"""
Create batch-aware KV cache for the language model.

Args:
model: The language model (model.language_model from VLM)
left_padding: Padding amounts for left-padded prompts

Returns:
List of BatchKVCache objects for each layer
"""
from mlx_lm.models.cache import BatchKVCache, KVCache

def to_batch_cache(c):
if isinstance(c, KVCache):
return BatchKVCache(left_padding)
else:
raise ValueError(f"{type(c)} does not yet support batching")

if hasattr(model, "make_cache"):
cache = model.make_cache()
return [to_batch_cache(c) for c in cache]
else:
return [BatchKVCache(left_padding) for _ in model.layers]


def _left_pad_prompts(
prompts: List[List[int]], max_length: Optional[int] = None
) -> mx.array:
Expand Down Expand Up @@ -679,10 +682,10 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch:
sample_cache = per_request_caches[0][0]
if not isinstance(sample_cache, (KVCache, RotatingKVCache)):
raise ValueError(
f"MLLM continuous batching requires KVCache or RotatingKVCache "
f"but got {type(sample_cache).__name__}. Disable "
f"--kv-cache-quantization when using multimodal models with "
f"--continuous-batching."
f"MLLM continuous batching requires KVCache or "
f"RotatingKVCache but got {type(sample_cache).__name__}. "
f"Disable --kv-cache-quantization when using multimodal "
f"models with --continuous-batching."
)

# Fix: RotatingKVCache._update_concat does NOT trim on first call —
Expand Down
Loading
Loading