diff --git a/tests/test_server.py b/tests/test_server.py index 08b169bd..a55ae5fb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -712,6 +712,177 @@ def test_rate_limiter_window_cleanup(self): class TestStreamChatCompletion: """Tests for streaming chat completion behavior.""" + @pytest.mark.anyio + async def test_stream_without_parser_flags_emits_structured_tool_calls( + self, monkeypatch + ): + """Streaming tools should still parse without explicit parser flags.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="", finished=False), + GenerationOutput( + text="", + new_text="", + finished=False, + ), + GenerationOutput( + text="", + new_text="/Users/testuser", + finished=False, + ), + GenerationOutput( + text="", + new_text="", + finished=False, + ), + GenerationOutput( + text="", + new_text="", + finished=True, + finish_reason="stop", + prompt_tokens=5, + completion_tokens=7, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files in a directory", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + tool_payloads = [ + payload + for payload in payloads + if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls") + ] + + assert len(tool_payloads) == 1 + delta = tool_payloads[0]["choices"][0]["delta"] + assert delta["tool_calls"][0]["function"]["name"] == "list_directory" + assert delta["tool_calls"][0]["function"]["arguments"] == ( + '{"path": "/Users/testuser"}' + ) + assert delta["content"] is None + assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls" + assert tool_payloads[0]["usage"] == { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12, + } + + @pytest.mark.anyio + async def test_stream_without_parser_flags_keeps_plain_text(self, monkeypatch): + """Generic streaming fallback should not interfere with normal text.""" + from vllm_mlx.engine.base import GenerationOutput + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + stream_chat_completion, + ) + import vllm_mlx.server as server + + class FakeEngine: + model_name = "fake-engine" + + async def stream_chat(self, messages, **kwargs): + chunks = [ + GenerationOutput(text="", new_text="hello ", finished=False), + GenerationOutput( + text="", + new_text="world", + finished=True, + finish_reason="stop", + prompt_tokens=4, + completion_tokens=2, + ), + ] + for chunk in chunks: + yield chunk + + monkeypatch.setattr(server, "_model_name", "served-model") + monkeypatch.setattr(server, "_reasoning_parser", None) + monkeypatch.setattr(server, "_enable_auto_tool_choice", False) + monkeypatch.setattr(server, "_tool_call_parser", None) + monkeypatch.setattr(server, "_tool_parser_instance", None) + + request = ChatCompletionRequest( + model="served-model", + messages=[Message(role="user", content="hi")], + tools=[ + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List files in a directory", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + stream=True, + ) + + chunks = [ + chunk + async for chunk in stream_chat_completion( + FakeEngine(), request.messages, request + ) + ] + payloads = [ + json.loads(chunk.removeprefix("data: ").strip()) + for chunk in chunks + if chunk != "data: [DONE]\n\n" + ] + + assert payloads[1]["choices"][0]["delta"]["content"] == "hello " + assert payloads[2]["choices"][0]["delta"]["content"] == "world" + assert payloads[2]["choices"][0]["finish_reason"] == "stop" + @pytest.mark.anyio async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch): """Tool markup after should emit tool_calls chunks.""" diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 6cd9581b..8960f163 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -174,6 +174,15 @@ def _resolve_top_p(request_value: float | None) -> float: # Safety net: the tool parser should consume these, but if it doesn't # (e.g. malformed JSON, stray closing tags), strip them before emitting. _TOOL_MARKUP_PATTERN = re.compile(r"|") +_STREAMING_TOOL_MARKERS = ( + "", + "<|tool_call>", + "", + ' bool: + """Heuristic marker check to avoid parser work on ordinary text chunks.""" + return any(marker in text for marker in _STREAMING_TOOL_MARKERS) + + def load_embedding_model( model_name: str | None, *, @@ -2294,24 +2363,10 @@ async def _stream_anthropic_messages( # Tool call streaming suppression — prevents raw tool markup from leaking # as text_delta events. Mirrors the OpenAI streaming path logic. - global _tool_parser_instance tool_parser = None tool_accumulated_text = "" tool_markup_possible = False - tool_choice = getattr(openai_request, "tool_choice", None) - if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - except Exception: - pass - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() + tool_parser = _get_streaming_tool_parser(openai_request) try: async for output in engine.stream_chat(messages=messages, **chat_kwargs): @@ -2341,7 +2396,12 @@ async def _stream_anthropic_messages( # Filter tool call markup during streaming if tool_parser and content_to_emit: - if not tool_markup_possible and "<" not in content_to_emit: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content_to_emit + ) + ): tool_accumulated_text += content_to_emit else: if not tool_markup_possible: @@ -2386,7 +2446,12 @@ async def _stream_anthropic_messages( # Filter tool call markup during streaming if tool_parser and content_to_emit: - if not tool_markup_possible and "<" not in content_to_emit: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content_to_emit + ) + ): tool_accumulated_text += content_to_emit else: if not tool_markup_possible: @@ -2618,27 +2683,11 @@ async def stream_chat_completion( last_output = None # Tool call streaming state - global _tool_parser_instance tool_parser = None tool_accumulated_text = "" tool_calls_detected = False - tool_markup_possible = False # Fast path: skip parsing until '<' seen - tool_choice = getattr(request, "tool_choice", None) - if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none": - # Initialize parser if needed (same as _parse_tool_calls_with_parser) - if _tool_parser_instance is None: - try: - parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) - tokenizer = None - if _engine is not None and hasattr(_engine, "_tokenizer"): - tokenizer = _engine._tokenizer - _tool_parser_instance = parser_cls(tokenizer) - logger.info(f"Initialized tool call parser: {_tool_call_parser}") - except Exception as e: - logger.warning(f"Failed to init tool parser for streaming: {e}") - if _tool_parser_instance is not None: - tool_parser = _tool_parser_instance - tool_parser.reset() + tool_markup_possible = False # Fast path: skip parsing until markers appear + tool_parser = _get_streaming_tool_parser(request) try: # Stream content @@ -2689,7 +2738,12 @@ async def stream_chat_completion( # Tool call parsing on content portion if tool_parser and content: - if not tool_markup_possible and "<" not in content: + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + content + ) + ): tool_accumulated_text += content # Suppress whitespace-only content when tools are active; # avoids emitting stray newlines before tool call XML. @@ -2799,10 +2853,16 @@ async def stream_chat_completion( # Tool call streaming parsing if tool_parser and delta_text: - # Fast path: skip full parsing until '<' is seen in the stream, - # which could start tool markup (e.g. ). This avoids - # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: + # Fast path: skip full parsing until likely tool markup appears. + # This preserves the cheap path for ordinary text while still + # allowing generic streaming tool parsing when no explicit + # parser flags are configured. + if ( + not tool_markup_possible + and not _streaming_tool_markup_possible( + tool_accumulated_text + delta_text + ) + ): tool_accumulated_text += delta_text # No tool markup yet, fall through to normal chunk emission else: @@ -2883,11 +2943,7 @@ async def stream_chat_completion( tool_parser and tool_accumulated_text and not tool_calls_detected - and ( - "" in tool_accumulated_text - or "<|tool_call>" in tool_accumulated_text - or "