diff --git a/tests/test_engine_tool_output_preservation.py b/tests/test_engine_tool_output_preservation.py new file mode 100644 index 00000000..6e8c9dce --- /dev/null +++ b/tests/test_engine_tool_output_preservation.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests that tool-enabled chat preserves raw parser-visible output.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +class TestSimpleEngineToolOutputPreservation: + @pytest.mark.anyio + async def test_chat_with_tools_preserves_raw_harmony_output(self): + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text=( + "<|channel|>commentary to=functions.get_weather" + '<|message|>{"city":"Paris"}<|call|>' + ), + tokens=[], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = MagicMock() + engine._loaded = True + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "Weather in Paris?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert "<|channel|>commentary" in output.text + assert "<|call|>" in output.text + + +class TestBatchedEngineToolOutputPreservation: + @pytest.mark.anyio + async def test_chat_with_tools_preserves_raw_output(self): + from vllm_mlx.engine.batched import BatchedEngine + + raw_output = ( + "<|channel|>commentary to=functions.get_weather" + '<|message|>{"city":"Paris"}<|call|>' + ) + + with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False): + engine = BatchedEngine("test-model") + engine._loaded = True + engine._tokenizer = MagicMock() + engine._apply_chat_template = MagicMock(return_value="prompt") + engine._engine = MagicMock() + engine._engine.generate = AsyncMock( + return_value=MagicMock( + output_text=raw_output, + prompt_tokens=9, + completion_tokens=3, + finish_reason="stop", + ) + ) + + output = await engine.chat( + messages=[{"role": "user", "content": "Weather in Paris?"}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text == raw_output + engine._engine.generate.assert_called_once() diff --git a/tests/test_harmony_parsers.py b/tests/test_harmony_parsers.py index 9ca509f6..77d3db59 100644 --- a/tests/test_harmony_parsers.py +++ b/tests/test_harmony_parsers.py @@ -198,6 +198,23 @@ def test_nested_json_arguments(self, parser): parsed_args = json.loads(result.tool_calls[0]["arguments"]) assert parsed_args["filter"]["type"] == "range" + def test_consecutive_duplicate_tool_calls_are_deduped(self, parser): + """Repeated identical commentary blocks should collapse to one call.""" + text = ( + "<|channel|>commentary to=functions.get_weather\n" + "<|constrain|>json\n" + '<|message|>{"city":"Paris"}\n' + "<|call|>\n" + "<|channel|>commentary to=functions.get_weather\n" + "<|constrain|>json\n" + '<|message|>{"city":"Paris"}\n' + "<|call|>" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + def test_streaming_no_tool_markers(self, parser): """Streaming: plain text passes through as content.""" result = parser.extract_tool_calls_streaming("", "Hello", "Hello") @@ -227,6 +244,25 @@ def test_streaming_building_tool_call(self, parser): result = parser.extract_tool_calls_streaming("", current, '{"a":') assert result is None + def test_streaming_duplicate_tool_call_is_not_reemitted(self, parser): + """Streaming should only emit newly completed Harmony tool calls.""" + previous = ( + "<|channel|>commentary to=functions.func\n" + "<|constrain|>json\n" + '<|message|>{"a": 1}\n' + "<|call|>" + ) + current = ( + previous + + "\n<|channel|>commentary to=functions.func\n" + + "<|constrain|>json\n" + + '<|message|>{"a": 1}\n' + + "<|call|>" + ) + + result = parser.extract_tool_calls_streaming(previous, current, "<|call|>") + assert result is None + # ============================================================================ # Reasoning Parser Tests @@ -388,26 +424,60 @@ def test_streaming_reset(self, parser): assert parser._current_channel is None assert parser._in_message is False - def test_streaming_commentary_suppressed(self, parser): - """Streaming: commentary channel output is suppressed.""" + def test_streaming_commentary_routed_as_content(self, parser): + """Streaming: commentary channel is forwarded for downstream tool parsing.""" parser.reset_state() - parser.extract_reasoning_streaming( + r1 = parser.extract_reasoning_streaming( "", "<|channel|>commentary to=functions.f\n", "<|channel|>commentary to=functions.f\n", ) - parser.extract_reasoning_streaming( + assert r1 is not None + assert r1.content == "<|channel|>commentary to=functions.f\n" + + r2 = parser.extract_reasoning_streaming( "<|channel|>commentary to=functions.f\n", "<|channel|>commentary to=functions.f\n<|message|>", "<|message|>", ) + assert r2 is not None + assert r2.content == "<|message|>" + r = parser.extract_reasoning_streaming( "<|channel|>commentary to=functions.f\n<|message|>", '<|channel|>commentary to=functions.f\n<|message|>{"a":1}', '{"a":1}', ) - assert r is None + assert r is not None + assert r.content == '{"a":1}' + + def test_streaming_split_commentary_header(self, parser): + """Split commentary headers should still be routed to downstream tool parsing.""" + parser.reset_state() + + accumulated = "" + content_parts = [] + for token in [ + "<|channel|>", + "comment", + "ary to", + "=functions.get_weather", + " <|constrain|>", + "json", + "<|message|>", + '{"city":"Paris"}', + "<|call|>", + ]: + prev = accumulated + accumulated += token + result = parser.extract_reasoning_streaming(prev, accumulated, token) + if result and result.content: + content_parts.append(result.content) + + combined = "".join(content_parts) + assert "<|channel|>commentary to=functions.get_weather" in combined + assert '<|message|>{"city":"Paris"}' in combined # ============================================================================ diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py index cce42bfc..7590ffb5 100644 --- a/tests/test_simple_engine.py +++ b/tests/test_simple_engine.py @@ -10,6 +10,10 @@ class TestSimpleEngineConcurrency: """Test SimpleEngine lock behavior with concurrent requests.""" + @pytest.fixture + def anyio_backend(self): + return "asyncio" + @pytest.fixture def mock_model(self): """Create a mock model that tracks concurrent calls.""" @@ -65,7 +69,7 @@ def chat_side_effect(**kwargs): model.chat = MagicMock(side_effect=chat_side_effect) return model - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_generate(self, mock_model): """Test that the lock prevents concurrent generate calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -89,7 +93,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_lock_prevents_concurrent_chat(self, mock_llm_model): """Test that the lock prevents concurrent chat calls.""" from vllm_mlx.engine.simple import SimpleEngine @@ -115,7 +119,57 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model): "The lock is not working correctly." ) - @pytest.mark.asyncio + @pytest.mark.anyio + async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model): + """Tool-enabled non-stream chat should use the streaming path.""" + from vllm_mlx.engine.simple import SimpleEngine + + async def fake_stream_chat(*args, **kwargs): + yield MagicMock( + text="partial", + tokens=[], + prompt_tokens=11, + completion_tokens=1, + finish_reason=None, + finished=False, + ) + yield MagicMock( + text="<|im_end|>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}", + tokens=[], + prompt_tokens=11, + completion_tokens=4, + finish_reason="stop", + finished=True, + ) + + with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False): + engine = SimpleEngine("test-model") + engine._model = mock_llm_model + engine._loaded = True + engine.stream_chat = fake_stream_chat # type: ignore[method-assign] + + output = await engine.chat( + messages=[{"role": "user", "content": "run pwd"}], + max_tokens=16, + tools=[ + { + "type": "function", + "function": { + "name": "bash", + "parameters": {"type": "object", "properties": {}}, + }, + } + ], + ) + + assert output.text.startswith("<|im_end|>") + assert output.tokens == [] + assert output.prompt_tokens == 11 + assert output.completion_tokens == 4 + assert output.finish_reason == "stop" + mock_llm_model.chat.assert_not_called() + + @pytest.mark.anyio async def test_lock_serializes_stream_generate(self, mock_model): """Test that stream_generate uses the same lock as other methods.""" from vllm_mlx.engine.simple import SimpleEngine @@ -178,7 +232,7 @@ async def try_stream(): result = await stream_task assert len(result) == 3, f"Expected 3 chunks, got {len(result)}" - @pytest.mark.asyncio + @pytest.mark.anyio async def test_engine_initialization_creates_lock(self): """Test that SimpleEngine creates a lock on initialization.""" from vllm_mlx.engine.simple import SimpleEngine @@ -189,7 +243,7 @@ async def test_engine_initialization_creates_lock(self): assert hasattr(engine, "_generation_lock") assert isinstance(engine._generation_lock, asyncio.Lock) - @pytest.mark.asyncio + @pytest.mark.anyio async def test_requests_complete_in_order(self, mock_model): """Test that concurrent requests complete (may be in any order due to lock).""" from vllm_mlx.engine.simple import SimpleEngine diff --git a/tests/test_streaming_reasoning_tools.py b/tests/test_streaming_reasoning_tools.py new file mode 100644 index 00000000..c6726bf1 --- /dev/null +++ b/tests/test_streaming_reasoning_tools.py @@ -0,0 +1,396 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Streaming chat-completion tests for reasoning + tool parser coexistence.""" + +import json +from collections.abc import AsyncIterator +from typing import Any + +import pytest + +import vllm_mlx.server as server +from vllm_mlx.api.models import ChatCompletionRequest, Message, ToolDefinition +from vllm_mlx.engine.base import BaseEngine, GenerationOutput +from vllm_mlx.reasoning import get_parser +from vllm_mlx.tool_parsers import ToolParserManager + + +TEST_TOOL = ToolDefinition( + function={ + "name": "get_weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + } +) + + +class FakeStreamEngine(BaseEngine): + """Minimal engine for deterministic stream_chat tests.""" + + def __init__(self, deltas: list[str], model_name: str = "test-model"): + self._deltas = deltas + self._model_name = model_name + self._tokenizer = None + + @property + def model_name(self) -> str: + return self._model_name + + @property + def is_mllm(self) -> bool: + return False + + @property + def tokenizer(self) -> Any: + return None + + async def start(self) -> None: + return None + + async def stop(self) -> None: + return None + + async def generate(self, *args, **kwargs) -> GenerationOutput: + raise NotImplementedError + + async def stream_generate(self, *args, **kwargs) -> AsyncIterator[GenerationOutput]: + raise NotImplementedError + + async def chat(self, *args, **kwargs) -> GenerationOutput: + raise NotImplementedError + + async def stream_chat(self, *args, **kwargs) -> AsyncIterator[GenerationOutput]: + text = "" + for i, delta in enumerate(self._deltas): + text += delta + yield GenerationOutput( + text=text, + new_text=delta, + finished=i == len(self._deltas) - 1, + completion_tokens=i + 1, + finish_reason="stop", + ) + + +def _collect_payloads(stream_output: list[str]) -> list[dict[str, Any]]: + payloads: list[dict[str, Any]] = [] + for chunk in stream_output: + assert chunk.startswith("data: ") + payload = chunk[6:].strip() + if payload == "[DONE]": + continue + payloads.append(json.loads(payload)) + return payloads + + +def _flatten_deltas(payloads: list[dict[str, Any]]) -> list[dict[str, Any]]: + deltas: list[dict[str, Any]] = [] + for payload in payloads: + for choice in payload["choices"]: + deltas.append( + { + "delta": choice["delta"], + "finish_reason": choice["finish_reason"], + } + ) + return deltas + + +def _reasoning_text(deltas: list[dict[str, Any]]) -> str: + return "".join(delta["delta"].get("reasoning") or "" for delta in deltas) + + +def _content_text(deltas: list[dict[str, Any]]) -> str: + return "".join(delta["delta"].get("content") or "" for delta in deltas) + + +def _tool_calls(deltas: list[dict[str, Any]]) -> list[dict[str, Any]]: + calls: list[dict[str, Any]] = [] + for delta in deltas: + calls.extend(delta["delta"].get("tool_calls") or []) + return calls + + +async def _run_stream( + monkeypatch: pytest.MonkeyPatch, + *, + deltas: list[str], + reasoning_parser: str | None, + tool_parser: str | None, + model_name: str = "test-model", +) -> list[dict[str, Any]]: + engine = FakeStreamEngine(deltas, model_name=model_name) + parser_instance = ( + ToolParserManager.get_tool_parser(tool_parser)() if tool_parser else None + ) + + monkeypatch.setattr(server, "_engine", engine) + monkeypatch.setattr(server, "_model_name", model_name) + monkeypatch.setattr( + server, + "_reasoning_parser", + get_parser(reasoning_parser)() if reasoning_parser else None, + ) + monkeypatch.setattr(server, "_enable_auto_tool_choice", tool_parser is not None) + monkeypatch.setattr(server, "_tool_call_parser", tool_parser) + monkeypatch.setattr(server, "_tool_parser_instance", parser_instance) + + request = ChatCompletionRequest( + model=model_name, + messages=[Message(role="user", content="Weather in Paris?")], + stream=True, + tools=[TEST_TOOL], + ) + + raw_chunks: list[str] = [] + async for chunk in server.stream_chat_completion( + engine, + [{"role": "user", "content": "Weather in Paris?"}], + request, + ): + raw_chunks.append(chunk) + + return _flatten_deltas(_collect_payloads(raw_chunks)) + + +@pytest.mark.anyio +async def test_streaming_qwen_reasoning_then_tool_call(monkeypatch): + deltas = [ + "", + "Need tool", + "", + "\n", + '{"name":"get_weather","arguments":{"city":"Paris"}}', + "\n", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="qwen3", + tool_parser="qwen", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "Need tool" + assert _content_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"} + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_qwen_direct_tool_call_with_reasoning_parser(monkeypatch): + deltas = [ + "\n", + '{"name":"get_weather","arguments":{"city":"Paris"}}', + "\n", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="qwen3", + tool_parser="qwen", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "" + assert _content_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_qwen_reasoning_then_plain_text(monkeypatch): + deltas = ["", "R", "", "Hello", " world"] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="qwen3", + tool_parser="qwen", + ) + + assert _reasoning_text(deltas_out) == "R" + assert _content_text(deltas_out) == "Hello world" + assert _tool_calls(deltas_out) == [] + assert deltas_out[-1]["finish_reason"] == "stop" + + +@pytest.mark.anyio +async def test_streaming_qwen_tool_call_inside_think_is_not_emitted(monkeypatch): + deltas = [ + "", + 'R {"name":"get_weather","arguments":{"city":"Paris"}}', + "", + "Text", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="qwen3", + tool_parser="qwen", + ) + + assert "" in _reasoning_text(deltas_out) + assert _content_text(deltas_out) == "Text" + assert _tool_calls(deltas_out) == [] + assert deltas_out[-1]["finish_reason"] == "stop" + + +@pytest.mark.anyio +async def test_streaming_harmony_reasoning_then_tool_call(monkeypatch): + deltas = [ + "<|channel|>analysis", + "<|message|>", + "Need weather", + "<|end|>", + "<|channel|>commentary to=functions.get_weather", + "<|constrain|>json", + "<|message|>", + '{"city":"Paris"}', + "<|call|>", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="harmony", + tool_parser="harmony", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "Need weather" + assert _content_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"} + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_gpt_oss_split_commentary_routes_to_harmony_tool_parser( + monkeypatch, +): + deltas = [ + "<|channel|>", + "analysis", + "<|message|>", + "Need weather", + "<|end|>", + "<|channel|>", + "comment", + "ary to", + "=functions.get_weather", + " <|constrain|>", + "json", + "<|message|>", + '{"city":"Paris"}', + "<|call|>", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="gpt_oss", + tool_parser="harmony", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "Need weather" + assert _content_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"} + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_qwen_function_parameter_tool_call(monkeypatch): + deltas = [ + "", + "Need tool", + "", + "\n", + "\n", + "\n", + '"Paris"\n', + "\n", + "\n", + "3\n", + "\n", + "\n", + "", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser="qwen3", + tool_parser="qwen", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "Need tool" + assert _content_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert json.loads(calls[0]["function"]["arguments"]) == { + "city": "Paris", + "days": 3, + } + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_without_reasoning_parser_keeps_qwen_tools_working(monkeypatch): + deltas = [ + "\n", + '{"name":"get_weather","arguments":{"city":"Paris"}}', + "\n", + ] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser=None, + tool_parser="qwen", + ) + + calls = _tool_calls(deltas_out) + + assert _reasoning_text(deltas_out) == "" + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert deltas_out[-1]["finish_reason"] == "tool_calls" + + +@pytest.mark.anyio +async def test_streaming_without_reasoning_parser_keeps_mistral_tools_working( + monkeypatch, +): + deltas = ["[TOOL_CALLS]", 'get_weather{"city":"Paris"}'] + + deltas_out = await _run_stream( + monkeypatch, + deltas=deltas, + reasoning_parser=None, + tool_parser="mistral", + ) + + calls = _tool_calls(deltas_out) + + assert len(calls) == 1 + assert calls[0]["function"]["name"] == "get_weather" + assert deltas_out[-1]["finish_reason"] == "tool_calls" diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py index dfe2bb6a..6606f0c2 100644 --- a/tests/test_tool_parsers.py +++ b/tests/test_tool_parsers.py @@ -185,6 +185,24 @@ def test_multiple_xml_calls(self, parser): assert result.tools_called assert len(result.tool_calls) == 2 + def test_function_parameter_format(self, parser): + """Test parsing Qwen's function/parameter XML format.""" + text = ( + "\n" + "\n" + "Paris\n" + "3\n" + "\n" + "" + ) + result = parser.extract_tool_calls(text) + + assert result.tools_called + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "get_weather" + args = json.loads(result.tool_calls[0]["arguments"]) + assert args == {"city": "Paris", "days": 3} + def test_no_tool_call(self, parser): """Test text without tool calls.""" text = "I can help you with that question." @@ -681,6 +699,20 @@ def test_auto_streaming(self): ) assert result == {"content": "Hello world"} + def test_qwen_streaming_split_closing_tag(self): + """Qwen streaming should finish on accumulated text, not lucky final deltas.""" + parser = QwenToolParser() + + r = parser.extract_tool_calls_streaming( + previous_text="Paris tag stripping in tool parsers (Issue #26).""" diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py index ce33e628..b2c92983 100644 --- a/vllm_mlx/engine/batched.py +++ b/vllm_mlx/engine/batched.py @@ -453,6 +453,8 @@ async def generate( if not self._loaded: await self.start() + preserve_raw_output = bool(kwargs.pop("preserve_raw_output", False)) + if self._is_mllm and self._mllm_scheduler: # Use MLLM scheduler for all requests when model is multimodal. # MLLM models only initialise the _mllm_scheduler (not _engine), @@ -467,7 +469,11 @@ async def generate( ) return GenerationOutput( - text=clean_output_text(output.output_text), + text=( + output.output_text + if preserve_raw_output + else clean_output_text(output.output_text) + ), prompt_tokens=output.prompt_tokens, completion_tokens=output.completion_tokens, finish_reason=output.finish_reason, @@ -488,7 +494,11 @@ async def generate( sampling_params=sampling_params, ) - text = clean_output_text(output.output_text) + text = ( + output.output_text + if preserve_raw_output + else clean_output_text(output.output_text) + ) return GenerationOutput( text=text, @@ -635,6 +645,7 @@ async def chat( top_p=top_p, images=all_images if all_images else None, videos=all_videos if all_videos else None, + preserve_raw_output=bool(tools), **kwargs, ) diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py index e96317ef..33d45305 100644 --- a/vllm_mlx/engine/simple.py +++ b/vllm_mlx/engine/simple.py @@ -437,6 +437,36 @@ async def chat( if not self._loaded: await self.start() + chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {}) + + # mlx-lm non-streaming chat with tools can stall indefinitely on some + # local models, while the streaming path completes normally. Reuse the + # streaming implementation and aggregate its final state so both chat + # APIs share the same tool-capable execution path. + if tools and not self._is_mllm: + stream_kwargs = dict(kwargs) + if chat_template_kwargs: + stream_kwargs["chat_template_kwargs"] = chat_template_kwargs + final_output = GenerationOutput(text="") + async for output in self.stream_chat( + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + tools=tools, + images=images, + videos=videos, + **stream_kwargs, + ): + final_output = output + text = final_output.text + return GenerationOutput( + text=text, + tokens=list(final_output.tokens), + prompt_tokens=final_output.prompt_tokens, + completion_tokens=final_output.completion_tokens, + finish_reason=final_output.finish_reason, + ) # Convert tools for template if provided template_tools = convert_tools_for_template(tools) if tools else None @@ -450,6 +480,7 @@ async def chat( max_tokens=max_tokens, temperature=temperature, tools=template_tools, + chat_template_kwargs=chat_template_kwargs, **kwargs, ) text = clean_output_text(output.text) @@ -469,6 +500,7 @@ async def chat( temperature=temperature, top_p=top_p, tools=template_tools, + chat_template_kwargs=chat_template_kwargs, **kwargs, ) text = clean_output_text(output.text) diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py index 72182037..9669d0c0 100644 --- a/vllm_mlx/models/llm.py +++ b/vllm_mlx/models/llm.py @@ -231,6 +231,17 @@ def stream_generate( should_stop = True break + eos_token_ids = getattr(self.tokenizer, "eos_token_ids", None) + if eos_token_ids is None: + eos_token_id = getattr(self.tokenizer, "eos_token_id", None) + eos_token_ids = [] if eos_token_id is None else [eos_token_id] + elif isinstance(eos_token_ids, int): + eos_token_ids = [eos_token_ids] + + response_token = getattr(response, "token", None) + if response_token is not None and response_token in set(eos_token_ids): + should_stop = True + finished = should_stop or token_count >= max_tokens finish_reason = None if finished: diff --git a/vllm_mlx/reasoning/gpt_oss_parser.py b/vllm_mlx/reasoning/gpt_oss_parser.py index 8541faf2..fc2f9882 100644 --- a/vllm_mlx/reasoning/gpt_oss_parser.py +++ b/vllm_mlx/reasoning/gpt_oss_parser.py @@ -15,6 +15,7 @@ import re from .base import DeltaMessage, ReasoningParser +from .harmony_parser import HarmonyReasoningParser # Structural tokens that should be stripped from output _STRUCTURAL_TOKENS = re.compile( @@ -69,6 +70,12 @@ class GptOssReasoningParser(ReasoningParser): <|channel|>final <|constrain|>JSON<|message|>[content]<|return|> """ + def __init__(self, tokenizer=None): + super().__init__(tokenizer) + # GPT-OSS streams the same Harmony channel tokens as the dedicated + # Harmony parser, including split channel markers across deltas. + self._stream_parser = HarmonyReasoningParser(tokenizer) + def extract_reasoning( self, model_output: str, @@ -124,41 +131,13 @@ def extract_reasoning_streaming( Returns: DeltaMessage with reasoning and/or content, or None to skip. """ - prev_phase = self._detect_phase(previous_text) - curr_phase = self._detect_phase(current_text) - - # Phase changed — extract content after the new marker - if curr_phase != prev_phase and curr_phase in ("analysis", "final"): - after_marker = self._extract_content_after_marker_in_delta( - current_text, curr_phase - ) - if after_marker: - after_marker = self._strip_return(after_marker) - if curr_phase == "analysis": - return DeltaMessage(reasoning=after_marker) - else: - return DeltaMessage(content=after_marker) - return None - - # In a steady phase — emit delta directly - if curr_phase == "analysis": - cleaned = self._strip_return(delta_text) - # Skip structural tokens in the delta - if _STRUCTURAL_TOKENS.search(cleaned): - cleaned = _STRUCTURAL_TOKENS.sub("", cleaned) - if cleaned: - return DeltaMessage(reasoning=cleaned) - return None - elif curr_phase == "final": - cleaned = self._strip_return(delta_text) - if _STRUCTURAL_TOKENS.search(cleaned): - cleaned = _STRUCTURAL_TOKENS.sub("", cleaned) - if cleaned: - return DeltaMessage(content=cleaned) - return None - - # init or transition phase — skip structural tokens - return None + return self._stream_parser.extract_reasoning_streaming( + previous_text, current_text, delta_text + ) + + def reset_state(self): + """Reset streaming state for a new request.""" + self._stream_parser.reset_state() @staticmethod def _detect_phase(text: str) -> str: diff --git a/vllm_mlx/reasoning/harmony_parser.py b/vllm_mlx/reasoning/harmony_parser.py index 73b94f0e..63fb4aee 100644 --- a/vllm_mlx/reasoning/harmony_parser.py +++ b/vllm_mlx/reasoning/harmony_parser.py @@ -50,6 +50,8 @@ def __init__(self, tokenizer=None): super().__init__(tokenizer) self._current_channel: str | None = None self._in_message: bool = False + self._pending_channel_text = "" + self._collecting_channel = False def extract_reasoning( self, @@ -97,31 +99,34 @@ def extract_reasoning_streaming( Returns: DeltaMessage with reasoning and/or content, or None. """ + if self._collecting_channel: + self._pending_channel_text += delta_text + resolved = self._try_resolve_pending_channel() + if resolved is not None: + return resolved + # Detect channel switches in the delta if "<|channel|>" in delta_text: - if "analysis" in delta_text: - self._current_channel = "analysis" - self._in_message = False - return None - elif "final" in delta_text: - self._current_channel = "final" - self._in_message = False - return None - elif "commentary" in delta_text: - self._current_channel = "commentary" - self._in_message = False - return None + self._collecting_channel = True + self._pending_channel_text = delta_text + resolved = self._try_resolve_pending_channel() + if resolved is not None: + return resolved + return None # Detect channel from full context if not yet determined if self._current_channel is None and "<|channel|>" in current_text: last_channel = current_text.rfind("<|channel|>") after = current_text[last_channel + len("<|channel|>") :] - if after.startswith("analysis"): - self._current_channel = "analysis" - elif after.startswith("final"): - self._current_channel = "final" - elif after.startswith("commentary"): - self._current_channel = "commentary" + self._current_channel = self._resolve_channel_name(after) + + # Commentary is routed through the tool parser via DeltaMessage.content. + if self._current_channel == "commentary": + if "<|message|>" in delta_text: + self._in_message = True + if any(token in delta_text for token in ("<|call|>", "<|end|>", "<|return|>")): + self._in_message = False + return DeltaMessage(content=delta_text) # Handle message start if "<|message|>" in delta_text: @@ -155,3 +160,40 @@ def reset_state(self): """Reset streaming state for a new request.""" self._current_channel = None self._in_message = False + self._pending_channel_text = "" + self._collecting_channel = False + + @staticmethod + def _resolve_channel_name(text: str) -> str | None: + """Extract the active Harmony channel name from a delta or suffix.""" + if "analysis" in text: + return "analysis" + if "final" in text: + return "final" + if "commentary" in text: + return "commentary" + return None + + def _try_resolve_pending_channel(self) -> DeltaMessage | None: + """Resolve a partially streamed channel header once enough text arrives.""" + suffix = self._pending_channel_text.split("<|channel|>", 1)[-1] + channel_name = self._resolve_channel_name(suffix) + if channel_name is None: + return None + + buffered = self._pending_channel_text + self._pending_channel_text = "" + self._collecting_channel = False + self._current_channel = channel_name + self._in_message = "<|message|>" in buffered + + if channel_name == "commentary": + return DeltaMessage(content=buffered) + + if "<|message|>" in buffered: + content = buffered.split("<|message|>", 1)[1] + if content: + if channel_name == "analysis": + return DeltaMessage(reasoning=content) + return DeltaMessage(content=content) + return None diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 4cb15f02..ba2076a9 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -162,6 +162,15 @@ def _resolve_top_p(request_value: float | None) -> float: _tool_parser_instance = None # Instantiated parser +def _looks_like_streaming_tool_markup(delta_text: str) -> bool: + """Cheap trigger check before invoking streaming tool parsers.""" + return ( + "<" in delta_text + or "[TOOL_CALLS]" in delta_text + or "[Calling tool:" in delta_text + ) + + def _load_prefix_cache_from_disk() -> None: """Load prefix cache from disk during startup.""" try: @@ -418,7 +427,1042 @@ def _parse_tool_calls_with_parser( logger.warning(f"Tool parser error: {e}") return parse_tool_calls(output_text, request_dict) +def _parse_and_validate_tools( + output_text: str, + request: ChatCompletionRequest, + should_parse: bool, +) -> tuple[str, list | None]: + """Parse tool calls from model output and validate against declared tools. + + When *should_parse* is False (tool_choice="none"), returns the raw text + with no tool calls. + """ + if not should_parse: + return output_text, None + cleaned, tool_calls = _parse_tool_calls_with_parser(output_text, request) + if tool_calls: + tool_calls = _validate_tool_calls(tool_calls, request.tools) + return cleaned, tool_calls + + +def _set_tool_grammar_processor(chat_kwargs: dict, tools: list) -> None: + """Attach an Outlines grammar-constrained logits processor for tool calls. + + Modifies *chat_kwargs* in place. Does nothing when Outlines is unavailable. + """ + from .guided_decoding import build_tool_call_processor + + processor = build_tool_call_processor(tools) + if processor: + chat_kwargs["logits_processors"] = [processor] + + +def _apply_tool_choice( + tool_choice: str | dict | None, + chat_kwargs: dict, + messages: list[dict], +) -> bool: + """Apply tool_choice policy to chat kwargs and messages. + + Modifies *chat_kwargs* and *messages* in place so that the chat template + and downstream parsing honour the caller's tool_choice setting. + + Returns ``True`` when the model output should be parsed for tool calls, + ``False`` when tool-call parsing must be skipped (``tool_choice="none"``). + """ + if tool_choice == "none": + chat_kwargs.pop("tools", None) + return False + + if tool_choice == "required": + messages.insert( + 0, + { + "role": "system", + "content": ( + "You MUST call one of the provided tools. " + "Do not respond with plain text." + ), + }, + ) + _set_tool_grammar_processor(chat_kwargs, chat_kwargs.get("tools", [])) + return True + + if isinstance(tool_choice, dict): + func_info = tool_choice.get("function", {}) + fname = func_info.get("name", "") if isinstance(func_info, dict) else "" + if fname: + template_tools = chat_kwargs.get("tools") + if template_tools: + filtered = [ + t + for t in template_tools + if t.get("function", {}).get("name") == fname + ] + if filtered: + chat_kwargs["tools"] = filtered + messages.insert( + 0, + { + "role": "system", + "content": f"You MUST call the function: {fname}", + }, + ) + _set_tool_grammar_processor(chat_kwargs, filtered) + return True + # Named function not found in tools — fall back to auto + logger.warning( + "tool_choice function %r not found in tools, falling back to auto", + fname, + ) + return True + + # "auto" or None — apply lazy grammar trigger when the parser defines one. + # The processor stays inactive until the model emits a tool call trigger + # token, then constrains the JSON body to match the tool schemas. + if chat_kwargs.get("tools") and _tool_call_parser and _tool_parser_instance: + parser_cls = type(_tool_parser_instance) + if parser_cls.TRIGGER_TOKEN_IDS: + from .guided_decoding import build_lazy_tool_call_processor + + processor = build_lazy_tool_call_processor( + tools=chat_kwargs.get("tools", []), + trigger_tokens=parser_cls.TRIGGER_TOKEN_IDS, + end_tokens=parser_cls.END_TOKEN_IDS, + prefix_skip=parser_cls.PREFIX_SKIP_TOKENS, + ) + if processor: + chat_kwargs["logits_processors"] = [processor] + return True + + +def _validate_tool_calls( + tool_calls: list | None, + tools: list | None, +) -> list | None: + """Remove tool calls with unknown names or invalid arguments. + + Validates each parsed tool call against the declared tools: + 1. Function name must exist in the tools list. + 2. Arguments must be valid JSON. + 3. Arguments must conform to the tool's parameters schema (if declared). + + Invalid tool calls are dropped with a warning log. + Returns None if all tool calls are invalid. + """ + if not tool_calls: + return None + if tools is None: + return tool_calls + + # Build lookup: function name -> parameters schema + tools_by_name: dict[str, dict | None] = {} + for t in tools: + func = ( + t.function + if isinstance(t, ToolDefinition) + else (t.get("function") if isinstance(t, dict) else None) + ) + if func and isinstance(func, dict): + fname = func.get("name") + if fname: + tools_by_name[fname] = func.get("parameters") + + valid = [] + for tc in tool_calls: + name = tc.function.name + if name not in tools_by_name: + logger.warning("Dropping tool call with unknown function: %s", name) + continue + schema = tools_by_name[name] + if schema: + try: + args = json.loads(tc.function.arguments) + jsonschema.validate(args, schema) + except (json.JSONDecodeError, jsonschema.ValidationError, jsonschema.SchemaError) as exc: + logger.warning( + "Dropping tool call %s: invalid arguments: %s", name, exc + ) + continue + valid.append(tc) + + return valid or None + + +def _new_response_item_id(prefix: str) -> str: + """Generate stable OpenAI-style item ids.""" + return f"{prefix}_{uuid.uuid4().hex}" + + +def _response_content_to_text(content) -> str: + """Normalize Responses API content items into plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content + + text_parts = [] + for part in content: + if isinstance(part, dict): + part_type = part.get("type") + text = part.get("text", "") + else: + part_type = getattr(part, "type", None) + text = getattr(part, "text", "") + if part_type in {"text", "input_text", "output_text"}: + text_parts.append(text) + return "\n".join(part for part in text_parts if part) + + +def _responses_tools_to_chat_tools( + tools: list[ResponseFunctionTool | dict], +) -> tuple[list[dict] | None, list[str]]: + """Convert supported Responses tools and report unsupported tool types.""" + if not tools: + return None, [] + + supported: list[dict] = [] + unsupported: list[str] = [] + + for tool in tools: + if isinstance(tool, ResponseFunctionTool): + tool_type = tool.type + tool_name = tool.name + tool_description = tool.description or "" + tool_parameters = tool.parameters + elif isinstance(tool, dict): + tool_type = tool.get("type", "unknown") + tool_name = tool.get("name", "") + tool_description = tool.get("description", "") + tool_parameters = tool.get("parameters", {}) + else: + unsupported.append(type(tool).__name__) + continue + + if tool_type == "function": + supported.append( + { + "type": "function", + "function": { + "name": tool_name, + "description": tool_description, + "parameters": tool_parameters + or {"type": "object", "properties": {}}, + }, + } + ) + else: + unsupported.append(tool_type) + + return supported or None, unsupported + + +def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: + """Convert Responses API input items into chat-completions-style messages.""" + messages: list[dict] = [] + + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + + if request.instructions: + messages.append({"role": "system", "content": request.instructions}) + + if isinstance(request.input, str): + messages.append({"role": "user", "content": request.input}) + return messages + + for item in request.input: + if isinstance(item, dict): + item_type = item.get("type", "") + if item_type == "message": + role = item.get("role", "user") + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.get("content")), + } + ) + elif item_type == "function_call": + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.get("call_id", _new_response_item_id("call")), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", ""), + }, + } + ], + } + ) + elif item_type == "function_call_output": + messages.append( + { + "role": "tool", + "tool_call_id": item.get("call_id", ""), + "content": item.get("output", ""), + } + ) + elif item_type == "reasoning": + parts = item.get("content", []) + reasoning_text = "\n".join( + p.get("text", "") for p in parts if isinstance(p, dict) + ) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", item_type + ) + continue + + if isinstance(item, ResponseMessageItem): + role = item.role + if role == "developer": + role = "system" + messages.append( + { + "role": role, + "content": _response_content_to_text(item.content), + } + ) + elif isinstance(item, ResponseFunctionCallItem): + messages.append( + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ], + } + ) + elif isinstance(item, ResponseFunctionCallOutputItem): + messages.append( + { + "role": "tool", + "tool_call_id": item.call_id, + "content": item.output, + } + ) + elif isinstance(item, ResponseReasoningItem): + reasoning_text = "\n".join(part.text for part in (item.content or [])) + if reasoning_text: + messages.append({"role": "assistant", "content": reasoning_text}) + else: + logger.info( + "Skipping unsupported Responses input item type %r", + getattr(item, "type", type(item).__name__), + ) + + return messages + + +def _responses_request_to_new_persisted_messages(request: ResponsesRequest) -> list[dict]: + """Persist only the current request's replayable input items.""" + request_without_history = request.model_copy( + update={"previous_response_id": None, "instructions": None}, + deep=True, + ) + return _responses_input_to_chat_messages(request_without_history) + + +def _responses_request_to_persisted_messages(request: ResponsesRequest) -> list[dict]: + """Persist replayable history for chained previous_response_id requests. + + Responses `instructions` are intentionally not replayed across + `previous_response_id`, but replayable message items are. + """ + messages: list[dict] = [] + if request.previous_response_id: + previous = _responses_store.get(request.previous_response_id) + if previous is None: + raise HTTPException( + status_code=404, + detail=f"Previous response `{request.previous_response_id}` not found", + ) + messages.extend(copy.deepcopy(previous["messages"])) + messages.extend(_responses_request_to_new_persisted_messages(request)) + return messages + + +def _responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest: + """Build a ChatCompletionRequest from a ResponsesRequest.""" + if request.text.format.type == "json_object": + raise HTTPException( + status_code=400, + detail="Responses text.format.type='json_object' is not supported on this backend", + ) + if request.reasoning is not None: + logger.debug("Ignoring reasoning configuration (not supported on this backend)") + + tools, unsupported_tools = _responses_tools_to_chat_tools(request.tools) + messages = _responses_input_to_chat_messages(request) + if unsupported_tools: + tool_list = ", ".join(sorted(set(unsupported_tools))) + messages.insert( + 0, + { + "role": "system", + "content": ( + "The following requested tool types are not available on this " + f"backend: {tool_list}. Do not call them." + ), + }, + ) + + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + merged_system_content = "\n\n".join( + str(msg.get("content", "")).strip() + for msg in system_messages + if str(msg.get("content", "")).strip() + ) + messages = ( + [{"role": "system", "content": merged_system_content}] + if merged_system_content + else [] + ) + non_system_messages + + return ChatCompletionRequest( + model=request.model, + messages=[Message(**msg) for msg in messages], + temperature=request.temperature, + top_p=request.top_p, + max_tokens=request.max_output_tokens, + stream=False, + tools=tools, + tool_choice=request.tool_choice, + ) + +def _build_responses_output_items( + text: str | None, + reasoning: str | None, + tool_calls: list[ToolCall] | None, +) -> list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem]: + """Convert parsed assistant output into Responses API output items.""" + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + + if reasoning: + output_items.append( + ResponseReasoningItem( + id=_new_response_item_id("rs"), + content=[ResponseReasoningTextPart(text=reasoning)], + ) + ) + + if text: + output_items.append( + ResponseMessageItem( + id=_new_response_item_id("msg"), + role="assistant", + content=[ResponseTextContentPart(type="output_text", text=text)], + ) + ) + + for tool_call in tool_calls or []: + output_items.append( + ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + ) + + return output_items + + +def _response_output_items_to_chat_messages(output_items: list) -> list[dict]: + """Persist assistant output in chat-completions form for previous_response_id.""" + assistant_text_parts: list[str] = [] + assistant_tool_calls: list[dict] = [] + + for item in output_items: + if isinstance(item, ResponseMessageItem): + assistant_text_parts.append(_response_content_to_text(item.content)) + elif isinstance(item, ResponseFunctionCallItem): + assistant_tool_calls.append( + { + "id": item.call_id, + "type": "function", + "function": { + "name": item.name, + "arguments": item.arguments, + }, + } + ) + + if not assistant_text_parts and not assistant_tool_calls: + return [] + + return [ + { + "role": "assistant", + "content": "".join(assistant_text_parts), + "tool_calls": assistant_tool_calls or None, + } + ] + + +def _build_response_object( + request: ResponsesRequest, + output_items: list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem], + prompt_tokens: int, + completion_tokens: int, + finish_reason: str | None, + response_id: str | None = None, +) -> ResponseObject: + """Build a full Responses API object.""" + response = ResponseObject( + id=response_id or _new_response_item_id("resp"), + model=_model_name or request.model, + instructions=request.instructions, + max_output_tokens=request.max_output_tokens, + metadata=request.metadata, + output=output_items, + parallel_tool_calls=request.parallel_tool_calls, + previous_response_id=request.previous_response_id, + text=request.text, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=_resolve_top_p(request.top_p), + temperature=_resolve_temperature(request.temperature), + truncation=request.truncation, + user=request.user, + store=request.store, + usage=ResponsesUsage( + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + if finish_reason == "length": + response.status = "incomplete" + response.incomplete_details = ResponseIncompleteDetails( + reason="max_output_tokens" + ) + return response + + +def _prepare_responses_request( + request: ResponsesRequest, +) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict, bool]: + """Prepare a Responses request for execution on the chat engine.""" + _validate_model_name(request.model) + engine = get_engine() + chat_request = _responses_request_to_chat_request(request) + + if chat_request.messages: + logger.info( + f"[REQUEST] POST /v1/responses stream={request.stream} " + f"model={request.model!r} items=" + f"{len(request.input) if isinstance(request.input, list) else 1} " + f"tools={len(request.tools)}" + ) + + messages, images, videos = extract_multimodal_content( + chat_request.messages, + preserve_native_format=engine.preserve_native_tool_format, + ) + + chat_kwargs = { + "max_tokens": chat_request.max_tokens or _default_max_tokens, + "temperature": _resolve_temperature(chat_request.temperature), + "top_p": _resolve_top_p(chat_request.top_p), + } + if request.tools: + chat_kwargs["tools"] = convert_tools_for_template(chat_request.tools) + should_parse_tools = _apply_tool_choice( + chat_request.tool_choice, chat_kwargs, messages + ) + if images: + chat_kwargs["images"] = images + if videos: + chat_kwargs["videos"] = videos + + return engine, chat_request, messages, chat_kwargs, should_parse_tools + + +async def _run_responses_request( + request: ResponsesRequest, + raw_request: Request, +) -> tuple[ResponseObject | None, list[dict]]: + """Execute a Responses API request against the backend chat engine.""" + engine, chat_request, messages, chat_kwargs, should_parse_tools = ( + _prepare_responses_request(request) + ) + + timeout = _default_timeout + output = await _wait_with_disconnect( + engine.chat(messages=messages, **chat_kwargs), + raw_request, + timeout=timeout, + ) + if output is None: + return None, [] + + cleaned_text, tool_calls = _parse_and_validate_tools( + output.text, chat_request, should_parse_tools + ) + reasoning_text = None + if _reasoning_parser and not tool_calls: + reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning( + cleaned_text or output.text + ) + + output_items = _build_responses_output_items( + clean_output_text(cleaned_text) if cleaned_text else None, + reasoning_text, + tool_calls, + ) + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=output.prompt_tokens, + completion_tokens=output.completion_tokens, + finish_reason=output.finish_reason, + ) + + persisted_messages = _responses_request_to_persisted_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + if request.store: + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + return response_object, persisted_messages + + +async def _stream_responses_request(request: ResponsesRequest) -> AsyncIterator[str]: + """Execute a Responses API request and stream SSE events incrementally.""" + engine, chat_request, messages, chat_kwargs, should_parse_tools = ( + _prepare_responses_request(request) + ) + + response_id = _new_response_item_id("resp") + sequence = 1 + base_response = _build_response_object( + request=request, + output_items=[], + prompt_tokens=0, + completion_tokens=0, + finish_reason=None, + response_id=response_id, + ) + base_response.status = "in_progress" + base_response.usage = None + + yield _responses_sse_event( + "response.created", + ResponseCreatedEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + yield _responses_sse_event( + "response.in_progress", + ResponseInProgressEvent(sequence_number=sequence, response=base_response), + ) + sequence += 1 + + prompt_tokens = 0 + completion_tokens = 0 + finish_reason = None + last_output = None + raw_accumulated_text = "" + accumulated_text = "" + accumulated_reasoning = "" + + text_item_id: str | None = None + text_output_index: int | None = None + reasoning_item_id: str | None = None + reasoning_output_index: int | None = None + next_output_index = 0 + + def _start_text_item() -> list[str]: + nonlocal text_item_id, text_output_index, next_output_index, sequence + events: list[str] = [] + if text_item_id is None: + text_item_id = _new_response_item_id("msg") + text_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=text_output_index, + item=ResponseMessageItem( + id=text_item_id, + role="assistant", + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + part=ResponseTextContentPart(type="output_text", text=""), + ), + ) + ) + sequence += 1 + return events + + def _start_reasoning_item() -> list[str]: + nonlocal reasoning_item_id, reasoning_output_index, next_output_index, sequence + events: list[str] = [] + if reasoning_item_id is None: + reasoning_item_id = _new_response_item_id("rs") + reasoning_output_index = next_output_index + next_output_index += 1 + events.append( + _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=reasoning_output_index, + item=ResponseReasoningItem( + id=reasoning_item_id, + status="in_progress", + content=[], + ), + ), + ) + ) + sequence += 1 + events.append( + _responses_sse_event( + "response.content_part.added", + ResponseContentPartAddedEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + part=ResponseReasoningTextPart(text=""), + ), + ) + ) + sequence += 1 + return events + + if _reasoning_parser: + _reasoning_parser.reset_state() + + global _tool_parser_instance + tool_parser = None + tool_accumulated_text = "" + tool_markup_possible = False + if should_parse_tools and _enable_auto_tool_choice and _tool_call_parser: + if _tool_parser_instance is None: + try: + parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser) + tokenizer = None + if _engine is not None and hasattr(_engine, "_tokenizer"): + tokenizer = _engine._tokenizer + _tool_parser_instance = parser_cls(tokenizer) + logger.info( + "Initialized tool call parser for responses streaming: %s", + _tool_call_parser, + ) + except Exception as e: + logger.warning( + "Failed to init tool parser for responses streaming: %s", e + ) + if _tool_parser_instance is not None: + tool_parser = _tool_parser_instance + tool_parser.reset() + + async for output in engine.stream_chat(messages=messages, **chat_kwargs): + last_output = output + finish_reason = output.finish_reason + if hasattr(output, "prompt_tokens") and output.prompt_tokens: + prompt_tokens = output.prompt_tokens + if hasattr(output, "completion_tokens") and output.completion_tokens: + completion_tokens = output.completion_tokens + + delta_text = output.new_text or "" + if not delta_text: + continue + + previous_text = raw_accumulated_text + raw_accumulated_text += delta_text + + if _reasoning_parser: + delta_msg = _reasoning_parser.extract_reasoning_streaming( + previous_text, raw_accumulated_text, delta_text + ) + if delta_msg is None: + continue + + if delta_msg.reasoning: + for event in _start_reasoning_item(): + yield event + accumulated_reasoning += delta_msg.reasoning + yield _responses_sse_event( + "response.reasoning_text.delta", + ResponseReasoningTextDeltaEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + delta=delta_msg.reasoning, + ), + ) + sequence += 1 + + if delta_msg.content: + for event in _start_text_item(): + yield event + accumulated_text += delta_msg.content + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + delta=delta_msg.content, + ), + ) + sequence += 1 + continue + + content = SPECIAL_TOKENS_PATTERN.sub("", delta_text) + if tool_parser and delta_text: + if not tool_markup_possible and not _looks_like_streaming_tool_markup( + delta_text + ): + tool_accumulated_text += delta_text + else: + if not tool_markup_possible: + tool_markup_possible = True + tool_result = tool_parser.extract_tool_calls_streaming( + tool_accumulated_text, tool_accumulated_text + delta_text, delta_text + ) + tool_accumulated_text += delta_text + if tool_result is None: + continue + if "tool_calls" in tool_result: + continue + content = tool_result.get("content", "") + + if not content: + continue + + for event in _start_text_item(): + yield event + accumulated_text += content + yield _responses_sse_event( + "response.output_text.delta", + ResponseOutputTextDeltaEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + delta=content, + ), + ) + sequence += 1 + + cleaned_text, tool_calls = _parse_and_validate_tools( + raw_accumulated_text, chat_request, should_parse_tools + ) + final_text = accumulated_text + if cleaned_text is not None and not final_text and not tool_calls: + final_text = clean_output_text(cleaned_text) + + reasoning_item = None + if reasoning_item_id is not None: + reasoning_item = ResponseReasoningItem( + id=reasoning_item_id, + status="completed", + content=[ResponseReasoningTextPart(text=accumulated_reasoning)], + ) + yield _responses_sse_event( + "response.reasoning_text.done", + ResponseReasoningTextDoneEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + text=accumulated_reasoning, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=reasoning_item_id, + output_index=reasoning_output_index, + content_index=0, + part=reasoning_item.content[0], + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=reasoning_output_index, + item=reasoning_item, + ), + ) + sequence += 1 + + text_item = None + if text_item_id is not None or final_text: + if text_item_id is None: + for event in _start_text_item(): + yield event + text_item = ResponseMessageItem( + id=text_item_id, + role="assistant", + status="completed", + content=[ResponseTextContentPart(type="output_text", text=final_text)], + ) + yield _responses_sse_event( + "response.output_text.done", + ResponseOutputTextDoneEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + text=final_text, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.content_part.done", + ResponseContentPartDoneEvent( + sequence_number=sequence, + item_id=text_item_id, + output_index=text_output_index, + content_index=0, + part=text_item.content[0], + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=text_output_index, + item=text_item, + ), + ) + sequence += 1 + + function_call_items: list[ResponseFunctionCallItem] = [] + for tool_call in tool_calls or []: + output_index = next_output_index + next_output_index += 1 + item = ResponseFunctionCallItem( + id=_new_response_item_id("fc"), + call_id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + function_call_items.append(item) + yield _responses_sse_event( + "response.output_item.added", + ResponseOutputItemAddedEvent( + sequence_number=sequence, + output_index=output_index, + item=item.model_copy(update={"status": "in_progress"}), + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.function_call_arguments.delta", + ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=sequence, + item_id=item.id, + output_index=output_index, + delta=item.arguments, + ), + ) + sequence += 1 + yield _responses_sse_event( + "response.output_item.done", + ResponseOutputItemDoneEvent( + sequence_number=sequence, + output_index=output_index, + item=item, + ), + ) + sequence += 1 + + output_items: list[ + ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem + ] = [] + if reasoning_item is not None: + output_items.append(reasoning_item) + if text_item is not None: + output_items.append(text_item) + output_items.extend(function_call_items) + + response_object = _build_response_object( + request=request, + output_items=output_items, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + finish_reason=finish_reason, + response_id=response_id, + ) + + if request.store and last_output is not None: + persisted_messages = _responses_request_to_persisted_messages(request) + persisted_messages.extend(_response_output_items_to_chat_messages(output_items)) + _responses_store[response_object.id] = { + "messages": copy.deepcopy(persisted_messages), + "response": response_object.model_copy(deep=True), + } + while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE: + _responses_store.popitem(last=False) + + yield _responses_sse_event( + "response.completed", + ResponseCompletedEvent(sequence_number=sequence, response=response_object), + ) + + +def _responses_sse_event(event_type: str, payload: BaseModel | dict) -> str: + """Encode a Responses API SSE event.""" + data = payload.model_dump_json() if isinstance(payload, BaseModel) else json.dumps(payload) + return f"event: {event_type}\ndata: {data}\n\n" def _detect_native_tool_support() -> bool: """ Detect if the active tool parser supports native tool format. @@ -2000,73 +3044,69 @@ async def stream_chat_completion( # Skip this chunk (e.g., token itself) continue - content = delta_msg.content - reasoning = delta_msg.reasoning - - # Tool call parsing on the content portion (fix: was - # previously unreachable because tool parser was in the - # else branch of this if/else) - if tool_parser and content: - if not tool_markup_possible and "<" not in content: - tool_accumulated_text += content - else: - if not tool_markup_possible: - tool_markup_possible = True - tool_previous = tool_accumulated_text - tool_accumulated_text += content + # Route content through tool parser when both are active + emit_content = delta_msg.content + if tool_parser and emit_content: + tool_accumulated_text += emit_content + if tool_markup_possible or _looks_like_streaming_tool_markup( + emit_content + ): + tool_markup_possible = True tool_result = tool_parser.extract_tool_calls_streaming( - tool_previous, tool_accumulated_text, content + tool_accumulated_text[: -len(emit_content)], + tool_accumulated_text, + emit_content, ) if tool_result is None: - # Inside tool markup — suppress content, keep reasoning - content = None + emit_content = None elif "tool_calls" in tool_result: + tool_call_complete = bool(tool_result.get("complete")) tool_calls_detected = True - chunk = ChatCompletionChunk( + tc_chunk = ChatCompletionChunk( id=response_id, - model=request.model, + model=_model_name, choices=[ ChatCompletionChunkChoice( delta=ChatCompletionChunkDelta( - reasoning=reasoning, tool_calls=tool_result["tool_calls"], ), finish_reason=( - "tool_calls" if output.finished else None + "tool_calls" + if (output.finished or tool_call_complete) + else None ), ) ], - usage=get_usage(output) if output.finished else None, + usage=( + get_usage(output) + if (output.finished or tool_call_complete) + else None + ), ) - yield f"data: {chunk.model_dump_json()}\n\n" + yield f"data: {tc_chunk.model_dump_json()}\n\n" + if tool_call_complete: + break continue else: - content = tool_result.get("content", "") - - # Skip if both content and reasoning are empty - if not content and not reasoning: - continue - - chunk = ChatCompletionChunk( - id=response_id, - model=_model_name, - choices=[ - ChatCompletionChunkChoice( - delta=ChatCompletionChunkDelta( - content=content if content else None, - reasoning=reasoning, - ), - finish_reason=( - "tool_calls" - if (output.finished and tool_calls_detected) - else (output.finish_reason if output.finished else None) - ), - ) - ], - usage=get_usage(output) if output.finished else None, - ) - yield f"data: {chunk.model_dump_json()}\n\n" + emit_content = tool_result.get("content", emit_content) + + if emit_content or delta_msg.reasoning: + chunk = ChatCompletionChunk( + id=response_id, + model=_model_name, + choices=[ + ChatCompletionChunkChoice( + delta=ChatCompletionChunkDelta( + content=emit_content, + reasoning=delta_msg.reasoning, + ), + finish_reason=output.finish_reason if output.finished else None, + ) + ], + usage=get_usage(output) if output.finished else None, + ) + yield f"data: {chunk.model_dump_json()}\n\n" else: # Standard path without reasoning parsing content = delta_text @@ -2085,7 +3125,9 @@ async def stream_chat_completion( # Fast path: skip full parsing until '<' is seen in the stream, # which could start tool markup (e.g. ). This avoids # per-token string scanning on the growing accumulated text. - if not tool_markup_possible and "<" not in delta_text: + if not tool_markup_possible and not _looks_like_streaming_tool_markup( + delta_text + ): tool_accumulated_text += delta_text # No tool markup yet, fall through to normal chunk emission else: @@ -2102,6 +3144,7 @@ async def stream_chat_completion( continue if "tool_calls" in tool_result: + tool_call_complete = bool(tool_result.get("complete")) # Emit structured tool calls tool_calls_detected = True chunk = ChatCompletionChunk( @@ -2113,13 +3156,21 @@ async def stream_chat_completion( tool_calls=tool_result["tool_calls"] ), finish_reason=( - "tool_calls" if output.finished else None + "tool_calls" + if (output.finished or tool_call_complete) + else None ), ) ], - usage=get_usage(output) if output.finished else None, + usage=( + get_usage(output) + if (output.finished or tool_call_complete) + else None + ), ) yield f"data: {chunk.model_dump_json()}\n\n" + if tool_call_complete: + break continue # Normal content from tool parser @@ -2144,16 +3195,14 @@ async def stream_chat_completion( ) yield f"data: {chunk.model_dump_json()}\n\n" - # Fallback: if tool parser accumulated text but never emitted tool_calls - # (e.g., never arrived - incomplete tool call) - if ( - tool_parser - and tool_accumulated_text - and not tool_calls_detected - and "" in tool_accumulated_text - ): - result = tool_parser.extract_tool_calls(tool_accumulated_text) - if result.tools_called: + # Completion guard: if a stream ended with fully accumulated tool markup but + # the incremental parser never emitted, run one final validated parse over + # the complete tool text before returning. + if tool_parser and tool_accumulated_text and not tool_calls_detected: + _, fallback_tool_calls = _parse_and_validate_tools( + tool_accumulated_text, request, True + ) + if fallback_tool_calls: tool_chunk = ChatCompletionChunk( id=response_id, model=_model_name, @@ -2163,14 +3212,14 @@ async def stream_chat_completion( tool_calls=[ { "index": i, - "id": tc["id"], + "id": tc.id, "type": "function", "function": { - "name": tc["name"], - "arguments": tc["arguments"], + "name": tc.function.name, + "arguments": tc.function.arguments, }, } - for i, tc in enumerate(result.tool_calls) + for i, tc in enumerate(fallback_tool_calls) ] ), finish_reason="tool_calls", diff --git a/vllm_mlx/tool_parsers/harmony_tool_parser.py b/vllm_mlx/tool_parsers/harmony_tool_parser.py index 34f8555d..c4d08952 100644 --- a/vllm_mlx/tool_parsers/harmony_tool_parser.py +++ b/vllm_mlx/tool_parsers/harmony_tool_parser.py @@ -34,6 +34,11 @@ def _generate_tool_id() -> str: return f"call_{uuid.uuid4().hex[:8]}" +def _same_tool_call(left: dict[str, str], right: dict[str, str]) -> bool: + """Return True when two parsed tool calls are semantically identical.""" + return left["name"] == right["name"] and left["arguments"] == right["arguments"] + + # Pattern: <|channel|>commentary to=functions.tool_name ... <|call|> _COMMENTARY_BLOCK_PATTERN = re.compile( r"<\|channel\|>commentary\s+to=functions\.(\w+)" @@ -82,26 +87,26 @@ def extract_tool_calls( try: arguments = json.loads(args_str) - tool_calls.append( - { - "id": _generate_tool_id(), - "name": tool_name, - "arguments": ( - json.dumps(arguments, ensure_ascii=False) - if isinstance(arguments, dict) - else str(arguments) - ), - } - ) + parsed_call = { + "id": _generate_tool_id(), + "name": tool_name, + "arguments": ( + json.dumps(arguments, ensure_ascii=False) + if isinstance(arguments, dict) + else str(arguments) + ), + } except json.JSONDecodeError: # Keep the raw arguments string - tool_calls.append( - { - "id": _generate_tool_id(), - "name": tool_name, - "arguments": args_str, - } - ) + parsed_call = { + "id": _generate_tool_id(), + "name": tool_name, + "arguments": args_str, + } + + if tool_calls and _same_tool_call(parsed_call, tool_calls[-1]): + continue + tool_calls.append(parsed_call) # Extract final channel content final_match = _FINAL_BLOCK_PATTERN.search(model_output) @@ -143,13 +148,16 @@ def extract_tool_calls_streaming( channel content as regular content deltas. """ # If we see a tool call completion marker in the delta - if "<|call|>" in delta_text: - result = self.extract_tool_calls(current_text) - if result.tools_called: + if "<|call|>" in current_text: + previous_result = self.extract_tool_calls(previous_text) + current_result = self.extract_tool_calls(current_text) + new_tool_calls = current_result.tool_calls[len(previous_result.tool_calls) :] + if new_tool_calls: return { + "complete": True, "tool_calls": [ { - "index": i, + "index": len(previous_result.tool_calls) + i, "id": tc["id"], "type": "function", "function": { @@ -157,7 +165,7 @@ def extract_tool_calls_streaming( "arguments": tc["arguments"], }, } - for i, tc in enumerate(result.tool_calls) + for i, tc in enumerate(new_tool_calls) ] } diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py index fd69b96c..5348d53c 100644 --- a/vllm_mlx/tool_parsers/qwen_tool_parser.py +++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py @@ -40,6 +40,17 @@ class QwenToolParser(ToolParser): # Pattern for XML-style: {"json"} XML_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL) + # Pattern for XML-style function blocks: + # value + FUNCTION_PATTERN = re.compile( + r"\s*]+)>(.*?)\s*", + re.DOTALL, + ) + PARAM_PATTERN = re.compile( + r"]+)>\s*(.*?)\s*", + re.DOTALL, + ) + # Pattern for bracket-style: [Calling tool: func_name({...})] BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL) @@ -101,6 +112,29 @@ def extract_tool_calls( if xml_matches: cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip() + # Try function/parameter pattern used by Qwen tool templates + function_matches = self.FUNCTION_PATTERN.findall(cleaned_text) + for name, params_block in function_matches: + arguments = {} + for param_name, param_value in self.PARAM_PATTERN.findall(params_block): + raw_value = param_value.strip() + try: + arguments[param_name.strip()] = json.loads(raw_value) + except json.JSONDecodeError: + arguments[param_name.strip()] = raw_value + + if name.strip(): + tool_calls.append( + { + "id": generate_tool_id(), + "name": name.strip(), + "arguments": json.dumps(arguments, ensure_ascii=False), + } + ) + + if function_matches: + cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip() + if tool_calls: return ExtractedToolCallInformation( tools_called=True, @@ -135,11 +169,12 @@ def extract_tool_calls_streaming( # If we're in a tool call, accumulate and parse at the end # For simplicity, return None during accumulation - if "" in delta_text or ")]" in delta_text: + if "" in current_text or ")]" in current_text: # Tool call complete, parse the whole thing result = self.extract_tool_calls(current_text) if result.tools_called: return { + "complete": True, "tool_calls": [ { "index": i,