From 61cf4e0fa579bebeee0bd51e721aa8a637412c30 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 10:57:48 -0800 Subject: [PATCH 1/7] =?UTF-8?q?fix:=205=20bugs=20from=20code=20review=20?= =?UTF-8?q?=E2=80=94=20init=20crash,=20JSON=20corruption,=20GC=20leak,=20c?= =?UTF-8?q?loud=20gaps?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. SimpleEngine._inject_shared_model: set missing MLXLanguageModel attributes (prefill_step_size, kv_bits, kv_group_size, _prompt_cache, _cached_token_ids, _cache_lock) that __new__ skips, preventing AttributeError on first generate 2. Non-streaming chat: guard extract_json_from_response with `if response_format` so plain text responses aren't corrupted by JSON extraction 3. stream_chat_completion: wrap generator body in try/finally so gc.enable() runs even on client disconnect, preventing permanent GC disable 4. Cloud streaming: wrap with _disconnect_guard like local streaming path 5. Cloud routing: forward response_format to cloud provider so structured output works consistently regardless of routing decision Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/engine/simple.py | 6 + vllm_mlx/server.py | 746 +++++++++++++++++++------------------- 2 files changed, 384 insertions(+), 368 deletions(-) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 79a3b2bc..02aa0f94 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -593,6 +593,12 @@ async def _inject_shared_model( self._model.trust_remote_code = self._trust_remote_code self._model.draft_model_name = self._draft_model_name self._model.num_draft_tokens = self._num_draft_tokens + self._model.prefill_step_size = 2048 + self._model.kv_bits = None + self._model.kv_group_size = 64 + self._model._prompt_cache = None + self._model._cached_token_ids = [] + self._model._cache_lock = False self._model.model = model self._model.tokenizer = tokenizer self._model.draft_model = None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 460298b7..d09e56af 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1806,6 +1806,11 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re "max_tokens": chat_kwargs.get("max_tokens"), "top_p": chat_kwargs.get("top_p"), } + if request.response_format: + rf = request.response_format + cloud_kwargs["response_format"] = ( + rf.model_dump() if hasattr(rf, "model_dump") else rf + ) if request.tools: # Pass raw tool defs (OpenAI format), not template-converted cloud_kwargs["tools"] = [ @@ -1814,10 +1819,13 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ] if request.stream: return StreamingResponse( - _cloud_router.stream_completion( - cloud_messages, - model_name=_model_name or "cloud", - **cloud_kwargs, + _disconnect_guard( + _cloud_router.stream_completion( + cloud_messages, + model_name=_model_name or "cloud", + **cloud_kwargs, + ), + raw_request, ), media_type="text/event-stream", ) @@ -1967,9 +1975,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re final_content = None if cleaned_text: final_content = strip_thinking_tags(clean_output_text(cleaned_text)) - # If response looks like it ends with JSON, extract just the JSON part - # This handles Qwen3 reasoning mode: "Let me think... {json}" - final_content = extract_json_from_response(final_content) + # If JSON mode requested, extract JSON from reasoning text + # (e.g., Qwen3 reasoning mode: "Let me think... {json}") + if response_format: + final_content = extract_json_from_response(final_content) # Build logprobs for response if requested choice_logprobs = None @@ -2434,386 +2443,387 @@ async def stream_chat_completion( if _gc_control and gc_was_enabled: gc.disable() - response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" - start_time = time.perf_counter() - - # Check if we should include usage in the final chunk - include_usage = request.stream_options and request.stream_options.include_usage + try: + response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" + start_time = time.perf_counter() - # Logprobs configuration - want_logprobs = request.logprobs and request.top_logprobs - top_k_logprobs = request.top_logprobs or 0 + # Check if we should include usage in the final chunk + include_usage = request.stream_options and request.stream_options.include_usage - def _build_chunk_logprobs(output: GenerationOutput) -> ChoiceLogProbs | None: - """Build ChoiceLogProbs for a streaming chunk if logprobs requested.""" - if not want_logprobs or output.logprobs is None: - return None - token_id = output.tokens[-1] if output.tokens else 0 - token_lp = _extract_token_logprob( - output.logprobs, token_id, engine.tokenizer, top_k_logprobs - ) - return ChoiceLogProbs(content=[token_lp]) + # Logprobs configuration + want_logprobs = request.logprobs and request.top_logprobs + top_k_logprobs = request.top_logprobs or 0 - # First chunk with role - first_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(role="assistant"), + def _build_chunk_logprobs(output: GenerationOutput) -> ChoiceLogProbs | None: + """Build ChoiceLogProbs for a streaming chunk if logprobs requested.""" + if not want_logprobs or output.logprobs is None: + return None + token_id = output.tokens[-1] if output.tokens else 0 + token_lp = _extract_token_logprob( + output.logprobs, token_id, engine.tokenizer, top_k_logprobs ) - ], - ) - _first_sse = f"data: {first_chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-ROLE] {_first_sse.strip()[:200]}") - yield _first_sse + return ChoiceLogProbs(content=[token_lp]) - # Track if we need to add prefix for thinking models (when no reasoning parser) - # The template adds to the prompt, so the model output starts inside the think block - is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser - think_prefix_sent = False - - # Reset reasoning parser state for this stream - if _reasoning_parser: - _reasoning_parser.reset_state() - - # Track accumulated text for reasoning parser - accumulated_text = "" - - # Track token counts for usage reporting - prompt_tokens = 0 - completion_tokens = 0 - last_output = None - - # Tool call streaming state - global _tool_parser_instance - tool_parser = None - tool_accumulated_text = "" - tool_calls_detected = False - tool_markup_possible = False # Fast path: skip parsing until '<' seen - if _enable_auto_tool_choice and _tool_call_parser: - # Initialize parser if needed (same as _parse_tool_calls_with_parser) - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - logger.info(f"Initialized tool call parser: {_tool_call_parser}") - except Exception as e: - logger.warning(f"Failed to init tool parser for streaming: {e}") - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() - - # Fallback: auto-infer tool parser when tools requested + reasoning parser set - if tool_parser is None and request.tools and _reasoning_parser_name: - _PARSER_MAP = {"minimax": "minimax"} - inferred = _PARSER_MAP.get(_reasoning_parser_name) - if inferred: - try: - parser_cls = ToolParserManager.get_tool_parser(inferred) - tokenizer = getattr(_engine, "_tokenizer", None) - tool_parser = parser_cls(tokenizer) + # First chunk with role + first_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(role="assistant"), + ) + ], + ) + _first_sse = f"data: {first_chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-ROLE] {_first_sse.strip()[:200]}") + yield _first_sse + + # Track if we need to add prefix for thinking models (when no reasoning parser) + # The template adds to the prompt, so the model output starts inside the think block + is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser + think_prefix_sent = False + + # Reset reasoning parser state for this stream + if _reasoning_parser: + _reasoning_parser.reset_state() + + # Track accumulated text for reasoning parser + accumulated_text = "" + + # Track token counts for usage reporting + prompt_tokens = 0 + completion_tokens = 0 + last_output = None + + # Tool call streaming state + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_calls_detected = False + tool_markup_possible = False # Fast path: skip parsing until '<' seen + if _enable_auto_tool_choice and _tool_call_parser: + # Initialize parser if needed (same as _parse_tool_calls_with_parser) + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info(f"Initialized tool call parser: {_tool_call_parser}") + except Exception as e: + logger.warning(f"Failed to init tool parser for streaming: {e}") + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance tool_parser.reset() - except Exception as e: - logger.debug(f"Auto-infer tool parser for streaming failed: {e}") - - # Stream content - async for output in engine.stream_chat(messages=messages, **kwargs): - delta_text = output.new_text - last_output = output - - # Track token counts from output (updated each chunk) - if hasattr(output, "prompt_tokens") and output.prompt_tokens: - prompt_tokens = output.prompt_tokens - if hasattr(output, "completion_tokens") and output.completion_tokens: - completion_tokens = output.completion_tokens - - # Use reasoning parser if enabled - if _reasoning_parser and delta_text: - previous_text = accumulated_text - accumulated_text += delta_text - delta_msg = _reasoning_parser.extract_reasoning_streaming( - previous_text, accumulated_text, delta_text - ) - - if delta_msg is None: - # Skip this chunk (e.g., token itself) - continue - content = delta_msg.content - reasoning = delta_msg.reasoning - - # Some models (e.g. MiniMax) wrap tool calls in - # blocks, so reasoning parser captures tool call XML as - # reasoning while content stays None. Redirect reasoning - # to the content stream so the tool parser can handle it. - # Check even when content is present (e.g. "\n" from - # boundary) to avoid XML leaking as reasoning. - if tool_parser and reasoning: - _check = tool_accumulated_text + reasoning - if ( - "" in _check - or "" in _check - or ' token itself) + continue + + content = delta_msg.content + reasoning = delta_msg.reasoning + + # Some models (e.g. MiniMax) wrap tool calls in + # blocks, so reasoning parser captures tool call XML as + # reasoning while content stays None. Redirect reasoning + # to the content stream so the tool parser can handle it. + # Check even when content is present (e.g. "\n" from + # boundary) to avoid XML leaking as reasoning. + if tool_parser and reasoning: + _check = tool_accumulated_text + reasoning + if ( + "" in _check + or "" in _check + or ' prefix on first content chunk for thinking models + if is_thinking_model and not think_prefix_sent and content: + content = "" + content + think_prefix_sent = True + + # Tool call streaming parsing + if tool_parser and delta_text: + # Fast path: skip full parsing until '<' is seen in the stream, + # which could start tool markup (e.g. ). This avoids + # per-token string scanning on the growing accumulated text. + if not tool_markup_possible and "<" not in delta_text: + tool_accumulated_text += delta_text + # No tool markup yet, fall through to normal chunk emission + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += delta_text + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, delta_text + ) - # Filter special tokens that may leak into streaming output - if content: - content = SPECIAL_TOKENS_PATTERN.sub("", content) - - # Add prefix on first content chunk for thinking models - if is_thinking_model and not think_prefix_sent and content: - content = "" + content - think_prefix_sent = True - - # Tool call streaming parsing - if tool_parser and delta_text: - # Fast path: skip full parsing until '<' is seen in the stream, - # which could start tool markup (e.g. ). This avoids - # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: - tool_accumulated_text += delta_text - # No tool markup yet, fall through to normal chunk emission - else: - if not tool_markup_possible: - tool_markup_possible = True - tool_previous = tool_accumulated_text - tool_accumulated_text += delta_text - tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, delta_text - ) + if tool_result is None: + # Inside tool markup - suppress output + continue + + if "tool_calls" in tool_result: + # Emit structured tool calls + tool_calls_detected = True + chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason=( + "tool_calls" if output.finished else None + ), + ) + ], + usage=get_usage(output) if output.finished else None, + ) + _tc_sse = f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-TC] {_tc_sse.strip()[:300]}") + yield _tc_sse + continue + + # Normal content from tool parser + content = tool_result.get("content", "") + + # Skip empty-string content but preserve whitespace/newlines + # (newlines are significant for markdown formatting) + if content is not None and content == "": + content = None + + # Compute finish reason + finish_reason = ( + "tool_calls" + if (output.finished and tool_calls_detected) + else (output.finish_reason if output.finished else None) + ) - if tool_result is None: - # Inside tool markup - suppress output - continue - - if "tool_calls" in tool_result: - # Emit structured tool calls - tool_calls_detected = True - chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - tool_calls=tool_result["tool_calls"] - ), - finish_reason=( - "tool_calls" if output.finished else None - ), - ) - ], - usage=get_usage(output) if output.finished else None, + # Skip empty chunks (no content and no finish) + if not content and not finish_reason: + continue + + chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=content if content else None + ), + finish_reason=finish_reason, + logprobs=_build_chunk_logprobs(output), ) - _tc_sse = f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-TC] {_tc_sse.strip()[:300]}") - yield _tc_sse - continue - - # Normal content from tool parser - content = tool_result.get("content", "") - - # Skip empty-string content but preserve whitespace/newlines - # (newlines are significant for markdown formatting) - if content is not None and content == "": - content = None - - # Compute finish reason - finish_reason = ( - "tool_calls" - if (output.finished and tool_calls_detected) - else (output.finish_reason if output.finished else None) - ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + + # Finalize reasoning parser: emit correction if short no-tag output + # was misclassified as reasoning during streaming. + if _reasoning_parser and accumulated_text: + correction = _reasoning_parser.finalize_streaming(accumulated_text) + if correction and correction.content: + correction_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=correction.content, + ), + finish_reason=None, + ) + ], + usage=None, + ) + yield f"data: {correction_chunk.model_dump_json(exclude_none=True)}\n\n" + + # Fallback: if tool parser accumulated text but never emitted tool_calls + # (e.g., closing tag never arrived - incomplete tool call). + # Use parser-aware check so non-standard markers (MiniMax, Llama, etc.) + # are detected instead of only checking for "". + # Also check accumulated_text (full output including reasoning) as a + # safety net — tool XML may have leaked into reasoning stream. + _fallback_text = tool_accumulated_text or accumulated_text + if ( + tool_parser + and _fallback_text + and not tool_calls_detected + and tool_parser.has_pending_tool_call(_fallback_text) + ): + result = tool_parser.extract_tool_calls(_fallback_text) + if result.tools_called: + tool_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=[ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + ), + finish_reason="tool_calls", + ) + ], + ) + _fb_sse = f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-FALLBACK-TC] {_fb_sse.strip()[:300]}") + yield _fb_sse - # Skip empty chunks (no content and no finish) - if not content and not finish_reason: - continue + # Log throughput + elapsed = time.perf_counter() - start_time + tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) - chunk = ChatCompletionChunk( + # Send final chunk with usage if requested + if include_usage: + usage_chunk = ChatCompletionChunk( id=response_id, model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=content if content else None - ), - finish_reason=finish_reason, - logprobs=_build_chunk_logprobs(output), - ) - ], - usage=get_usage(output) if output.finished else None, - ) - yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - - # Finalize reasoning parser: emit correction if short no-tag output - # was misclassified as reasoning during streaming. - if _reasoning_parser and accumulated_text: - correction = _reasoning_parser.finalize_streaming(accumulated_text) - if correction and correction.content: - correction_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=correction.content, - ), - finish_reason=None, - ) - ], - usage=None, - ) - yield f"data: {correction_chunk.model_dump_json(exclude_none=True)}\n\n" - - # Fallback: if tool parser accumulated text but never emitted tool_calls - # (e.g., closing tag never arrived - incomplete tool call). - # Use parser-aware check so non-standard markers (MiniMax, Llama, etc.) - # are detected instead of only checking for "". - # Also check accumulated_text (full output including reasoning) as a - # safety net — tool XML may have leaked into reasoning stream. - _fallback_text = tool_accumulated_text or accumulated_text - if ( - tool_parser - and _fallback_text - and not tool_calls_detected - and tool_parser.has_pending_tool_call(_fallback_text) - ): - result = tool_parser.extract_tool_calls(_fallback_text) - if result.tools_called: - tool_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - tool_calls=[ - { - "index": i, - "id": tc["id"], - "type": "function", - "function": { - "name": tc["name"], - "arguments": tc["arguments"], - }, - } - for i, tc in enumerate(result.tool_calls) - ] - ), - finish_reason="tool_calls", - ) - ], + choices=[], # Empty choices for usage-only chunk + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) - _fb_sse = f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-FALLBACK-TC] {_fb_sse.strip()[:300]}") - yield _fb_sse - - # Log throughput - elapsed = time.perf_counter() - start_time - tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 - logger.info( - f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" - ) - - # Send final chunk with usage if requested - if include_usage: - usage_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[], # Empty choices for usage-only chunk - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" - yield "data: [DONE]\n\n" - - # Re-enable GC and collect after generation completes - if _gc_control and gc_was_enabled: - gc.enable() - gc.collect() + yield "data: [DONE]\n\n" + finally: + # Re-enable GC even if generator is abandoned (client disconnect) + if _gc_control and gc_was_enabled: + gc.enable() + gc.collect() # ============================================================================= From 968706021271685a1835e0c5bafe74047765e3d5 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:06:11 -0800 Subject: [PATCH 2/7] fix: cloud routing sends pre-mutation messages and forwards stop/tool_choice MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cloud routing was using locally-mutated messages (tool→user conversion, developer→system normalization, suffix injection) instead of original OpenAI-format messages. Also forward stop and tool_choice parameters. Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/server.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index d09e56af..7612a0f5 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -1678,6 +1678,19 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ) logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + # Save original messages (clean dicts) for cloud routing BEFORE + # local mutations (extract_multimodal_content, developer→system, suffix injection). + # Cloud APIs expect standard OpenAI-format messages. + if _cloud_router: + _cloud_original_messages = [ + msg.model_dump(exclude_none=True) + if hasattr(msg, "model_dump") + else {k: v for k, v in dict(msg).items() if v is not None} + for msg in request.messages + ] + else: + _cloud_original_messages = None + # For MLLM models, keep original messages with embedded images # (MLLM.chat() extracts images from message content internally) if engine.is_mllm: @@ -1799,13 +1812,17 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re f"> threshold {_cloud_router.threshold}, " f"routing to {_cloud_router.cloud_model}" ) - # Pass original dict-format messages to cloud (not template-applied prompt) - cloud_messages = messages + # Use pre-mutation messages for cloud (standard OpenAI format) + cloud_messages = _cloud_original_messages cloud_kwargs = { "temperature": chat_kwargs.get("temperature"), "max_tokens": chat_kwargs.get("max_tokens"), "top_p": chat_kwargs.get("top_p"), } + if request.stop: + cloud_kwargs["stop"] = request.stop + if request.tool_choice is not None: + cloud_kwargs["tool_choice"] = request.tool_choice if request.response_format: rf = request.response_format cloud_kwargs["response_format"] = ( From abaae424c8e38f937a93ef6351cf85a1de4cedac Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:17:26 -0800 Subject: [PATCH 3/7] fix: prefix cache pin stability and guided.py return types PrefixCacheManager.pin_prefix was silently undone by store_cache and _touch_lru re-adding entries to LRU. Added _pinned set to track pinned entries, ensuring they stay out of LRU. Pinned entries now count toward capacity to prevent unbounded cache growth. Fixed generate_json/generate_json_object return type from str to str|None to match actual behavior (returns None on failure). Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/api/guided.py | 8 +++--- vllm_mlx/prefix_cache.py | 53 +++++++++++++++++++++++++--------------- 2 files changed, 37 insertions(+), 24 deletions(-) diff --git a/vllm_mlx/api/guided.py b/vllm_mlx/api/guided.py index 454a7944..0c7a4da3 100644 --- a/vllm_mlx/api/guided.py +++ b/vllm_mlx/api/guided.py @@ -128,7 +128,7 @@ def generate_json( json_schema: dict[str, Any], max_tokens: int = 256, temperature: float = 0.7, - ) -> str: + ) -> str | None: """ Generate JSON output constrained to a schema. @@ -139,7 +139,7 @@ def generate_json( temperature: Sampling temperature Returns: - JSON string matching the schema + JSON string matching the schema, or None on failure """ # Convert schema to Pydantic model pydantic_model = json_schema_to_pydantic(json_schema) @@ -172,7 +172,7 @@ def generate_json_object( prompt: str, max_tokens: int = 256, temperature: float = 0.7, - ) -> str: + ) -> str | None: """ Generate any valid JSON object. @@ -182,7 +182,7 @@ def generate_json_object( temperature: Sampling temperature Returns: - JSON string + JSON string, or None on failure """ try: from outlines import generate diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index 68301985..b94ebf9d 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -110,6 +110,9 @@ def __init__(self, model: Any, max_entries: int = 100): # LRU tracking: (model_key, tuple(tokens)) ordered by access time self._lru: deque = deque() + # Pinned entries: keys excluded from LRU eviction + self._pinned: set = set() + # Statistics self.stats = PrefixCacheStats() @@ -243,20 +246,24 @@ def store_cache(self, tokens: List[int], prompt_cache: List[Any]) -> None: current = current[tok] # Store or update cache entry + key = (self.model_key, tokens_tuple) if "cache" in current: current["cache"].count += 1 - # Update LRU position - try: - self._lru.remove((self.model_key, tokens_tuple)) - except ValueError: - pass + # Update LRU position (skip if pinned) + if key not in self._pinned: + try: + self._lru.remove(key) + except ValueError: + pass else: current["cache"] = CacheEntry(prompt_cache, 1) - self._lru.append((self.model_key, tokens_tuple)) + # Only add to LRU if not pinned + if key not in self._pinned: + self._lru.append(key) - # Evict if over capacity - while len(self._lru) > self.max_size: + # Evict if over capacity (count pinned entries toward total) + while len(self._lru) + len(self._pinned) > self.max_size and len(self._lru) > 0: self._evict_lru() def _get_cache_entry(self, tokens: List[int]) -> Optional[CacheEntry]: @@ -275,6 +282,8 @@ def _get_cache_entry(self, tokens: List[int]) -> Optional[CacheEntry]: def _touch_lru(self, tokens_tuple: tuple) -> None: """Move entry to end of LRU queue (most recently used).""" key = (self.model_key, tokens_tuple) + if key in self._pinned: + return # Pinned entries stay out of LRU try: self._lru.remove(key) except ValueError: @@ -345,6 +354,7 @@ def clear(self) -> None: """Clear all cached entries.""" self._cache.clear() self._lru.clear() + self._pinned.clear() self.reset_stats() def pin_prefix(self, tokens: List[int]) -> bool: @@ -362,13 +372,18 @@ def pin_prefix(self, tokens: List[int]) -> bool: """ tokens_tuple = tuple(tokens) key = (self.model_key, tokens_tuple) + # Verify entry exists in trie + entry = self._get_cache_entry(tokens) + if entry is None: + logger.warning(f"Cannot pin prefix: not found in cache") + return False try: self._lru.remove(key) - logger.info(f"Pinned prefix ({len(tokens)} tokens) - removed from LRU") - return True except ValueError: - logger.warning(f"Cannot pin prefix: not found in cache") - return False + pass # May already be removed from LRU + self._pinned.add(key) + logger.info(f"Pinned prefix ({len(tokens)} tokens)") + return True def unpin_prefix(self, tokens: List[int]) -> bool: """ @@ -382,20 +397,18 @@ def unpin_prefix(self, tokens: List[int]) -> bool: """ tokens_tuple = tuple(tokens) key = (self.model_key, tokens_tuple) - # Check the entry exists in the trie - entry = self._get_cache_entry(tokens) - if entry is None: + if key not in self._pinned: return False + self._pinned.discard(key) # Re-add to LRU (at MRU end) if key not in self._lru: self._lru.append(key) - logger.info(f"Unpinned prefix ({len(tokens)} tokens) - added back to LRU") - return True - return False + logger.info(f"Unpinned prefix ({len(tokens)} tokens) - added back to LRU") + return True def __len__(self) -> int: - """Return number of cached entries.""" - return len(self._lru) + """Return number of cached entries (including pinned).""" + return len(self._lru) + len(self._pinned) # ============================================================================= From c23a8bc1483710438d4090822b80bf449f98bb2d Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:18:24 -0800 Subject: [PATCH 4/7] fix: rate limiter stale key cleanup and demote user content log to DEBUG RateLimiter._requests dict grew unbounded with unique client keys that stopped making requests. Added periodic purge of stale keys when dict exceeds 100 entries. Demoted user message preview logging from INFO to DEBUG to prevent PII/sensitive content from appearing in production logs. Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/server.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 7612a0f5..e54d6857 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -379,6 +379,15 @@ def is_allowed(self, client_id: str) -> tuple[bool, int]: window_start = current_time - self.window_size with self._lock: + # Periodically purge stale client keys (every ~100 requests) + if len(self._requests) > 100: + stale = [ + k for k, v in self._requests.items() + if not v or max(v) <= window_start + ] + for k in stale: + del self._requests[k] + # Clean old requests outside window self._requests[client_id] = [ t for t in self._requests[client_id] if t > window_start @@ -1676,7 +1685,7 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re f"total_chars={total_chars} tools={n_tools} " f"response_format={request.response_format}" ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + logger.debug(f"[REQUEST] last user message preview: {last_user_preview!r}") # Save original messages (clean dicts) for cloud routing BEFORE # local mutations (extract_multimodal_content, developer→system, suffix injection). @@ -2104,7 +2113,7 @@ async def create_anthropic_message( f"msgs={n_msgs} total_chars={total_chars} system_chars={sys_chars} " f"tools={n_tools}" ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + logger.debug(f"[REQUEST] last user message preview: {last_user_preview!r}") # Convert Anthropic request -> OpenAI request openai_request = anthropic_to_openai(anthropic_request) From 35b9f9edeb0cd4fe62895ad95f5964445a7135b6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:44:19 -0800 Subject: [PATCH 5/7] fix: cloud response_format passthrough, inject config values, pin cap 1. CloudRouter._build_call_kwargs now forwards response_format to litellm so structured output works on cloud-routed requests. 2. _inject_shared_model uses engine config (self._prefill_step_size, self._kv_bits, self._kv_group_size) instead of hardcoded defaults. 3. pin_prefix rejects when pinned count reaches max_size, preventing capacity from becoming unenforceable. Co-Authored-By: Claude Opus 4.6 --- vllm_mlx/cloud_router.py | 2 +- vllm_mlx/engine/simple.py | 6 +++--- vllm_mlx/prefix_cache.py | 7 +++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/vllm_mlx/cloud_router.py b/vllm_mlx/cloud_router.py index a980b46c..1603ee89 100644 --- a/vllm_mlx/cloud_router.py +++ b/vllm_mlx/cloud_router.py @@ -178,7 +178,7 @@ def _build_call_kwargs( for key in ( "temperature", "max_tokens", "top_p", "stop", "frequency_penalty", "presence_penalty", - "tools", "tool_choice", + "tools", "tool_choice", "response_format", ): if key in kwargs and kwargs[key] is not None: call_kwargs[key] = kwargs[key] diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 02aa0f94..2c45773b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -593,9 +593,9 @@ async def _inject_shared_model( self._model.trust_remote_code = self._trust_remote_code self._model.draft_model_name = self._draft_model_name self._model.num_draft_tokens = self._num_draft_tokens - self._model.prefill_step_size = 2048 - self._model.kv_bits = None - self._model.kv_group_size = 64 + self._model.prefill_step_size = self._prefill_step_size + self._model.kv_bits = self._kv_bits + self._model.kv_group_size = self._kv_group_size self._model._prompt_cache = None self._model._cached_token_ids = [] self._model._cache_lock = False diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index b94ebf9d..f0f4c873 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -377,6 +377,13 @@ def pin_prefix(self, tokens: List[int]) -> bool: if entry is None: logger.warning(f"Cannot pin prefix: not found in cache") return False + # Reject if pinning would make capacity unenforceable + if key not in self._pinned and len(self._pinned) >= self.max_size: + logger.warning( + f"Cannot pin prefix: pinned count ({len(self._pinned)}) " + f"already at capacity ({self.max_size})" + ) + return False try: self._lru.remove(key) except ValueError: From 27438cbc3288677ecf0ca9df892626aa232074a4 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:48:11 -0800 Subject: [PATCH 6/7] test: regression tests for cloud response_format, inject config, pin cap - test_passes_through_response_format: verifies response_format is forwarded through _build_call_kwargs (was silently dropped) - TestInjectSharedModelConfig: verifies _inject_shared_model propagates engine config (prefill_step_size, kv_bits, kv_group_size) instead of hardcoded defaults - TestPrefixCachePinning: verifies pin survives store/touch, capacity guard rejects at max_size, unpin restores evictability, clear resets Also adds docstring note to pin_prefix about capacity policy. Co-Authored-By: Claude Opus 4.6 --- tests/test_cloud_router.py | 31 ++++++++++++ tests/test_prefix_cache.py | 87 ++++++++++++++++++++++++++++++++ tests/test_simple_engine_unit.py | 42 +++++++++++++++ vllm_mlx/prefix_cache.py | 5 ++ 4 files changed, 165 insertions(+) diff --git a/tests/test_cloud_router.py b/tests/test_cloud_router.py index fdd48ddc..83d41dbb 100644 --- a/tests/test_cloud_router.py +++ b/tests/test_cloud_router.py @@ -170,6 +170,37 @@ def test_ignores_unsupported_kwargs(self): assert "unsupported_param" not in kwargs + def test_passes_through_response_format(self): + """response_format is forwarded to litellm (regression: was silently dropped).""" + from vllm_mlx.cloud_router import CloudRouter + + router = CloudRouter(cloud_model="test-model", threshold=1000) + messages = [{"role": "user", "content": "Hello"}] + rf = {"type": "json_schema", "json_schema": {"name": "out", "schema": {"type": "object"}}} + + kwargs = router._build_call_kwargs( + messages=messages, + stream=False, + response_format=rf, + ) + + assert kwargs["response_format"] == rf + + def test_response_format_none_omitted(self): + """response_format=None is not included in kwargs.""" + from vllm_mlx.cloud_router import CloudRouter + + router = CloudRouter(cloud_model="test-model", threshold=1000) + messages = [{"role": "user", "content": "Hello"}] + + kwargs = router._build_call_kwargs( + messages=messages, + stream=False, + response_format=None, + ) + + assert "response_format" not in kwargs + class TestCloudRouterLazyImport: """Tests for CloudRouter lazy litellm import.""" diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 2a2321c4..cd6bfb5c 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -256,6 +256,93 @@ def test_trie_structure(self, cache_manager): assert cache_124 == ["cache_124"] +class TestPrefixCachePinning: + """Tests for pin/unpin prefix functionality.""" + + @pytest.fixture + def mock_model(self): + return MagicMock() + + def test_pin_survives_store(self, mock_model): + """Pinned entry stays pinned after store_cache re-stores same key (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=3) + mgr.store_cache([1, 2], ["cache_a"]) + assert mgr.pin_prefix([1, 2]) is True + + # Re-store same tokens — pin must not be silently undone + mgr.store_cache([1, 2], ["cache_b"]) + + # Fill remaining capacity and overflow — pinned entry must survive + mgr.store_cache([3], ["cache_c"]) + mgr.store_cache([4], ["cache_d"]) + mgr.store_cache([5], ["cache_e"]) # would evict [1,2] if it were in LRU + + cache, _ = mgr.fetch_cache([1, 2]) + assert cache is not None, "Pinned entry was evicted after store_cache" + + def test_pin_survives_touch(self, mock_model): + """Pinned entry stays pinned after fetch_cache touches it (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=3) + mgr.store_cache([1, 2], ["cache_a"]) + mgr.pin_prefix([1, 2]) + + # Access pinned entry — must not re-add to LRU + mgr.fetch_cache([1, 2]) + + # Overflow — pinned entry must survive + mgr.store_cache([3], ["c"]) + mgr.store_cache([4], ["d"]) + mgr.store_cache([5], ["e"]) + mgr.store_cache([6], ["f"]) + + cache, _ = mgr.fetch_cache([1, 2]) + assert cache is not None, "Pinned entry was evicted after fetch_cache touch" + + def test_pin_capacity_guard(self, mock_model): + """pin_prefix rejects when pinned count reaches max_size (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=2) + mgr.store_cache([1], ["a"]) + mgr.store_cache([2], ["b"]) + assert mgr.pin_prefix([1]) is True + assert mgr.pin_prefix([2]) is True + + # Now at capacity — next pin must fail + mgr.store_cache([3], ["c"]) # won't fit in LRU, but trie entry exists + # Actually [3] can't be stored because LRU+pinned > max_size and LRU is empty + # So test with existing entries: + # Unpin one, store a new entry, try to pin 3 total + mgr.unpin_prefix([2]) + mgr.store_cache([3], ["c"]) + assert mgr.pin_prefix([3]) is True # now 2 pinned (max_size=2) + + mgr.store_cache([4], ["d"]) + assert mgr.pin_prefix([4]) is False, "Pin should fail when at capacity" + + def test_unpin_restores_lru(self, mock_model): + """Unpinned entry becomes evictable again.""" + mgr = PrefixCacheManager(mock_model, max_entries=2) + mgr.store_cache([1], ["a"]) + mgr.store_cache([2], ["b"]) + mgr.pin_prefix([1]) + + mgr.unpin_prefix([1]) + + # Now [1] is back in LRU and can be evicted + mgr.store_cache([3], ["c"]) + mgr.store_cache([4], ["d"]) + + cache, _ = mgr.fetch_cache([1]) + assert cache is None, "Unpinned entry should be evictable" + + def test_clear_resets_pinned(self, mock_model): + """clear() removes pinned entries too.""" + mgr = PrefixCacheManager(mock_model, max_entries=5) + mgr.store_cache([1], ["a"]) + mgr.pin_prefix([1]) + mgr.clear() + assert len(mgr) == 0 + + class TestSchedulerIntegration: """Test integration with scheduler.""" diff --git a/tests/test_simple_engine_unit.py b/tests/test_simple_engine_unit.py index 6cecd553..e6734d04 100644 --- a/tests/test_simple_engine_unit.py +++ b/tests/test_simple_engine_unit.py @@ -323,3 +323,45 @@ def test_set_false(self): engine.preserve_native_tool_format = True engine.preserve_native_tool_format = False assert engine.preserve_native_tool_format is False + + +# --------------------------------------------------------------------------- +# _inject_shared_model config propagation +# --------------------------------------------------------------------------- + + +class TestInjectSharedModelConfig: + """Tests for _inject_shared_model using engine config values.""" + + @pytest.mark.asyncio + async def test_inject_propagates_engine_config(self): + """Injected model uses engine's config, not hardcoded defaults (regression).""" + engine = SimpleEngine( + model_name="test", + prefill_step_size=4096, + kv_bits=4, + kv_group_size=128, + ) + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + await engine._inject_shared_model(mock_model, mock_tokenizer) + + assert engine._model.prefill_step_size == 4096 + assert engine._model.kv_bits == 4 + assert engine._model.kv_group_size == 128 + + @pytest.mark.asyncio + async def test_inject_default_config(self): + """Injected model uses default config when engine uses defaults.""" + engine = SimpleEngine(model_name="test") + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + await engine._inject_shared_model(mock_model, mock_tokenizer) + + assert engine._model.prefill_step_size == 2048 + assert engine._model.kv_bits is None + assert engine._model.kv_group_size == 64 diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index f0f4c873..c4c2d66b 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -364,6 +364,11 @@ def pin_prefix(self, tokens: List[int]) -> bool: For the trie-based cache, this removes the entry from the LRU queue so it is never evicted. The entry remains accessible for lookups. + Note: Pinned entries count toward max_size capacity. If the number of + pinned entries already equals max_size, this method returns False to + prevent capacity from becoming unenforceable. Unpin existing entries + first to make room. + Args: tokens: Token sequence of the prefix to pin From 8e0a5d834ac3117bf2e1f470c5e9a7e1b7766980 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 3 Mar 2026 11:51:26 -0800 Subject: [PATCH 7/7] test: rate limiter stale purge and JSON extraction guard coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_rate_limiter_stale_key_purge: verifies stale client keys are purged when dict exceeds 100 entries - TestExtractJsonFromResponse: documents why extract_json_from_response must be guarded by `if response_format` — it corrupts plain text that ends with balanced braces Co-Authored-By: Claude Opus 4.6 --- tests/test_server.py | 23 +++++++++++++++++++++++ tests/test_server_utils.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+) diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e..5630e76e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -676,6 +676,29 @@ def test_rate_limiter_window_cleanup(self): allowed, _ = limiter.is_allowed("test_client") assert allowed is True + def test_rate_limiter_stale_key_purge(self): + """Stale client keys are purged when dict exceeds 100 entries (regression).""" + from vllm_mlx.server import RateLimiter + import time + + limiter = RateLimiter(requests_per_minute=10, enabled=True) + + # Seed 101 unique clients with expired timestamps + old_time = time.time() - 120 # 2 minutes ago (outside window) + with limiter._lock: + for i in range(101): + limiter._requests[f"stale_client_{i}"] = [old_time] + + assert len(limiter._requests) == 101 + + # One more request triggers purge (len > 100) + limiter.is_allowed("new_client") + + # Stale keys should be purged + assert len(limiter._requests) < 101 + # new_client should still be present + assert "new_client" in limiter._requests + # ============================================================================= # Integration Tests (require running server) diff --git a/tests/test_server_utils.py b/tests/test_server_utils.py index a3338374..7d9c4de0 100644 --- a/tests/test_server_utils.py +++ b/tests/test_server_utils.py @@ -321,3 +321,40 @@ class ToolDef(BaseModel): self._make_tool_calls(), tools, ) + + +# --------------------------------------------------------------------------- +# extract_json_from_response guard tests +# --------------------------------------------------------------------------- + + +class TestExtractJsonFromResponse: + """Tests for extract_json_from_response showing why guard is needed.""" + + def test_extracts_json_from_reasoning_text(self): + """Correctly extracts JSON from reasoning prefix.""" + from vllm_mlx.api.utils import extract_json_from_response + + text = 'Let me think... {"result": 42}' + assert extract_json_from_response(text) == '{"result": 42}' + + def test_corrupts_plain_text_ending_with_json(self): + """Without guard, plain text ending with JSON-like braces gets corrupted. + + This documents why server.py wraps the call with `if response_format`. + """ + from vllm_mlx.api.utils import extract_json_from_response + + # Plain text that happens to end with balanced braces + plain = 'The config looks like {"debug": true}' + result = extract_json_from_response(plain) + # The function extracts '{"debug": true}' — losing the prefix + assert result == '{"debug": true}' + assert result != plain + + def test_returns_original_when_no_json(self): + """Returns original text when no JSON structure found.""" + from vllm_mlx.api.utils import extract_json_from_response + + text = "Hello, world!" + assert extract_json_from_response(text) == text