diff --git a/tests/test_cloud_router.py b/tests/test_cloud_router.py index fdd48ddc..83d41dbb 100644 --- a/tests/test_cloud_router.py +++ b/tests/test_cloud_router.py @@ -170,6 +170,37 @@ def test_ignores_unsupported_kwargs(self): assert "unsupported_param" not in kwargs + def test_passes_through_response_format(self): + """response_format is forwarded to litellm (regression: was silently dropped).""" + from vllm_mlx.cloud_router import CloudRouter + + router = CloudRouter(cloud_model="test-model", threshold=1000) + messages = [{"role": "user", "content": "Hello"}] + rf = {"type": "json_schema", "json_schema": {"name": "out", "schema": {"type": "object"}}} + + kwargs = router._build_call_kwargs( + messages=messages, + stream=False, + response_format=rf, + ) + + assert kwargs["response_format"] == rf + + def test_response_format_none_omitted(self): + """response_format=None is not included in kwargs.""" + from vllm_mlx.cloud_router import CloudRouter + + router = CloudRouter(cloud_model="test-model", threshold=1000) + messages = [{"role": "user", "content": "Hello"}] + + kwargs = router._build_call_kwargs( + messages=messages, + stream=False, + response_format=None, + ) + + assert "response_format" not in kwargs + class TestCloudRouterLazyImport: """Tests for CloudRouter lazy litellm import.""" diff --git a/tests/test_prefix_cache.py b/tests/test_prefix_cache.py index 2a2321c4..cd6bfb5c 100644 --- a/tests/test_prefix_cache.py +++ b/tests/test_prefix_cache.py @@ -256,6 +256,93 @@ def test_trie_structure(self, cache_manager): assert cache_124 == ["cache_124"] +class TestPrefixCachePinning: + """Tests for pin/unpin prefix functionality.""" + + @pytest.fixture + def mock_model(self): + return MagicMock() + + def test_pin_survives_store(self, mock_model): + """Pinned entry stays pinned after store_cache re-stores same key (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=3) + mgr.store_cache([1, 2], ["cache_a"]) + assert mgr.pin_prefix([1, 2]) is True + + # Re-store same tokens — pin must not be silently undone + mgr.store_cache([1, 2], ["cache_b"]) + + # Fill remaining capacity and overflow — pinned entry must survive + mgr.store_cache([3], ["cache_c"]) + mgr.store_cache([4], ["cache_d"]) + mgr.store_cache([5], ["cache_e"]) # would evict [1,2] if it were in LRU + + cache, _ = mgr.fetch_cache([1, 2]) + assert cache is not None, "Pinned entry was evicted after store_cache" + + def test_pin_survives_touch(self, mock_model): + """Pinned entry stays pinned after fetch_cache touches it (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=3) + mgr.store_cache([1, 2], ["cache_a"]) + mgr.pin_prefix([1, 2]) + + # Access pinned entry — must not re-add to LRU + mgr.fetch_cache([1, 2]) + + # Overflow — pinned entry must survive + mgr.store_cache([3], ["c"]) + mgr.store_cache([4], ["d"]) + mgr.store_cache([5], ["e"]) + mgr.store_cache([6], ["f"]) + + cache, _ = mgr.fetch_cache([1, 2]) + assert cache is not None, "Pinned entry was evicted after fetch_cache touch" + + def test_pin_capacity_guard(self, mock_model): + """pin_prefix rejects when pinned count reaches max_size (regression).""" + mgr = PrefixCacheManager(mock_model, max_entries=2) + mgr.store_cache([1], ["a"]) + mgr.store_cache([2], ["b"]) + assert mgr.pin_prefix([1]) is True + assert mgr.pin_prefix([2]) is True + + # Now at capacity — next pin must fail + mgr.store_cache([3], ["c"]) # won't fit in LRU, but trie entry exists + # Actually [3] can't be stored because LRU+pinned > max_size and LRU is empty + # So test with existing entries: + # Unpin one, store a new entry, try to pin 3 total + mgr.unpin_prefix([2]) + mgr.store_cache([3], ["c"]) + assert mgr.pin_prefix([3]) is True # now 2 pinned (max_size=2) + + mgr.store_cache([4], ["d"]) + assert mgr.pin_prefix([4]) is False, "Pin should fail when at capacity" + + def test_unpin_restores_lru(self, mock_model): + """Unpinned entry becomes evictable again.""" + mgr = PrefixCacheManager(mock_model, max_entries=2) + mgr.store_cache([1], ["a"]) + mgr.store_cache([2], ["b"]) + mgr.pin_prefix([1]) + + mgr.unpin_prefix([1]) + + # Now [1] is back in LRU and can be evicted + mgr.store_cache([3], ["c"]) + mgr.store_cache([4], ["d"]) + + cache, _ = mgr.fetch_cache([1]) + assert cache is None, "Unpinned entry should be evictable" + + def test_clear_resets_pinned(self, mock_model): + """clear() removes pinned entries too.""" + mgr = PrefixCacheManager(mock_model, max_entries=5) + mgr.store_cache([1], ["a"]) + mgr.pin_prefix([1]) + mgr.clear() + assert len(mgr) == 0 + + class TestSchedulerIntegration: """Test integration with scheduler.""" diff --git a/tests/test_server.py b/tests/test_server.py index 9fb86a3e..5630e76e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -676,6 +676,29 @@ def test_rate_limiter_window_cleanup(self): allowed, _ = limiter.is_allowed("test_client") assert allowed is True + def test_rate_limiter_stale_key_purge(self): + """Stale client keys are purged when dict exceeds 100 entries (regression).""" + from vllm_mlx.server import RateLimiter + import time + + limiter = RateLimiter(requests_per_minute=10, enabled=True) + + # Seed 101 unique clients with expired timestamps + old_time = time.time() - 120 # 2 minutes ago (outside window) + with limiter._lock: + for i in range(101): + limiter._requests[f"stale_client_{i}"] = [old_time] + + assert len(limiter._requests) == 101 + + # One more request triggers purge (len > 100) + limiter.is_allowed("new_client") + + # Stale keys should be purged + assert len(limiter._requests) < 101 + # new_client should still be present + assert "new_client" in limiter._requests + # ============================================================================= # Integration Tests (require running server) diff --git a/tests/test_server_utils.py b/tests/test_server_utils.py index a3338374..7d9c4de0 100644 --- a/tests/test_server_utils.py +++ b/tests/test_server_utils.py @@ -321,3 +321,40 @@ class ToolDef(BaseModel): self._make_tool_calls(), tools, ) + + +# --------------------------------------------------------------------------- +# extract_json_from_response guard tests +# --------------------------------------------------------------------------- + + +class TestExtractJsonFromResponse: + """Tests for extract_json_from_response showing why guard is needed.""" + + def test_extracts_json_from_reasoning_text(self): + """Correctly extracts JSON from reasoning prefix.""" + from vllm_mlx.api.utils import extract_json_from_response + + text = 'Let me think... {"result": 42}' + assert extract_json_from_response(text) == '{"result": 42}' + + def test_corrupts_plain_text_ending_with_json(self): + """Without guard, plain text ending with JSON-like braces gets corrupted. + + This documents why server.py wraps the call with `if response_format`. + """ + from vllm_mlx.api.utils import extract_json_from_response + + # Plain text that happens to end with balanced braces + plain = 'The config looks like {"debug": true}' + result = extract_json_from_response(plain) + # The function extracts '{"debug": true}' — losing the prefix + assert result == '{"debug": true}' + assert result != plain + + def test_returns_original_when_no_json(self): + """Returns original text when no JSON structure found.""" + from vllm_mlx.api.utils import extract_json_from_response + + text = "Hello, world!" + assert extract_json_from_response(text) == text diff --git a/tests/test_simple_engine_unit.py b/tests/test_simple_engine_unit.py index 6cecd553..e6734d04 100644 --- a/tests/test_simple_engine_unit.py +++ b/tests/test_simple_engine_unit.py @@ -323,3 +323,45 @@ def test_set_false(self): engine.preserve_native_tool_format = True engine.preserve_native_tool_format = False assert engine.preserve_native_tool_format is False + + +# --------------------------------------------------------------------------- +# _inject_shared_model config propagation +# --------------------------------------------------------------------------- + + +class TestInjectSharedModelConfig: + """Tests for _inject_shared_model using engine config values.""" + + @pytest.mark.asyncio + async def test_inject_propagates_engine_config(self): + """Injected model uses engine's config, not hardcoded defaults (regression).""" + engine = SimpleEngine( + model_name="test", + prefill_step_size=4096, + kv_bits=4, + kv_group_size=128, + ) + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + await engine._inject_shared_model(mock_model, mock_tokenizer) + + assert engine._model.prefill_step_size == 4096 + assert engine._model.kv_bits == 4 + assert engine._model.kv_group_size == 128 + + @pytest.mark.asyncio + async def test_inject_default_config(self): + """Injected model uses default config when engine uses defaults.""" + engine = SimpleEngine(model_name="test") + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + + await engine._inject_shared_model(mock_model, mock_tokenizer) + + assert engine._model.prefill_step_size == 2048 + assert engine._model.kv_bits is None + assert engine._model.kv_group_size == 64 diff --git a/vllm_mlx/api/guided.py b/vllm_mlx/api/guided.py index 454a7944..0c7a4da3 100644 --- a/vllm_mlx/api/guided.py +++ b/vllm_mlx/api/guided.py @@ -128,7 +128,7 @@ def generate_json( json_schema: dict[str, Any], max_tokens: int = 256, temperature: float = 0.7, - ) -> str: + ) -> str | None: """ Generate JSON output constrained to a schema. @@ -139,7 +139,7 @@ def generate_json( temperature: Sampling temperature Returns: - JSON string matching the schema + JSON string matching the schema, or None on failure """ # Convert schema to Pydantic model pydantic_model = json_schema_to_pydantic(json_schema) @@ -172,7 +172,7 @@ def generate_json_object( prompt: str, max_tokens: int = 256, temperature: float = 0.7, - ) -> str: + ) -> str | None: """ Generate any valid JSON object. @@ -182,7 +182,7 @@ def generate_json_object( temperature: Sampling temperature Returns: - JSON string + JSON string, or None on failure """ try: from outlines import generate diff --git a/vllm_mlx/cloud_router.py b/vllm_mlx/cloud_router.py index a980b46c..1603ee89 100644 --- a/vllm_mlx/cloud_router.py +++ b/vllm_mlx/cloud_router.py @@ -178,7 +178,7 @@ def _build_call_kwargs( for key in ( "temperature", "max_tokens", "top_p", "stop", "frequency_penalty", "presence_penalty", - "tools", "tool_choice", + "tools", "tool_choice", "response_format", ): if key in kwargs and kwargs[key] is not None: call_kwargs[key] = kwargs[key] diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index 79a3b2bc..2c45773b 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -593,6 +593,12 @@ async def _inject_shared_model( self._model.trust_remote_code = self._trust_remote_code self._model.draft_model_name = self._draft_model_name self._model.num_draft_tokens = self._num_draft_tokens + self._model.prefill_step_size = self._prefill_step_size + self._model.kv_bits = self._kv_bits + self._model.kv_group_size = self._kv_group_size + self._model._prompt_cache = None + self._model._cached_token_ids = [] + self._model._cache_lock = False self._model.model = model self._model.tokenizer = tokenizer self._model.draft_model = None diff --git a/vllm_mlx/prefix_cache.py b/vllm_mlx/prefix_cache.py index 68301985..c4c2d66b 100644 --- a/vllm_mlx/prefix_cache.py +++ b/vllm_mlx/prefix_cache.py @@ -110,6 +110,9 @@ def __init__(self, model: Any, max_entries: int = 100): # LRU tracking: (model_key, tuple(tokens)) ordered by access time self._lru: deque = deque() + # Pinned entries: keys excluded from LRU eviction + self._pinned: set = set() + # Statistics self.stats = PrefixCacheStats() @@ -243,20 +246,24 @@ def store_cache(self, tokens: List[int], prompt_cache: List[Any]) -> None: current = current[tok] # Store or update cache entry + key = (self.model_key, tokens_tuple) if "cache" in current: current["cache"].count += 1 - # Update LRU position - try: - self._lru.remove((self.model_key, tokens_tuple)) - except ValueError: - pass + # Update LRU position (skip if pinned) + if key not in self._pinned: + try: + self._lru.remove(key) + except ValueError: + pass else: current["cache"] = CacheEntry(prompt_cache, 1) - self._lru.append((self.model_key, tokens_tuple)) + # Only add to LRU if not pinned + if key not in self._pinned: + self._lru.append(key) - # Evict if over capacity - while len(self._lru) > self.max_size: + # Evict if over capacity (count pinned entries toward total) + while len(self._lru) + len(self._pinned) > self.max_size and len(self._lru) > 0: self._evict_lru() def _get_cache_entry(self, tokens: List[int]) -> Optional[CacheEntry]: @@ -275,6 +282,8 @@ def _get_cache_entry(self, tokens: List[int]) -> Optional[CacheEntry]: def _touch_lru(self, tokens_tuple: tuple) -> None: """Move entry to end of LRU queue (most recently used).""" key = (self.model_key, tokens_tuple) + if key in self._pinned: + return # Pinned entries stay out of LRU try: self._lru.remove(key) except ValueError: @@ -345,6 +354,7 @@ def clear(self) -> None: """Clear all cached entries.""" self._cache.clear() self._lru.clear() + self._pinned.clear() self.reset_stats() def pin_prefix(self, tokens: List[int]) -> bool: @@ -354,6 +364,11 @@ def pin_prefix(self, tokens: List[int]) -> bool: For the trie-based cache, this removes the entry from the LRU queue so it is never evicted. The entry remains accessible for lookups. + Note: Pinned entries count toward max_size capacity. If the number of + pinned entries already equals max_size, this method returns False to + prevent capacity from becoming unenforceable. Unpin existing entries + first to make room. + Args: tokens: Token sequence of the prefix to pin @@ -362,13 +377,25 @@ def pin_prefix(self, tokens: List[int]) -> bool: """ tokens_tuple = tuple(tokens) key = (self.model_key, tokens_tuple) + # Verify entry exists in trie + entry = self._get_cache_entry(tokens) + if entry is None: + logger.warning(f"Cannot pin prefix: not found in cache") + return False + # Reject if pinning would make capacity unenforceable + if key not in self._pinned and len(self._pinned) >= self.max_size: + logger.warning( + f"Cannot pin prefix: pinned count ({len(self._pinned)}) " + f"already at capacity ({self.max_size})" + ) + return False try: self._lru.remove(key) - logger.info(f"Pinned prefix ({len(tokens)} tokens) - removed from LRU") - return True except ValueError: - logger.warning(f"Cannot pin prefix: not found in cache") - return False + pass # May already be removed from LRU + self._pinned.add(key) + logger.info(f"Pinned prefix ({len(tokens)} tokens)") + return True def unpin_prefix(self, tokens: List[int]) -> bool: """ @@ -382,20 +409,18 @@ def unpin_prefix(self, tokens: List[int]) -> bool: """ tokens_tuple = tuple(tokens) key = (self.model_key, tokens_tuple) - # Check the entry exists in the trie - entry = self._get_cache_entry(tokens) - if entry is None: + if key not in self._pinned: return False + self._pinned.discard(key) # Re-add to LRU (at MRU end) if key not in self._lru: self._lru.append(key) - logger.info(f"Unpinned prefix ({len(tokens)} tokens) - added back to LRU") - return True - return False + logger.info(f"Unpinned prefix ({len(tokens)} tokens) - added back to LRU") + return True def __len__(self) -> int: - """Return number of cached entries.""" - return len(self._lru) + """Return number of cached entries (including pinned).""" + return len(self._lru) + len(self._pinned) # ============================================================================= diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 460298b7..e54d6857 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -379,6 +379,15 @@ def is_allowed(self, client_id: str) -> tuple[bool, int]: window_start = current_time - self.window_size with self._lock: + # Periodically purge stale client keys (every ~100 requests) + if len(self._requests) > 100: + stale = [ + k for k, v in self._requests.items() + if not v or max(v) <= window_start + ] + for k in stale: + del self._requests[k] + # Clean old requests outside window self._requests[client_id] = [ t for t in self._requests[client_id] if t > window_start @@ -1676,7 +1685,20 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re f"total_chars={total_chars} tools={n_tools} " f"response_format={request.response_format}" ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + logger.debug(f"[REQUEST] last user message preview: {last_user_preview!r}") + + # Save original messages (clean dicts) for cloud routing BEFORE + # local mutations (extract_multimodal_content, developer→system, suffix injection). + # Cloud APIs expect standard OpenAI-format messages. + if _cloud_router: + _cloud_original_messages = [ + msg.model_dump(exclude_none=True) + if hasattr(msg, "model_dump") + else {k: v for k, v in dict(msg).items() if v is not None} + for msg in request.messages + ] + else: + _cloud_original_messages = None # For MLLM models, keep original messages with embedded images # (MLLM.chat() extracts images from message content internally) @@ -1799,13 +1821,22 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re f"> threshold {_cloud_router.threshold}, " f"routing to {_cloud_router.cloud_model}" ) - # Pass original dict-format messages to cloud (not template-applied prompt) - cloud_messages = messages + # Use pre-mutation messages for cloud (standard OpenAI format) + cloud_messages = _cloud_original_messages cloud_kwargs = { "temperature": chat_kwargs.get("temperature"), "max_tokens": chat_kwargs.get("max_tokens"), "top_p": chat_kwargs.get("top_p"), } + if request.stop: + cloud_kwargs["stop"] = request.stop + if request.tool_choice is not None: + cloud_kwargs["tool_choice"] = request.tool_choice + if request.response_format: + rf = request.response_format + cloud_kwargs["response_format"] = ( + rf.model_dump() if hasattr(rf, "model_dump") else rf + ) if request.tools: # Pass raw tool defs (OpenAI format), not template-converted cloud_kwargs["tools"] = [ @@ -1814,10 +1845,13 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re ] if request.stream: return StreamingResponse( - _cloud_router.stream_completion( - cloud_messages, - model_name=_model_name or "cloud", - **cloud_kwargs, + _disconnect_guard( + _cloud_router.stream_completion( + cloud_messages, + model_name=_model_name or "cloud", + **cloud_kwargs, + ), + raw_request, ), media_type="text/event-stream", ) @@ -1967,9 +2001,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re final_content = None if cleaned_text: final_content = strip_thinking_tags(clean_output_text(cleaned_text)) - # If response looks like it ends with JSON, extract just the JSON part - # This handles Qwen3 reasoning mode: "Let me think... {json}" - final_content = extract_json_from_response(final_content) + # If JSON mode requested, extract JSON from reasoning text + # (e.g., Qwen3 reasoning mode: "Let me think... {json}") + if response_format: + final_content = extract_json_from_response(final_content) # Build logprobs for response if requested choice_logprobs = None @@ -2078,7 +2113,7 @@ async def create_anthropic_message( f"msgs={n_msgs} total_chars={total_chars} system_chars={sys_chars} " f"tools={n_tools}" ) - logger.info(f"[REQUEST] last user message preview: {last_user_preview!r}") + logger.debug(f"[REQUEST] last user message preview: {last_user_preview!r}") # Convert Anthropic request -> OpenAI request openai_request = anthropic_to_openai(anthropic_request) @@ -2434,386 +2469,387 @@ async def stream_chat_completion( if _gc_control and gc_was_enabled: gc.disable() - response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" - start_time = time.perf_counter() + try: + response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" + start_time = time.perf_counter() - # Check if we should include usage in the final chunk - include_usage = request.stream_options and request.stream_options.include_usage + # Check if we should include usage in the final chunk + include_usage = request.stream_options and request.stream_options.include_usage - # Logprobs configuration - want_logprobs = request.logprobs and request.top_logprobs - top_k_logprobs = request.top_logprobs or 0 - - def _build_chunk_logprobs(output: GenerationOutput) -> ChoiceLogProbs | None: - """Build ChoiceLogProbs for a streaming chunk if logprobs requested.""" - if not want_logprobs or output.logprobs is None: - return None - token_id = output.tokens[-1] if output.tokens else 0 - token_lp = _extract_token_logprob( - output.logprobs, token_id, engine.tokenizer, top_k_logprobs - ) - return ChoiceLogProbs(content=[token_lp]) + # Logprobs configuration + want_logprobs = request.logprobs and request.top_logprobs + top_k_logprobs = request.top_logprobs or 0 - # First chunk with role - first_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta(role="assistant"), + def _build_chunk_logprobs(output: GenerationOutput) -> ChoiceLogProbs | None: + """Build ChoiceLogProbs for a streaming chunk if logprobs requested.""" + if not want_logprobs or output.logprobs is None: + return None + token_id = output.tokens[-1] if output.tokens else 0 + token_lp = _extract_token_logprob( + output.logprobs, token_id, engine.tokenizer, top_k_logprobs ) - ], - ) - _first_sse = f"data: {first_chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-ROLE] {_first_sse.strip()[:200]}") - yield _first_sse + return ChoiceLogProbs(content=[token_lp]) - # Track if we need to add prefix for thinking models (when no reasoning parser) - # The template adds to the prompt, so the model output starts inside the think block - is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser - think_prefix_sent = False - - # Reset reasoning parser state for this stream - if _reasoning_parser: - _reasoning_parser.reset_state() - - # Track accumulated text for reasoning parser - accumulated_text = "" - - # Track token counts for usage reporting - prompt_tokens = 0 - completion_tokens = 0 - last_output = None - - # Tool call streaming state - global _tool_parser_instance - tool_parser = None - tool_accumulated_text = "" - tool_calls_detected = False - tool_markup_possible = False # Fast path: skip parsing until '<' seen - if _enable_auto_tool_choice and _tool_call_parser: - # Initialize parser if needed (same as _parse_tool_calls_with_parser) - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - logger.info(f"Initialized tool call parser: {_tool_call_parser}") - except Exception as e: - logger.warning(f"Failed to init tool parser for streaming: {e}") - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() - - # Fallback: auto-infer tool parser when tools requested + reasoning parser set - if tool_parser is None and request.tools and _reasoning_parser_name: - _PARSER_MAP = {"minimax": "minimax"} - inferred = _PARSER_MAP.get(_reasoning_parser_name) - if inferred: - try: - parser_cls = ToolParserManager.get_tool_parser(inferred) - tokenizer = getattr(_engine, "_tokenizer", None) - tool_parser = parser_cls(tokenizer) + # First chunk with role + first_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta(role="assistant"), + ) + ], + ) + _first_sse = f"data: {first_chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-ROLE] {_first_sse.strip()[:200]}") + yield _first_sse + + # Track if we need to add prefix for thinking models (when no reasoning parser) + # The template adds to the prompt, so the model output starts inside the think block + is_thinking_model = "nemotron" in request.model.lower() and not _reasoning_parser + think_prefix_sent = False + + # Reset reasoning parser state for this stream + if _reasoning_parser: + _reasoning_parser.reset_state() + + # Track accumulated text for reasoning parser + accumulated_text = "" + + # Track token counts for usage reporting + prompt_tokens = 0 + completion_tokens = 0 + last_output = None + + # Tool call streaming state + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_calls_detected = False + tool_markup_possible = False # Fast path: skip parsing until '<' seen + if _enable_auto_tool_choice and _tool_call_parser: + # Initialize parser if needed (same as _parse_tool_calls_with_parser) + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info(f"Initialized tool call parser: {_tool_call_parser}") + except Exception as e: + logger.warning(f"Failed to init tool parser for streaming: {e}") + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance tool_parser.reset() - except Exception as e: - logger.debug(f"Auto-infer tool parser for streaming failed: {e}") - - # Stream content - async for output in engine.stream_chat(messages=messages, **kwargs): - delta_text = output.new_text - last_output = output - - # Track token counts from output (updated each chunk) - if hasattr(output, "prompt_tokens") and output.prompt_tokens: - prompt_tokens = output.prompt_tokens - if hasattr(output, "completion_tokens") and output.completion_tokens: - completion_tokens = output.completion_tokens - - # Use reasoning parser if enabled - if _reasoning_parser and delta_text: - previous_text = accumulated_text - accumulated_text += delta_text - delta_msg = _reasoning_parser.extract_reasoning_streaming( - previous_text, accumulated_text, delta_text - ) - if delta_msg is None: - # Skip this chunk (e.g., token itself) - continue - - content = delta_msg.content - reasoning = delta_msg.reasoning - - # Some models (e.g. MiniMax) wrap tool calls in - # blocks, so reasoning parser captures tool call XML as - # reasoning while content stays None. Redirect reasoning - # to the content stream so the tool parser can handle it. - # Check even when content is present (e.g. "\n" from - # boundary) to avoid XML leaking as reasoning. - if tool_parser and reasoning: - _check = tool_accumulated_text + reasoning - if ( - "" in _check - or "" in _check - or ' token itself) + continue + + content = delta_msg.content + reasoning = delta_msg.reasoning + + # Some models (e.g. MiniMax) wrap tool calls in + # blocks, so reasoning parser captures tool call XML as + # reasoning while content stays None. Redirect reasoning + # to the content stream so the tool parser can handle it. + # Check even when content is present (e.g. "\n" from + # boundary) to avoid XML leaking as reasoning. + if tool_parser and reasoning: + _check = tool_accumulated_text + reasoning + if ( + "" in _check + or "" in _check + or ' prefix on first content chunk for thinking models + if is_thinking_model and not think_prefix_sent and content: + content = "" + content + think_prefix_sent = True + + # Tool call streaming parsing + if tool_parser and delta_text: + # Fast path: skip full parsing until '<' is seen in the stream, + # which could start tool markup (e.g. ). This avoids + # per-token string scanning on the growing accumulated text. + if not tool_markup_possible and "<" not in delta_text: + tool_accumulated_text += delta_text + # No tool markup yet, fall through to normal chunk emission + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_previous = tool_accumulated_text + tool_accumulated_text += delta_text + tool_result = tool_parser.extract_tool_calls_streaming( + tool_previous, tool_accumulated_text, delta_text + ) - # Filter special tokens that may leak into streaming output - if content: - content = SPECIAL_TOKENS_PATTERN.sub("", content) - - # Add prefix on first content chunk for thinking models - if is_thinking_model and not think_prefix_sent and content: - content = "" + content - think_prefix_sent = True - - # Tool call streaming parsing - if tool_parser and delta_text: - # Fast path: skip full parsing until '<' is seen in the stream, - # which could start tool markup (e.g. ). This avoids - # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: - tool_accumulated_text += delta_text - # No tool markup yet, fall through to normal chunk emission - else: - if not tool_markup_possible: - tool_markup_possible = True - tool_previous = tool_accumulated_text - tool_accumulated_text += delta_text - tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, delta_text - ) + if tool_result is None: + # Inside tool markup - suppress output + continue + + if "tool_calls" in tool_result: + # Emit structured tool calls + tool_calls_detected = True + chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=tool_result["tool_calls"] + ), + finish_reason=( + "tool_calls" if output.finished else None + ), + ) + ], + usage=get_usage(output) if output.finished else None, + ) + _tc_sse = f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-TC] {_tc_sse.strip()[:300]}") + yield _tc_sse + continue + + # Normal content from tool parser + content = tool_result.get("content", "") + + # Skip empty-string content but preserve whitespace/newlines + # (newlines are significant for markdown formatting) + if content is not None and content == "": + content = None + + # Compute finish reason + finish_reason = ( + "tool_calls" + if (output.finished and tool_calls_detected) + else (output.finish_reason if output.finished else None) + ) - if tool_result is None: - # Inside tool markup - suppress output - continue - - if "tool_calls" in tool_result: - # Emit structured tool calls - tool_calls_detected = True - chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - tool_calls=tool_result["tool_calls"] - ), - finish_reason=( - "tool_calls" if output.finished else None - ), - ) - ], - usage=get_usage(output) if output.finished else None, + # Skip empty chunks (no content and no finish) + if not content and not finish_reason: + continue + + chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=content if content else None + ), + finish_reason=finish_reason, + logprobs=_build_chunk_logprobs(output), ) - _tc_sse = f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-TC] {_tc_sse.strip()[:300]}") - yield _tc_sse - continue - - # Normal content from tool parser - content = tool_result.get("content", "") - - # Skip empty-string content but preserve whitespace/newlines - # (newlines are significant for markdown formatting) - if content is not None and content == "": - content = None - - # Compute finish reason - finish_reason = ( - "tool_calls" - if (output.finished and tool_calls_detected) - else (output.finish_reason if output.finished else None) - ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" + + # Finalize reasoning parser: emit correction if short no-tag output + # was misclassified as reasoning during streaming. + if _reasoning_parser and accumulated_text: + correction = _reasoning_parser.finalize_streaming(accumulated_text) + if correction and correction.content: + correction_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=correction.content, + ), + finish_reason=None, + ) + ], + usage=None, + ) + yield f"data: {correction_chunk.model_dump_json(exclude_none=True)}\n\n" + + # Fallback: if tool parser accumulated text but never emitted tool_calls + # (e.g., closing tag never arrived - incomplete tool call). + # Use parser-aware check so non-standard markers (MiniMax, Llama, etc.) + # are detected instead of only checking for "". + # Also check accumulated_text (full output including reasoning) as a + # safety net — tool XML may have leaked into reasoning stream. + _fallback_text = tool_accumulated_text or accumulated_text + if ( + tool_parser + and _fallback_text + and not tool_calls_detected + and tool_parser.has_pending_tool_call(_fallback_text) + ): + result = tool_parser.extract_tool_calls(_fallback_text) + if result.tools_called: + tool_chunk = ChatCompletionChunk( + id=response_id, + model=request.model, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + tool_calls=[ + { + "index": i, + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"], + }, + } + for i, tc in enumerate(result.tool_calls) + ] + ), + finish_reason="tool_calls", + ) + ], + ) + _fb_sse = f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" + logger.info(f"[SSE-FALLBACK-TC] {_fb_sse.strip()[:300]}") + yield _fb_sse - # Skip empty chunks (no content and no finish) - if not content and not finish_reason: - continue + # Log throughput + elapsed = time.perf_counter() - start_time + tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 + logger.info( + f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" + ) - chunk = ChatCompletionChunk( + # Send final chunk with usage if requested + if include_usage: + usage_chunk = ChatCompletionChunk( id=response_id, model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=content if content else None - ), - finish_reason=finish_reason, - logprobs=_build_chunk_logprobs(output), - ) - ], - usage=get_usage(output) if output.finished else None, - ) - yield f"data: {chunk.model_dump_json(exclude_none=True)}\n\n" - - # Finalize reasoning parser: emit correction if short no-tag output - # was misclassified as reasoning during streaming. - if _reasoning_parser and accumulated_text: - correction = _reasoning_parser.finalize_streaming(accumulated_text) - if correction and correction.content: - correction_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=correction.content, - ), - finish_reason=None, - ) - ], - usage=None, - ) - yield f"data: {correction_chunk.model_dump_json(exclude_none=True)}\n\n" - - # Fallback: if tool parser accumulated text but never emitted tool_calls - # (e.g., closing tag never arrived - incomplete tool call). - # Use parser-aware check so non-standard markers (MiniMax, Llama, etc.) - # are detected instead of only checking for "". - # Also check accumulated_text (full output including reasoning) as a - # safety net — tool XML may have leaked into reasoning stream. - _fallback_text = tool_accumulated_text or accumulated_text - if ( - tool_parser - and _fallback_text - and not tool_calls_detected - and tool_parser.has_pending_tool_call(_fallback_text) - ): - result = tool_parser.extract_tool_calls(_fallback_text) - if result.tools_called: - tool_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - tool_calls=[ - { - "index": i, - "id": tc["id"], - "type": "function", - "function": { - "name": tc["name"], - "arguments": tc["arguments"], - }, - } - for i, tc in enumerate(result.tool_calls) - ] - ), - finish_reason="tool_calls", - ) - ], + choices=[], # Empty choices for usage-only chunk + usage=Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) - _fb_sse = f"data: {tool_chunk.model_dump_json(exclude_none=True)}\n\n" - logger.info(f"[SSE-FALLBACK-TC] {_fb_sse.strip()[:300]}") - yield _fb_sse + yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" - # Log throughput - elapsed = time.perf_counter() - start_time - tokens_per_sec = completion_tokens / elapsed if elapsed > 0 else 0 - logger.info( - f"Chat completion (stream): {completion_tokens} tokens in {elapsed:.2f}s ({tokens_per_sec:.1f} tok/s)" - ) - - # Send final chunk with usage if requested - if include_usage: - usage_chunk = ChatCompletionChunk( - id=response_id, - model=request.model, - choices=[], # Empty choices for usage-only chunk - usage=Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - yield f"data: {usage_chunk.model_dump_json(exclude_none=True)}\n\n" - - yield "data: [DONE]\n\n" - - # Re-enable GC and collect after generation completes - if _gc_control and gc_was_enabled: - gc.enable() - gc.collect() + yield "data: [DONE]\n\n" + finally: + # Re-enable GC even if generator is abandoned (client disconnect) + if _gc_control and gc_was_enabled: + gc.enable() + gc.collect() # =============================================================================