diff --git a/vllm_mlx/api/anthropic_models.py b/vllm_mlx/api/anthropic_models.py index a5bc6f776..e8854a5e6 100644 --- a/vllm_mlx/api/anthropic_models.py +++ b/vllm_mlx/api/anthropic_models.py @@ -84,8 +84,10 @@ class AnthropicUsage(BaseModel): class AnthropicResponseContentBlock(BaseModel): """A content block in the Anthropic response.""" - type: str # "text" or "tool_use" + type: str # "text", "thinking", or "tool_use" text: str | None = None + # thinking block + thinking: str | None = None # tool_use fields id: str | None = None name: str | None = None diff --git a/vllm_mlx/memory_cache.py b/vllm_mlx/memory_cache.py index f43763541..111439330 100644 --- a/vllm_mlx/memory_cache.py +++ b/vllm_mlx/memory_cache.py @@ -255,44 +255,121 @@ def create(cls, tokens: list[int], cache: list[Any]) -> _CacheEntry: def _trim_cache_offset(cache: list[Any], trim_by: int) -> list[Any]: - """Create shallow copies of KVCache/QuantizedKVCache layers with offset reduced. + """Create copies of cache layers with the last ``trim_by`` positions removed. This is used when returning a cached KV state to the scheduler so that the last N positions are "freed" and the model will recompute them on the next forward pass (preventing duplicate KV entries). - Supports both KVCache (keys/values are arrays) and QuantizedKVCache - (keys/values are 3-tuples of arrays). - """ - from mlx_lm.models.cache import KVCache + For plain KVCache: reduces offset (surplus data beyond offset is harmless + since merge slices to ``keys[:, :, :offset, :]``). - try: - from mlx_lm.models.cache import QuantizedKVCache - except ImportError: - QuantizedKVCache = None # noqa: N806 + For RotatingKVCache: actually trims the circular buffer — reducing offset + alone breaks ``size()`` / ``_temporal_order`` invariants. + + Supports KVCache, RotatingKVCache, and _QuantizedCacheWrapper. + """ + import mlx.core as mx + from mlx_lm.models.cache import RotatingKVCache trimmed: list[Any] = [] + eval_targets: list[Any] = [] for layer_cache in cache: - if QuantizedKVCache is not None and isinstance(layer_cache, QuantizedKVCache): - tc = QuantizedKVCache.__new__(QuantizedKVCache) + if isinstance(layer_cache, _QuantizedCacheWrapper): + # Shallow copy with reduced offset + tc = _QuantizedCacheWrapper.__new__(_QuantizedCacheWrapper) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) - tc.group_size = layer_cache.group_size tc.bits = layer_cache.bits + tc.group_size = layer_cache.group_size + tc.orig_type = layer_cache.orig_type + tc.orig_attrs = layer_cache.orig_attrs + trimmed.append(tc) + elif isinstance(layer_cache, RotatingKVCache): + if layer_cache.keys is None or trim_by <= 0: + trimmed.append(layer_cache) + continue + # RotatingKVCache: must trim buffer, not just offset. + # The buffer stores the last min(offset, max_size) tokens in a + # circular arrangement. Trimming excess positions from the END + # means removing the newest entries (chronologically last). + old_offset = layer_cache.offset + new_offset = max(old_offset - trim_by, 0) + old_size = min(old_offset, layer_cache.max_size) + entries_to_keep = max(0, old_size - trim_by) + + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) + tc.offset = new_offset + tc.max_size = layer_cache.max_size + tc.keep = getattr(layer_cache, "keep", 0) + tc.step = getattr(layer_cache, "step", layer_cache.max_size) + + if entries_to_keep <= 0: + # All buffer content is beyond the trim point — clear + tc.keys = None + tc.values = None + tc._idx = 0 + elif entries_to_keep < old_size: + # Reorder to temporal order, keep the oldest entries + ordered_k = layer_cache._temporal_order(layer_cache.keys) + ordered_v = layer_cache._temporal_order(layer_cache.values) + kept_k = ordered_k[:, :, :entries_to_keep, :] + kept_v = ordered_v[:, :, :entries_to_keep, :] + + if new_offset >= tc.max_size: + # Invariant: when offset >= max_size, buffer must be + # full (keys.shape[2] == max_size). Left-pad with + # zeros to restore the full buffer. Zeros represent + # positions evicted long ago; _idx = max_size so + # _temporal_order returns as-is and _update_in_place + # rotates to overwrite zeros first. + pad_n = tc.max_size - entries_to_keep + pad_k = mx.zeros( + (kept_k.shape[0], kept_k.shape[1], pad_n, kept_k.shape[3]), + dtype=kept_k.dtype, + ) + pad_v = mx.zeros( + (kept_v.shape[0], kept_v.shape[1], pad_n, kept_v.shape[3]), + dtype=kept_v.dtype, + ) + tc.keys = mx.concatenate([pad_k, kept_k], axis=2) + tc.values = mx.concatenate([pad_v, kept_v], axis=2) + tc._idx = tc.max_size + else: + tc.keys = kept_k + tc.values = kept_v + tc._idx = entries_to_keep + eval_targets.extend([tc.keys, tc.values]) + else: + # No entries removed (trim_by == 0 already handled above, + # this covers entries_to_keep == old_size edge case) + tc.keys = layer_cache.keys + tc.values = layer_cache.values + tc._idx = layer_cache._idx trimmed.append(tc) elif ( hasattr(layer_cache, "offset") and hasattr(layer_cache, "keys") and not isinstance(layer_cache.keys, (list, tuple)) ): - tc = KVCache.__new__(KVCache) + orig_cls = type(layer_cache) + tc = orig_cls.__new__(orig_cls) tc.keys = layer_cache.keys tc.values = layer_cache.values tc.offset = max(layer_cache.offset - trim_by, 0) + # Preserve type-specific attrs (max_size, keep, step, _idx) + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer_cache, attr): + setattr(tc, attr, getattr(layer_cache, attr)) trimmed.append(tc) else: trimmed.append(layer_cache) + + if eval_targets: + mx.eval(*eval_targets) + return trimmed @@ -353,28 +430,72 @@ def _trim_to_offset(cache: list[Any]) -> list[Any]: return trimmed +class _QuantizedCacheWrapper: + """Lightweight wrapper storing quantized KV arrays + original cache metadata. + + Unlike ``QuantizedKVCache``, this preserves enough info to reconstruct + the *original* cache type (KVCache, RotatingKVCache, etc.) on dequantize. + """ + + __slots__ = ( + "keys", + "values", + "offset", + "bits", + "group_size", + "orig_type", + "orig_attrs", + ) + + def __init__(self, layer: Any, bits: int, group_size: int): + import mlx.core as mx + + self.keys = mx.quantize(layer.keys, group_size=group_size, bits=bits) + self.values = mx.quantize(layer.values, group_size=group_size, bits=bits) + self.offset = layer.offset + self.bits = bits + self.group_size = group_size + self.orig_type = type(layer) + # Preserve RotatingKVCache-specific attrs + self.orig_attrs = {} + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + self.orig_attrs[attr] = getattr(layer, attr) + + def _quantize_cache(cache: list[Any], bits: int = 8, group_size: int = 64) -> list[Any]: - """Quantize KVCache layers to reduce memory. Non-KVCache layers are kept as-is.""" + """Quantize KV cache layers to reduce memory. + + Only plain KVCache layers are quantized. RotatingKVCache (sliding window) + is left as-is because its internal _idx/rotation state is tightly coupled + with update_and_fetch logic and cannot survive quantize/dequantize roundtrip. + RotatingKVCache is typically small (max_size=1024) so skipping it is fine. + """ from mlx_lm.models.cache import KVCache quantized = [] for layer in cache: - if isinstance(layer, KVCache) and layer.keys is not None: - quantized.append(layer.to_quantized(group_size=group_size, bits=bits)) + if type(layer) is KVCache and getattr(layer, "keys", None) is not None: + quantized.append(_QuantizedCacheWrapper(layer, bits, group_size)) else: quantized.append(layer) return quantized def _dequantize_cache(cache: list[Any]) -> list[Any]: - """Dequantize QuantizedKVCache layers back to regular KVCache.""" + """Dequantize _QuantizedCacheWrapper layers and copy non-quantized layers. + + All layers are copied (never returned by reference) so that the model's + ``update_and_fetch`` mutations don't corrupt the stored cache entry. + """ import mlx.core as mx - from mlx_lm.models.cache import KVCache, QuantizedKVCache result = [] for layer in cache: - if isinstance(layer, QuantizedKVCache) and layer.keys is not None: - kv = KVCache() + if isinstance(layer, _QuantizedCacheWrapper): + # Reconstruct original cache type from quantized data + orig_cls = layer.orig_type + kv = orig_cls.__new__(orig_cls) kv.keys = mx.dequantize( *layer.keys, group_size=layer.group_size, bits=layer.bits ) @@ -382,6 +503,21 @@ def _dequantize_cache(cache: list[Any]) -> list[Any]: *layer.values, group_size=layer.group_size, bits=layer.bits ) kv.offset = layer.offset + # Restore type-specific attrs (max_size, keep, step, _idx) + for attr, val in layer.orig_attrs.items(): + setattr(kv, attr, val) + result.append(kv) + elif hasattr(layer, "keys") and hasattr(layer, "offset"): + # Deep-copy non-quantized cache layers (e.g. RotatingKVCache) + # so model's in-place mutations don't corrupt stored entries + orig_cls = type(layer) + kv = orig_cls.__new__(orig_cls) + kv.keys = mx.array(layer.keys) if layer.keys is not None else None + kv.values = mx.array(layer.values) if layer.values is not None else None + kv.offset = layer.offset + for attr in ("max_size", "keep", "step", "_idx"): + if hasattr(layer, attr): + setattr(kv, attr, getattr(layer, attr)) result.append(kv) else: result.append(layer) diff --git a/vllm_mlx/mllm_batch_generator.py b/vllm_mlx/mllm_batch_generator.py index a8845c5e8..1de137587 100644 --- a/vllm_mlx/mllm_batch_generator.py +++ b/vllm_mlx/mllm_batch_generator.py @@ -156,15 +156,44 @@ def extend(self, other: "MLLMBatch") -> None: def extract_cache(self, idx: int) -> List[Any]: """ - Extract cache for a single request (for caching). + Extract cache for a single request (for prefix caching). - Args: - idx: Index of request in batch - - Returns: - Cache state for that request + Handles BatchRotatingKVCache negative left_padding bug: + during generation with rotation, left_padding becomes negative, + causing extract() to use Python negative indexing and truncate + the buffer to only generation tokens instead of the full window. """ - return [c.extract(idx) if hasattr(c, "extract") else None for c in self.cache] + from mlx_lm.models.cache import ( + BatchRotatingKVCache, + RotatingKVCache, + ) + + result = [] + for c in self.cache: + if not hasattr(c, "extract"): + result.append(None) + elif isinstance(c, BatchRotatingKVCache): + # Custom extraction: clamp left_padding to >= 0 + cache = RotatingKVCache(c.max_size) + padding = max(0, c.left_padding[idx].item()) + offset = c.offset[idx].item() + cache.keys = c.keys[idx : idx + 1] + cache.values = c.values[idx : idx + 1] + cache._idx = c._idx + if c.rotated: + cache.keys = mx.roll(cache.keys, -c._idx, axis=2) + cache.values = mx.roll(cache.values, -c._idx, axis=2) + cache._idx = c.max_size + cache.keys = mx.contiguous(cache.keys[:, :, padding : cache._idx]) + cache.values = mx.contiguous(cache.values[:, :, padding : cache._idx]) + cache.offset = offset + cache._idx = cache.keys.shape[2] + cache.step = getattr(c, "step", c.max_size) + cache.keep = getattr(c, "keep", 0) + result.append(cache) + else: + result.append(c.extract(idx)) + return result class MLLMBatchStats: @@ -205,32 +234,6 @@ def to_dict(self) -> Dict[str, Any]: } -def _make_batch_cache(model: nn.Module, left_padding: List[int]) -> List[Any]: - """ - Create batch-aware KV cache for the language model. - - Args: - model: The language model (model.language_model from VLM) - left_padding: Padding amounts for left-padded prompts - - Returns: - List of BatchKVCache objects for each layer - """ - from mlx_lm.models.cache import BatchKVCache, KVCache - - def to_batch_cache(c): - if isinstance(c, KVCache): - return BatchKVCache(left_padding) - else: - raise ValueError(f"{type(c)} does not yet support batching") - - if hasattr(model, "make_cache"): - cache = model.make_cache() - return [to_batch_cache(c) for c in cache] - else: - return [BatchKVCache(left_padding) for _ in model.layers] - - def _left_pad_prompts( prompts: List[List[int]], max_length: Optional[int] = None ) -> mx.array: @@ -679,10 +682,10 @@ def _process_prompts(self, requests: List[MLLMBatchRequest]) -> MLLMBatch: sample_cache = per_request_caches[0][0] if not isinstance(sample_cache, (KVCache, RotatingKVCache)): raise ValueError( - f"MLLM continuous batching requires KVCache or RotatingKVCache " - f"but got {type(sample_cache).__name__}. Disable " - f"--kv-cache-quantization when using multimodal models with " - f"--continuous-batching." + f"MLLM continuous batching requires KVCache or " + f"RotatingKVCache but got {type(sample_cache).__name__}. " + f"Disable --kv-cache-quantization when using multimodal " + f"models with --continuous-batching." ) # Fix: RotatingKVCache._update_concat does NOT trim on first call — diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index af10e7341..9b66bbe2f 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -42,6 +42,7 @@ import json import logging import os +import re import secrets import tempfile import threading @@ -57,8 +58,13 @@ # Import from new modular API # Re-export for backwards compatibility with tests -from .api.anthropic_adapter import anthropic_to_openai, openai_to_anthropic -from .api.anthropic_models import AnthropicRequest +from .api.anthropic_adapter import anthropic_to_openai +from .api.anthropic_models import ( + AnthropicRequest, + AnthropicResponse, + AnthropicResponseContentBlock, + AnthropicUsage, +) from .api.models import ( AssistantMessage, # noqa: F401 ChatCompletionChoice, # noqa: F401 @@ -1535,6 +1541,17 @@ def _inject_json_instruction(messages: list, instruction: str) -> list: # ============================================================================= +def _convert_anthropic_stop_reason(openai_reason: str | None) -> str: + """Convert OpenAI finish_reason to Anthropic stop_reason.""" + mapping = { + "stop": "end_turn", + "tool_calls": "tool_use", + "length": "max_tokens", + "content_filter": "end_turn", + } + return mapping.get(openai_reason or "", "end_turn") + + @app.post("/v1/messages") async def create_anthropic_message( request: Request, @@ -1549,8 +1566,19 @@ async def create_anthropic_message( """ engine = get_engine() - # Parse the raw body to handle Anthropic request format - body = await request.json() + # Parse the raw body to handle Anthropic request format. + # Some clients (e.g. Claude Code) may send JSON with invalid escape + # sequences like \s, \d in regex patterns within tool definitions. + # Python's json.loads is strict per RFC 8259 and rejects these. + try: + body = await request.json() + except json.JSONDecodeError as e: + if "Invalid \\escape" in str(e): + raw = await request.body() + # Replace lone backslashes (not valid JSON escapes) with \\ + body = json.loads(re.sub(rb'\\(?!["\\/bfnrtu])', rb"\\\\", raw)) + else: + raise anthropic_request = AnthropicRequest(**body) _validate_model_name(anthropic_request.model) @@ -1627,35 +1655,63 @@ async def create_anthropic_message( output.text, openai_request ) + # Extract reasoning if parser is configured + reasoning_text = None + if _reasoning_parser and not tool_calls: + text_to_parse = cleaned_text or output.text + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + text_to_parse + ) + # Clean output text final_content = None if cleaned_text: final_content = clean_output_text(cleaned_text) - # Determine finish reason - finish_reason = "tool_calls" if tool_calls else output.finish_reason + # Build Anthropic content blocks directly (with thinking support) + content_blocks = [] - # Build OpenAI response to convert - openai_response = ChatCompletionResponse( - model=_model_name, - choices=[ - ChatCompletionChoice( - message=AssistantMessage( - content=final_content, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, + if reasoning_text: + content_blocks.append( + AnthropicResponseContentBlock(type="thinking", thinking=reasoning_text) + ) + + if final_content: + content_blocks.append( + AnthropicResponseContentBlock(type="text", text=final_content) + ) + + if tool_calls: + for tc in tool_calls: + try: + tool_input = json.loads(tc.function.arguments) + except (json.JSONDecodeError, AttributeError): + tool_input = {} + content_blocks.append( + AnthropicResponseContentBlock( + type="tool_use", + id=tc.id, + name=tc.function.name, + input=tool_input, + ) ) - ], - usage=Usage( - prompt_tokens=output.prompt_tokens, - completion_tokens=output.completion_tokens, - total_tokens=output.prompt_tokens + output.completion_tokens, - ), + + if not content_blocks: + content_blocks.append(AnthropicResponseContentBlock(type="text", text="")) + + stop_reason = _convert_anthropic_stop_reason( + "tool_calls" if tool_calls else output.finish_reason ) - # Convert to Anthropic response - anthropic_response = openai_to_anthropic(openai_response, _model_name) + anthropic_response = AnthropicResponse( + model=_model_name, + content=content_blocks, + stop_reason=stop_reason, + usage=AnthropicUsage( + input_tokens=output.prompt_tokens, + output_tokens=output.completion_tokens, + ), + ) return Response( content=anthropic_response.model_dump_json(exclude_none=True), media_type="application/json", @@ -1836,26 +1892,39 @@ async def _stream_anthropic_messages( # Stream pipeline: raw text → tool call filter → think router → emit # - Tool call filter strips tool call markup (emitted as structured blocks later) - # - Think router separates content into Anthropic thinking blocks + # - Think router separates reasoning from content into Anthropic blocks + # + # When a reasoning parser is configured (e.g. --reasoning-parser gemma4), + # it replaces the generic StreamingThinkRouter to handle model-specific + # reasoning formats (e.g. Gemma 4 <|channel>thought...). accumulated_text = "" + use_reasoning_parser = _reasoning_parser is not None tool_filter = StreamingToolCallFilter() - # Detect if the model's chat template injects into the - # generation prompt. If so, the model starts in thinking mode and - # the opening tag never appears in the output stream. - _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None - _chat_template = "" - if _tokenizer and hasattr(_tokenizer, "chat_template"): - _chat_template = _tokenizer.chat_template or "" - _starts_thinking = ( - "" in _chat_template and "add_generation_prompt" in _chat_template - ) - think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + + if use_reasoning_parser: + _reasoning_parser.reset_state() + think_router = None + else: + # Detect if the model's chat template injects into the + # generation prompt. If so, the model starts in thinking mode and + # the opening tag never appears in the output stream. + _tokenizer = engine.tokenizer if hasattr(engine, "tokenizer") else None + _chat_template = "" + if _tokenizer and hasattr(_tokenizer, "chat_template"): + _chat_template = _tokenizer.chat_template or "" + _starts_thinking = ( + "" in _chat_template and "add_generation_prompt" in _chat_template + ) + think_router = StreamingThinkRouter(start_in_thinking=_starts_thinking) + prompt_tokens = 0 completion_tokens = 0 # Track which content blocks we've started current_block_type = None # "thinking" or "text" block_index = 0 + # For reasoning parser: track accumulated text for parser context + reasoning_accumulated = "" async for output in engine.stream_chat(messages=messages, **chat_kwargs): delta_text = output.new_text @@ -1878,30 +1947,62 @@ async def _stream_anthropic_messages( filtered = tool_filter.process(content) if not filtered: continue - # Stage 2: route thinking vs text - pieces = think_router.process(filtered) + + if use_reasoning_parser: + # Stage 2a: use reasoning parser for model-specific formats + prev = reasoning_accumulated + reasoning_accumulated += filtered + delta_msg = _reasoning_parser.extract_reasoning_streaming( + prev, reasoning_accumulated, filtered + ) + if delta_msg is None: + continue + pieces = [] + if delta_msg.reasoning: + pieces.append(("thinking", delta_msg.reasoning)) + if delta_msg.content: + pieces.append(("text", delta_msg.content)) + else: + # Stage 2b: generic tag router + pieces = think_router.process(filtered) + events, current_block_type, block_index = _emit_content_pieces( pieces, current_block_type, block_index ) for event in events: yield event - # Flush remaining from both filters + # Flush remaining from tool filter remaining = tool_filter.flush() if remaining: + if use_reasoning_parser: + prev = reasoning_accumulated + reasoning_accumulated += remaining + delta_msg = _reasoning_parser.extract_reasoning_streaming( + prev, reasoning_accumulated, remaining + ) + pieces = [] + if delta_msg: + if delta_msg.reasoning: + pieces.append(("thinking", delta_msg.reasoning)) + if delta_msg.content: + pieces.append(("text", delta_msg.content)) + else: + pieces = think_router.process(remaining) events, current_block_type, block_index = _emit_content_pieces( - think_router.process(remaining), current_block_type, block_index + pieces, current_block_type, block_index ) for event in events: yield event - flush_pieces = think_router.flush() - if flush_pieces: - events, current_block_type, block_index = _emit_content_pieces( - flush_pieces, current_block_type, block_index - ) - for event in events: - yield event + if not use_reasoning_parser: + flush_pieces = think_router.flush() + if flush_pieces: + events, current_block_type, block_index = _emit_content_pieces( + flush_pieces, current_block_type, block_index + ) + for event in events: + yield event # Close final content block if current_block_type is not None: