diff --git a/tests/test_engine_tool_output_preservation.py b/tests/test_engine_tool_output_preservation.py
new file mode 100644
index 00000000..6e8c9dce
--- /dev/null
+++ b/tests/test_engine_tool_output_preservation.py
@@ -0,0 +1,94 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Tests that tool-enabled chat preserves raw parser-visible output."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+
+@pytest.fixture
+def anyio_backend():
+ return "asyncio"
+
+
+class TestSimpleEngineToolOutputPreservation:
+ @pytest.mark.anyio
+ async def test_chat_with_tools_preserves_raw_harmony_output(self):
+ from vllm_mlx.engine.simple import SimpleEngine
+
+ async def fake_stream_chat(*args, **kwargs):
+ yield MagicMock(
+ text=(
+ "<|channel|>commentary to=functions.get_weather"
+ '<|message|>{"city":"Paris"}<|call|>'
+ ),
+ tokens=[],
+ prompt_tokens=11,
+ completion_tokens=4,
+ finish_reason="stop",
+ finished=True,
+ )
+
+ with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
+ engine = SimpleEngine("test-model")
+ engine._model = MagicMock()
+ engine._loaded = True
+ engine.stream_chat = fake_stream_chat # type: ignore[method-assign]
+
+ output = await engine.chat(
+ messages=[{"role": "user", "content": "Weather in Paris?"}],
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {}},
+ },
+ }
+ ],
+ )
+
+ assert "<|channel|>commentary" in output.text
+ assert "<|call|>" in output.text
+
+
+class TestBatchedEngineToolOutputPreservation:
+ @pytest.mark.anyio
+ async def test_chat_with_tools_preserves_raw_output(self):
+ from vllm_mlx.engine.batched import BatchedEngine
+
+ raw_output = (
+ "<|channel|>commentary to=functions.get_weather"
+ '<|message|>{"city":"Paris"}<|call|>'
+ )
+
+ with patch("vllm_mlx.engine.batched.is_mllm_model", return_value=False):
+ engine = BatchedEngine("test-model")
+ engine._loaded = True
+ engine._tokenizer = MagicMock()
+ engine._apply_chat_template = MagicMock(return_value="prompt")
+ engine._engine = MagicMock()
+ engine._engine.generate = AsyncMock(
+ return_value=MagicMock(
+ output_text=raw_output,
+ prompt_tokens=9,
+ completion_tokens=3,
+ finish_reason="stop",
+ )
+ )
+
+ output = await engine.chat(
+ messages=[{"role": "user", "content": "Weather in Paris?"}],
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "get_weather",
+ "parameters": {"type": "object", "properties": {}},
+ },
+ }
+ ],
+ )
+
+ assert output.text == raw_output
+ engine._engine.generate.assert_called_once()
diff --git a/tests/test_harmony_parsers.py b/tests/test_harmony_parsers.py
index 9ca509f6..77d3db59 100644
--- a/tests/test_harmony_parsers.py
+++ b/tests/test_harmony_parsers.py
@@ -198,6 +198,23 @@ def test_nested_json_arguments(self, parser):
parsed_args = json.loads(result.tool_calls[0]["arguments"])
assert parsed_args["filter"]["type"] == "range"
+ def test_consecutive_duplicate_tool_calls_are_deduped(self, parser):
+ """Repeated identical commentary blocks should collapse to one call."""
+ text = (
+ "<|channel|>commentary to=functions.get_weather\n"
+ "<|constrain|>json\n"
+ '<|message|>{"city":"Paris"}\n'
+ "<|call|>\n"
+ "<|channel|>commentary to=functions.get_weather\n"
+ "<|constrain|>json\n"
+ '<|message|>{"city":"Paris"}\n'
+ "<|call|>"
+ )
+ result = parser.extract_tool_calls(text)
+
+ assert result.tools_called
+ assert len(result.tool_calls) == 1
+
def test_streaming_no_tool_markers(self, parser):
"""Streaming: plain text passes through as content."""
result = parser.extract_tool_calls_streaming("", "Hello", "Hello")
@@ -227,6 +244,25 @@ def test_streaming_building_tool_call(self, parser):
result = parser.extract_tool_calls_streaming("", current, '{"a":')
assert result is None
+ def test_streaming_duplicate_tool_call_is_not_reemitted(self, parser):
+ """Streaming should only emit newly completed Harmony tool calls."""
+ previous = (
+ "<|channel|>commentary to=functions.func\n"
+ "<|constrain|>json\n"
+ '<|message|>{"a": 1}\n'
+ "<|call|>"
+ )
+ current = (
+ previous
+ + "\n<|channel|>commentary to=functions.func\n"
+ + "<|constrain|>json\n"
+ + '<|message|>{"a": 1}\n'
+ + "<|call|>"
+ )
+
+ result = parser.extract_tool_calls_streaming(previous, current, "<|call|>")
+ assert result is None
+
# ============================================================================
# Reasoning Parser Tests
@@ -388,26 +424,60 @@ def test_streaming_reset(self, parser):
assert parser._current_channel is None
assert parser._in_message is False
- def test_streaming_commentary_suppressed(self, parser):
- """Streaming: commentary channel output is suppressed."""
+ def test_streaming_commentary_routed_as_content(self, parser):
+ """Streaming: commentary channel is forwarded for downstream tool parsing."""
parser.reset_state()
- parser.extract_reasoning_streaming(
+ r1 = parser.extract_reasoning_streaming(
"",
"<|channel|>commentary to=functions.f\n",
"<|channel|>commentary to=functions.f\n",
)
- parser.extract_reasoning_streaming(
+ assert r1 is not None
+ assert r1.content == "<|channel|>commentary to=functions.f\n"
+
+ r2 = parser.extract_reasoning_streaming(
"<|channel|>commentary to=functions.f\n",
"<|channel|>commentary to=functions.f\n<|message|>",
"<|message|>",
)
+ assert r2 is not None
+ assert r2.content == "<|message|>"
+
r = parser.extract_reasoning_streaming(
"<|channel|>commentary to=functions.f\n<|message|>",
'<|channel|>commentary to=functions.f\n<|message|>{"a":1}',
'{"a":1}',
)
- assert r is None
+ assert r is not None
+ assert r.content == '{"a":1}'
+
+ def test_streaming_split_commentary_header(self, parser):
+ """Split commentary headers should still be routed to downstream tool parsing."""
+ parser.reset_state()
+
+ accumulated = ""
+ content_parts = []
+ for token in [
+ "<|channel|>",
+ "comment",
+ "ary to",
+ "=functions.get_weather",
+ " <|constrain|>",
+ "json",
+ "<|message|>",
+ '{"city":"Paris"}',
+ "<|call|>",
+ ]:
+ prev = accumulated
+ accumulated += token
+ result = parser.extract_reasoning_streaming(prev, accumulated, token)
+ if result and result.content:
+ content_parts.append(result.content)
+
+ combined = "".join(content_parts)
+ assert "<|channel|>commentary to=functions.get_weather" in combined
+ assert '<|message|>{"city":"Paris"}' in combined
# ============================================================================
diff --git a/tests/test_simple_engine.py b/tests/test_simple_engine.py
index cce42bfc..7590ffb5 100644
--- a/tests/test_simple_engine.py
+++ b/tests/test_simple_engine.py
@@ -10,6 +10,10 @@
class TestSimpleEngineConcurrency:
"""Test SimpleEngine lock behavior with concurrent requests."""
+ @pytest.fixture
+ def anyio_backend(self):
+ return "asyncio"
+
@pytest.fixture
def mock_model(self):
"""Create a mock model that tracks concurrent calls."""
@@ -65,7 +69,7 @@ def chat_side_effect(**kwargs):
model.chat = MagicMock(side_effect=chat_side_effect)
return model
- @pytest.mark.asyncio
+ @pytest.mark.anyio
async def test_lock_prevents_concurrent_generate(self, mock_model):
"""Test that the lock prevents concurrent generate calls."""
from vllm_mlx.engine.simple import SimpleEngine
@@ -89,7 +93,7 @@ async def test_lock_prevents_concurrent_generate(self, mock_model):
"The lock is not working correctly."
)
- @pytest.mark.asyncio
+ @pytest.mark.anyio
async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"""Test that the lock prevents concurrent chat calls."""
from vllm_mlx.engine.simple import SimpleEngine
@@ -115,7 +119,57 @@ async def test_lock_prevents_concurrent_chat(self, mock_llm_model):
"The lock is not working correctly."
)
- @pytest.mark.asyncio
+ @pytest.mark.anyio
+ async def test_chat_with_tools_aggregates_streaming_path(self, mock_llm_model):
+ """Tool-enabled non-stream chat should use the streaming path."""
+ from vllm_mlx.engine.simple import SimpleEngine
+
+ async def fake_stream_chat(*args, **kwargs):
+ yield MagicMock(
+ text="partial",
+ tokens=[],
+ prompt_tokens=11,
+ completion_tokens=1,
+ finish_reason=None,
+ finished=False,
+ )
+ yield MagicMock(
+ text="<|im_end|>{\"name\":\"bash\",\"arguments\":{\"command\":\"pwd\"}}",
+ tokens=[],
+ prompt_tokens=11,
+ completion_tokens=4,
+ finish_reason="stop",
+ finished=True,
+ )
+
+ with patch("vllm_mlx.engine.simple.is_mllm_model", return_value=False):
+ engine = SimpleEngine("test-model")
+ engine._model = mock_llm_model
+ engine._loaded = True
+ engine.stream_chat = fake_stream_chat # type: ignore[method-assign]
+
+ output = await engine.chat(
+ messages=[{"role": "user", "content": "run pwd"}],
+ max_tokens=16,
+ tools=[
+ {
+ "type": "function",
+ "function": {
+ "name": "bash",
+ "parameters": {"type": "object", "properties": {}},
+ },
+ }
+ ],
+ )
+
+ assert output.text.startswith("<|im_end|>")
+ assert output.tokens == []
+ assert output.prompt_tokens == 11
+ assert output.completion_tokens == 4
+ assert output.finish_reason == "stop"
+ mock_llm_model.chat.assert_not_called()
+
+ @pytest.mark.anyio
async def test_lock_serializes_stream_generate(self, mock_model):
"""Test that stream_generate uses the same lock as other methods."""
from vllm_mlx.engine.simple import SimpleEngine
@@ -178,7 +232,7 @@ async def try_stream():
result = await stream_task
assert len(result) == 3, f"Expected 3 chunks, got {len(result)}"
- @pytest.mark.asyncio
+ @pytest.mark.anyio
async def test_engine_initialization_creates_lock(self):
"""Test that SimpleEngine creates a lock on initialization."""
from vllm_mlx.engine.simple import SimpleEngine
@@ -189,7 +243,7 @@ async def test_engine_initialization_creates_lock(self):
assert hasattr(engine, "_generation_lock")
assert isinstance(engine._generation_lock, asyncio.Lock)
- @pytest.mark.asyncio
+ @pytest.mark.anyio
async def test_requests_complete_in_order(self, mock_model):
"""Test that concurrent requests complete (may be in any order due to lock)."""
from vllm_mlx.engine.simple import SimpleEngine
diff --git a/tests/test_streaming_reasoning_tools.py b/tests/test_streaming_reasoning_tools.py
new file mode 100644
index 00000000..c6726bf1
--- /dev/null
+++ b/tests/test_streaming_reasoning_tools.py
@@ -0,0 +1,396 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Streaming chat-completion tests for reasoning + tool parser coexistence."""
+
+import json
+from collections.abc import AsyncIterator
+from typing import Any
+
+import pytest
+
+import vllm_mlx.server as server
+from vllm_mlx.api.models import ChatCompletionRequest, Message, ToolDefinition
+from vllm_mlx.engine.base import BaseEngine, GenerationOutput
+from vllm_mlx.reasoning import get_parser
+from vllm_mlx.tool_parsers import ToolParserManager
+
+
+TEST_TOOL = ToolDefinition(
+ function={
+ "name": "get_weather",
+ "parameters": {
+ "type": "object",
+ "properties": {"city": {"type": "string"}},
+ "required": ["city"],
+ },
+ }
+)
+
+
+class FakeStreamEngine(BaseEngine):
+ """Minimal engine for deterministic stream_chat tests."""
+
+ def __init__(self, deltas: list[str], model_name: str = "test-model"):
+ self._deltas = deltas
+ self._model_name = model_name
+ self._tokenizer = None
+
+ @property
+ def model_name(self) -> str:
+ return self._model_name
+
+ @property
+ def is_mllm(self) -> bool:
+ return False
+
+ @property
+ def tokenizer(self) -> Any:
+ return None
+
+ async def start(self) -> None:
+ return None
+
+ async def stop(self) -> None:
+ return None
+
+ async def generate(self, *args, **kwargs) -> GenerationOutput:
+ raise NotImplementedError
+
+ async def stream_generate(self, *args, **kwargs) -> AsyncIterator[GenerationOutput]:
+ raise NotImplementedError
+
+ async def chat(self, *args, **kwargs) -> GenerationOutput:
+ raise NotImplementedError
+
+ async def stream_chat(self, *args, **kwargs) -> AsyncIterator[GenerationOutput]:
+ text = ""
+ for i, delta in enumerate(self._deltas):
+ text += delta
+ yield GenerationOutput(
+ text=text,
+ new_text=delta,
+ finished=i == len(self._deltas) - 1,
+ completion_tokens=i + 1,
+ finish_reason="stop",
+ )
+
+
+def _collect_payloads(stream_output: list[str]) -> list[dict[str, Any]]:
+ payloads: list[dict[str, Any]] = []
+ for chunk in stream_output:
+ assert chunk.startswith("data: ")
+ payload = chunk[6:].strip()
+ if payload == "[DONE]":
+ continue
+ payloads.append(json.loads(payload))
+ return payloads
+
+
+def _flatten_deltas(payloads: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ deltas: list[dict[str, Any]] = []
+ for payload in payloads:
+ for choice in payload["choices"]:
+ deltas.append(
+ {
+ "delta": choice["delta"],
+ "finish_reason": choice["finish_reason"],
+ }
+ )
+ return deltas
+
+
+def _reasoning_text(deltas: list[dict[str, Any]]) -> str:
+ return "".join(delta["delta"].get("reasoning") or "" for delta in deltas)
+
+
+def _content_text(deltas: list[dict[str, Any]]) -> str:
+ return "".join(delta["delta"].get("content") or "" for delta in deltas)
+
+
+def _tool_calls(deltas: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ calls: list[dict[str, Any]] = []
+ for delta in deltas:
+ calls.extend(delta["delta"].get("tool_calls") or [])
+ return calls
+
+
+async def _run_stream(
+ monkeypatch: pytest.MonkeyPatch,
+ *,
+ deltas: list[str],
+ reasoning_parser: str | None,
+ tool_parser: str | None,
+ model_name: str = "test-model",
+) -> list[dict[str, Any]]:
+ engine = FakeStreamEngine(deltas, model_name=model_name)
+ parser_instance = (
+ ToolParserManager.get_tool_parser(tool_parser)() if tool_parser else None
+ )
+
+ monkeypatch.setattr(server, "_engine", engine)
+ monkeypatch.setattr(server, "_model_name", model_name)
+ monkeypatch.setattr(
+ server,
+ "_reasoning_parser",
+ get_parser(reasoning_parser)() if reasoning_parser else None,
+ )
+ monkeypatch.setattr(server, "_enable_auto_tool_choice", tool_parser is not None)
+ monkeypatch.setattr(server, "_tool_call_parser", tool_parser)
+ monkeypatch.setattr(server, "_tool_parser_instance", parser_instance)
+
+ request = ChatCompletionRequest(
+ model=model_name,
+ messages=[Message(role="user", content="Weather in Paris?")],
+ stream=True,
+ tools=[TEST_TOOL],
+ )
+
+ raw_chunks: list[str] = []
+ async for chunk in server.stream_chat_completion(
+ engine,
+ [{"role": "user", "content": "Weather in Paris?"}],
+ request,
+ ):
+ raw_chunks.append(chunk)
+
+ return _flatten_deltas(_collect_payloads(raw_chunks))
+
+
+@pytest.mark.anyio
+async def test_streaming_qwen_reasoning_then_tool_call(monkeypatch):
+ deltas = [
+ "",
+ "Need tool",
+ "",
+ "\n",
+ '{"name":"get_weather","arguments":{"city":"Paris"}}',
+ "\n",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="qwen3",
+ tool_parser="qwen",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == "Need tool"
+ assert _content_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"}
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_qwen_direct_tool_call_with_reasoning_parser(monkeypatch):
+ deltas = [
+ "\n",
+ '{"name":"get_weather","arguments":{"city":"Paris"}}',
+ "\n",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="qwen3",
+ tool_parser="qwen",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == ""
+ assert _content_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_qwen_reasoning_then_plain_text(monkeypatch):
+ deltas = ["", "R", "", "Hello", " world"]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="qwen3",
+ tool_parser="qwen",
+ )
+
+ assert _reasoning_text(deltas_out) == "R"
+ assert _content_text(deltas_out) == "Hello world"
+ assert _tool_calls(deltas_out) == []
+ assert deltas_out[-1]["finish_reason"] == "stop"
+
+
+@pytest.mark.anyio
+async def test_streaming_qwen_tool_call_inside_think_is_not_emitted(monkeypatch):
+ deltas = [
+ "",
+ 'R {"name":"get_weather","arguments":{"city":"Paris"}}',
+ "",
+ "Text",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="qwen3",
+ tool_parser="qwen",
+ )
+
+ assert "" in _reasoning_text(deltas_out)
+ assert _content_text(deltas_out) == "Text"
+ assert _tool_calls(deltas_out) == []
+ assert deltas_out[-1]["finish_reason"] == "stop"
+
+
+@pytest.mark.anyio
+async def test_streaming_harmony_reasoning_then_tool_call(monkeypatch):
+ deltas = [
+ "<|channel|>analysis",
+ "<|message|>",
+ "Need weather",
+ "<|end|>",
+ "<|channel|>commentary to=functions.get_weather",
+ "<|constrain|>json",
+ "<|message|>",
+ '{"city":"Paris"}',
+ "<|call|>",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="harmony",
+ tool_parser="harmony",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == "Need weather"
+ assert _content_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"}
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_gpt_oss_split_commentary_routes_to_harmony_tool_parser(
+ monkeypatch,
+):
+ deltas = [
+ "<|channel|>",
+ "analysis",
+ "<|message|>",
+ "Need weather",
+ "<|end|>",
+ "<|channel|>",
+ "comment",
+ "ary to",
+ "=functions.get_weather",
+ " <|constrain|>",
+ "json",
+ "<|message|>",
+ '{"city":"Paris"}',
+ "<|call|>",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="gpt_oss",
+ tool_parser="harmony",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == "Need weather"
+ assert _content_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert json.loads(calls[0]["function"]["arguments"]) == {"city": "Paris"}
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_qwen_function_parameter_tool_call(monkeypatch):
+ deltas = [
+ "",
+ "Need tool",
+ "",
+ "\n",
+ "\n",
+ "\n",
+ '"Paris"\n',
+ "\n",
+ "\n",
+ "3\n",
+ "\n",
+ "\n",
+ "",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser="qwen3",
+ tool_parser="qwen",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == "Need tool"
+ assert _content_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert json.loads(calls[0]["function"]["arguments"]) == {
+ "city": "Paris",
+ "days": 3,
+ }
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_without_reasoning_parser_keeps_qwen_tools_working(monkeypatch):
+ deltas = [
+ "\n",
+ '{"name":"get_weather","arguments":{"city":"Paris"}}',
+ "\n",
+ ]
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser=None,
+ tool_parser="qwen",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert _reasoning_text(deltas_out) == ""
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
+
+
+@pytest.mark.anyio
+async def test_streaming_without_reasoning_parser_keeps_mistral_tools_working(
+ monkeypatch,
+):
+ deltas = ["[TOOL_CALLS]", 'get_weather{"city":"Paris"}']
+
+ deltas_out = await _run_stream(
+ monkeypatch,
+ deltas=deltas,
+ reasoning_parser=None,
+ tool_parser="mistral",
+ )
+
+ calls = _tool_calls(deltas_out)
+
+ assert len(calls) == 1
+ assert calls[0]["function"]["name"] == "get_weather"
+ assert deltas_out[-1]["finish_reason"] == "tool_calls"
diff --git a/tests/test_tool_parsers.py b/tests/test_tool_parsers.py
index dfe2bb6a..6606f0c2 100644
--- a/tests/test_tool_parsers.py
+++ b/tests/test_tool_parsers.py
@@ -185,6 +185,24 @@ def test_multiple_xml_calls(self, parser):
assert result.tools_called
assert len(result.tool_calls) == 2
+ def test_function_parameter_format(self, parser):
+ """Test parsing Qwen's function/parameter XML format."""
+ text = (
+ "\n"
+ "\n"
+ "Paris\n"
+ "3\n"
+ "\n"
+ ""
+ )
+ result = parser.extract_tool_calls(text)
+
+ assert result.tools_called
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0]["name"] == "get_weather"
+ args = json.loads(result.tool_calls[0]["arguments"])
+ assert args == {"city": "Paris", "days": 3}
+
def test_no_tool_call(self, parser):
"""Test text without tool calls."""
text = "I can help you with that question."
@@ -681,6 +699,20 @@ def test_auto_streaming(self):
)
assert result == {"content": "Hello world"}
+ def test_qwen_streaming_split_closing_tag(self):
+ """Qwen streaming should finish on accumulated text, not lucky final deltas."""
+ parser = QwenToolParser()
+
+ r = parser.extract_tool_calls_streaming(
+ previous_text="ParisParis",
+ delta_text=">",
+ )
+
+ assert r is not None
+ assert "tool_calls" in r
+ assert r["tool_calls"][0]["function"]["name"] == "get_weather"
+
class TestThinkTagStripping:
"""Test tag stripping in tool parsers (Issue #26)."""
diff --git a/vllm_mlx/engine/batched.py b/vllm_mlx/engine/batched.py
index ce33e628..b2c92983 100644
--- a/vllm_mlx/engine/batched.py
+++ b/vllm_mlx/engine/batched.py
@@ -453,6 +453,8 @@ async def generate(
if not self._loaded:
await self.start()
+ preserve_raw_output = bool(kwargs.pop("preserve_raw_output", False))
+
if self._is_mllm and self._mllm_scheduler:
# Use MLLM scheduler for all requests when model is multimodal.
# MLLM models only initialise the _mllm_scheduler (not _engine),
@@ -467,7 +469,11 @@ async def generate(
)
return GenerationOutput(
- text=clean_output_text(output.output_text),
+ text=(
+ output.output_text
+ if preserve_raw_output
+ else clean_output_text(output.output_text)
+ ),
prompt_tokens=output.prompt_tokens,
completion_tokens=output.completion_tokens,
finish_reason=output.finish_reason,
@@ -488,7 +494,11 @@ async def generate(
sampling_params=sampling_params,
)
- text = clean_output_text(output.output_text)
+ text = (
+ output.output_text
+ if preserve_raw_output
+ else clean_output_text(output.output_text)
+ )
return GenerationOutput(
text=text,
@@ -635,6 +645,7 @@ async def chat(
top_p=top_p,
images=all_images if all_images else None,
videos=all_videos if all_videos else None,
+ preserve_raw_output=bool(tools),
**kwargs,
)
diff --git a/vllm_mlx/engine/simple.py b/vllm_mlx/engine/simple.py
index e96317ef..33d45305 100644
--- a/vllm_mlx/engine/simple.py
+++ b/vllm_mlx/engine/simple.py
@@ -437,6 +437,36 @@ async def chat(
if not self._loaded:
await self.start()
+ chat_template_kwargs = dict(kwargs.pop("chat_template_kwargs", {}) or {})
+
+ # mlx-lm non-streaming chat with tools can stall indefinitely on some
+ # local models, while the streaming path completes normally. Reuse the
+ # streaming implementation and aggregate its final state so both chat
+ # APIs share the same tool-capable execution path.
+ if tools and not self._is_mllm:
+ stream_kwargs = dict(kwargs)
+ if chat_template_kwargs:
+ stream_kwargs["chat_template_kwargs"] = chat_template_kwargs
+ final_output = GenerationOutput(text="")
+ async for output in self.stream_chat(
+ messages=messages,
+ max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ tools=tools,
+ images=images,
+ videos=videos,
+ **stream_kwargs,
+ ):
+ final_output = output
+ text = final_output.text
+ return GenerationOutput(
+ text=text,
+ tokens=list(final_output.tokens),
+ prompt_tokens=final_output.prompt_tokens,
+ completion_tokens=final_output.completion_tokens,
+ finish_reason=final_output.finish_reason,
+ )
# Convert tools for template if provided
template_tools = convert_tools_for_template(tools) if tools else None
@@ -450,6 +480,7 @@ async def chat(
max_tokens=max_tokens,
temperature=temperature,
tools=template_tools,
+ chat_template_kwargs=chat_template_kwargs,
**kwargs,
)
text = clean_output_text(output.text)
@@ -469,6 +500,7 @@ async def chat(
temperature=temperature,
top_p=top_p,
tools=template_tools,
+ chat_template_kwargs=chat_template_kwargs,
**kwargs,
)
text = clean_output_text(output.text)
diff --git a/vllm_mlx/models/llm.py b/vllm_mlx/models/llm.py
index 72182037..9669d0c0 100644
--- a/vllm_mlx/models/llm.py
+++ b/vllm_mlx/models/llm.py
@@ -231,6 +231,17 @@ def stream_generate(
should_stop = True
break
+ eos_token_ids = getattr(self.tokenizer, "eos_token_ids", None)
+ if eos_token_ids is None:
+ eos_token_id = getattr(self.tokenizer, "eos_token_id", None)
+ eos_token_ids = [] if eos_token_id is None else [eos_token_id]
+ elif isinstance(eos_token_ids, int):
+ eos_token_ids = [eos_token_ids]
+
+ response_token = getattr(response, "token", None)
+ if response_token is not None and response_token in set(eos_token_ids):
+ should_stop = True
+
finished = should_stop or token_count >= max_tokens
finish_reason = None
if finished:
diff --git a/vllm_mlx/reasoning/gpt_oss_parser.py b/vllm_mlx/reasoning/gpt_oss_parser.py
index 8541faf2..fc2f9882 100644
--- a/vllm_mlx/reasoning/gpt_oss_parser.py
+++ b/vllm_mlx/reasoning/gpt_oss_parser.py
@@ -15,6 +15,7 @@
import re
from .base import DeltaMessage, ReasoningParser
+from .harmony_parser import HarmonyReasoningParser
# Structural tokens that should be stripped from output
_STRUCTURAL_TOKENS = re.compile(
@@ -69,6 +70,12 @@ class GptOssReasoningParser(ReasoningParser):
<|channel|>final <|constrain|>JSON<|message|>[content]<|return|>
"""
+ def __init__(self, tokenizer=None):
+ super().__init__(tokenizer)
+ # GPT-OSS streams the same Harmony channel tokens as the dedicated
+ # Harmony parser, including split channel markers across deltas.
+ self._stream_parser = HarmonyReasoningParser(tokenizer)
+
def extract_reasoning(
self,
model_output: str,
@@ -124,41 +131,13 @@ def extract_reasoning_streaming(
Returns:
DeltaMessage with reasoning and/or content, or None to skip.
"""
- prev_phase = self._detect_phase(previous_text)
- curr_phase = self._detect_phase(current_text)
-
- # Phase changed — extract content after the new marker
- if curr_phase != prev_phase and curr_phase in ("analysis", "final"):
- after_marker = self._extract_content_after_marker_in_delta(
- current_text, curr_phase
- )
- if after_marker:
- after_marker = self._strip_return(after_marker)
- if curr_phase == "analysis":
- return DeltaMessage(reasoning=after_marker)
- else:
- return DeltaMessage(content=after_marker)
- return None
-
- # In a steady phase — emit delta directly
- if curr_phase == "analysis":
- cleaned = self._strip_return(delta_text)
- # Skip structural tokens in the delta
- if _STRUCTURAL_TOKENS.search(cleaned):
- cleaned = _STRUCTURAL_TOKENS.sub("", cleaned)
- if cleaned:
- return DeltaMessage(reasoning=cleaned)
- return None
- elif curr_phase == "final":
- cleaned = self._strip_return(delta_text)
- if _STRUCTURAL_TOKENS.search(cleaned):
- cleaned = _STRUCTURAL_TOKENS.sub("", cleaned)
- if cleaned:
- return DeltaMessage(content=cleaned)
- return None
-
- # init or transition phase — skip structural tokens
- return None
+ return self._stream_parser.extract_reasoning_streaming(
+ previous_text, current_text, delta_text
+ )
+
+ def reset_state(self):
+ """Reset streaming state for a new request."""
+ self._stream_parser.reset_state()
@staticmethod
def _detect_phase(text: str) -> str:
diff --git a/vllm_mlx/reasoning/harmony_parser.py b/vllm_mlx/reasoning/harmony_parser.py
index 73b94f0e..63fb4aee 100644
--- a/vllm_mlx/reasoning/harmony_parser.py
+++ b/vllm_mlx/reasoning/harmony_parser.py
@@ -50,6 +50,8 @@ def __init__(self, tokenizer=None):
super().__init__(tokenizer)
self._current_channel: str | None = None
self._in_message: bool = False
+ self._pending_channel_text = ""
+ self._collecting_channel = False
def extract_reasoning(
self,
@@ -97,31 +99,34 @@ def extract_reasoning_streaming(
Returns:
DeltaMessage with reasoning and/or content, or None.
"""
+ if self._collecting_channel:
+ self._pending_channel_text += delta_text
+ resolved = self._try_resolve_pending_channel()
+ if resolved is not None:
+ return resolved
+
# Detect channel switches in the delta
if "<|channel|>" in delta_text:
- if "analysis" in delta_text:
- self._current_channel = "analysis"
- self._in_message = False
- return None
- elif "final" in delta_text:
- self._current_channel = "final"
- self._in_message = False
- return None
- elif "commentary" in delta_text:
- self._current_channel = "commentary"
- self._in_message = False
- return None
+ self._collecting_channel = True
+ self._pending_channel_text = delta_text
+ resolved = self._try_resolve_pending_channel()
+ if resolved is not None:
+ return resolved
+ return None
# Detect channel from full context if not yet determined
if self._current_channel is None and "<|channel|>" in current_text:
last_channel = current_text.rfind("<|channel|>")
after = current_text[last_channel + len("<|channel|>") :]
- if after.startswith("analysis"):
- self._current_channel = "analysis"
- elif after.startswith("final"):
- self._current_channel = "final"
- elif after.startswith("commentary"):
- self._current_channel = "commentary"
+ self._current_channel = self._resolve_channel_name(after)
+
+ # Commentary is routed through the tool parser via DeltaMessage.content.
+ if self._current_channel == "commentary":
+ if "<|message|>" in delta_text:
+ self._in_message = True
+ if any(token in delta_text for token in ("<|call|>", "<|end|>", "<|return|>")):
+ self._in_message = False
+ return DeltaMessage(content=delta_text)
# Handle message start
if "<|message|>" in delta_text:
@@ -155,3 +160,40 @@ def reset_state(self):
"""Reset streaming state for a new request."""
self._current_channel = None
self._in_message = False
+ self._pending_channel_text = ""
+ self._collecting_channel = False
+
+ @staticmethod
+ def _resolve_channel_name(text: str) -> str | None:
+ """Extract the active Harmony channel name from a delta or suffix."""
+ if "analysis" in text:
+ return "analysis"
+ if "final" in text:
+ return "final"
+ if "commentary" in text:
+ return "commentary"
+ return None
+
+ def _try_resolve_pending_channel(self) -> DeltaMessage | None:
+ """Resolve a partially streamed channel header once enough text arrives."""
+ suffix = self._pending_channel_text.split("<|channel|>", 1)[-1]
+ channel_name = self._resolve_channel_name(suffix)
+ if channel_name is None:
+ return None
+
+ buffered = self._pending_channel_text
+ self._pending_channel_text = ""
+ self._collecting_channel = False
+ self._current_channel = channel_name
+ self._in_message = "<|message|>" in buffered
+
+ if channel_name == "commentary":
+ return DeltaMessage(content=buffered)
+
+ if "<|message|>" in buffered:
+ content = buffered.split("<|message|>", 1)[1]
+ if content:
+ if channel_name == "analysis":
+ return DeltaMessage(reasoning=content)
+ return DeltaMessage(content=content)
+ return None
diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py
index 4cb15f02..ba2076a9 100644
--- a/vllm_mlx/server.py
+++ b/vllm_mlx/server.py
@@ -162,6 +162,15 @@ def _resolve_top_p(request_value: float | None) -> float:
_tool_parser_instance = None # Instantiated parser
+def _looks_like_streaming_tool_markup(delta_text: str) -> bool:
+ """Cheap trigger check before invoking streaming tool parsers."""
+ return (
+ "<" in delta_text
+ or "[TOOL_CALLS]" in delta_text
+ or "[Calling tool:" in delta_text
+ )
+
+
def _load_prefix_cache_from_disk() -> None:
"""Load prefix cache from disk during startup."""
try:
@@ -418,7 +427,1042 @@ def _parse_tool_calls_with_parser(
logger.warning(f"Tool parser error: {e}")
return parse_tool_calls(output_text, request_dict)
+def _parse_and_validate_tools(
+ output_text: str,
+ request: ChatCompletionRequest,
+ should_parse: bool,
+) -> tuple[str, list | None]:
+ """Parse tool calls from model output and validate against declared tools.
+
+ When *should_parse* is False (tool_choice="none"), returns the raw text
+ with no tool calls.
+ """
+ if not should_parse:
+ return output_text, None
+ cleaned, tool_calls = _parse_tool_calls_with_parser(output_text, request)
+ if tool_calls:
+ tool_calls = _validate_tool_calls(tool_calls, request.tools)
+ return cleaned, tool_calls
+
+
+def _set_tool_grammar_processor(chat_kwargs: dict, tools: list) -> None:
+ """Attach an Outlines grammar-constrained logits processor for tool calls.
+
+ Modifies *chat_kwargs* in place. Does nothing when Outlines is unavailable.
+ """
+ from .guided_decoding import build_tool_call_processor
+
+ processor = build_tool_call_processor(tools)
+ if processor:
+ chat_kwargs["logits_processors"] = [processor]
+
+
+def _apply_tool_choice(
+ tool_choice: str | dict | None,
+ chat_kwargs: dict,
+ messages: list[dict],
+) -> bool:
+ """Apply tool_choice policy to chat kwargs and messages.
+
+ Modifies *chat_kwargs* and *messages* in place so that the chat template
+ and downstream parsing honour the caller's tool_choice setting.
+
+ Returns ``True`` when the model output should be parsed for tool calls,
+ ``False`` when tool-call parsing must be skipped (``tool_choice="none"``).
+ """
+ if tool_choice == "none":
+ chat_kwargs.pop("tools", None)
+ return False
+
+ if tool_choice == "required":
+ messages.insert(
+ 0,
+ {
+ "role": "system",
+ "content": (
+ "You MUST call one of the provided tools. "
+ "Do not respond with plain text."
+ ),
+ },
+ )
+ _set_tool_grammar_processor(chat_kwargs, chat_kwargs.get("tools", []))
+ return True
+
+ if isinstance(tool_choice, dict):
+ func_info = tool_choice.get("function", {})
+ fname = func_info.get("name", "") if isinstance(func_info, dict) else ""
+ if fname:
+ template_tools = chat_kwargs.get("tools")
+ if template_tools:
+ filtered = [
+ t
+ for t in template_tools
+ if t.get("function", {}).get("name") == fname
+ ]
+ if filtered:
+ chat_kwargs["tools"] = filtered
+ messages.insert(
+ 0,
+ {
+ "role": "system",
+ "content": f"You MUST call the function: {fname}",
+ },
+ )
+ _set_tool_grammar_processor(chat_kwargs, filtered)
+ return True
+ # Named function not found in tools — fall back to auto
+ logger.warning(
+ "tool_choice function %r not found in tools, falling back to auto",
+ fname,
+ )
+ return True
+
+ # "auto" or None — apply lazy grammar trigger when the parser defines one.
+ # The processor stays inactive until the model emits a tool call trigger
+ # token, then constrains the JSON body to match the tool schemas.
+ if chat_kwargs.get("tools") and _tool_call_parser and _tool_parser_instance:
+ parser_cls = type(_tool_parser_instance)
+ if parser_cls.TRIGGER_TOKEN_IDS:
+ from .guided_decoding import build_lazy_tool_call_processor
+
+ processor = build_lazy_tool_call_processor(
+ tools=chat_kwargs.get("tools", []),
+ trigger_tokens=parser_cls.TRIGGER_TOKEN_IDS,
+ end_tokens=parser_cls.END_TOKEN_IDS,
+ prefix_skip=parser_cls.PREFIX_SKIP_TOKENS,
+ )
+ if processor:
+ chat_kwargs["logits_processors"] = [processor]
+ return True
+
+
+def _validate_tool_calls(
+ tool_calls: list | None,
+ tools: list | None,
+) -> list | None:
+ """Remove tool calls with unknown names or invalid arguments.
+
+ Validates each parsed tool call against the declared tools:
+ 1. Function name must exist in the tools list.
+ 2. Arguments must be valid JSON.
+ 3. Arguments must conform to the tool's parameters schema (if declared).
+
+ Invalid tool calls are dropped with a warning log.
+ Returns None if all tool calls are invalid.
+ """
+ if not tool_calls:
+ return None
+ if tools is None:
+ return tool_calls
+
+ # Build lookup: function name -> parameters schema
+ tools_by_name: dict[str, dict | None] = {}
+ for t in tools:
+ func = (
+ t.function
+ if isinstance(t, ToolDefinition)
+ else (t.get("function") if isinstance(t, dict) else None)
+ )
+ if func and isinstance(func, dict):
+ fname = func.get("name")
+ if fname:
+ tools_by_name[fname] = func.get("parameters")
+
+ valid = []
+ for tc in tool_calls:
+ name = tc.function.name
+ if name not in tools_by_name:
+ logger.warning("Dropping tool call with unknown function: %s", name)
+ continue
+ schema = tools_by_name[name]
+ if schema:
+ try:
+ args = json.loads(tc.function.arguments)
+ jsonschema.validate(args, schema)
+ except (json.JSONDecodeError, jsonschema.ValidationError, jsonschema.SchemaError) as exc:
+ logger.warning(
+ "Dropping tool call %s: invalid arguments: %s", name, exc
+ )
+ continue
+ valid.append(tc)
+
+ return valid or None
+
+
+def _new_response_item_id(prefix: str) -> str:
+ """Generate stable OpenAI-style item ids."""
+ return f"{prefix}_{uuid.uuid4().hex}"
+
+
+def _response_content_to_text(content) -> str:
+ """Normalize Responses API content items into plain text."""
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+
+ text_parts = []
+ for part in content:
+ if isinstance(part, dict):
+ part_type = part.get("type")
+ text = part.get("text", "")
+ else:
+ part_type = getattr(part, "type", None)
+ text = getattr(part, "text", "")
+ if part_type in {"text", "input_text", "output_text"}:
+ text_parts.append(text)
+ return "\n".join(part for part in text_parts if part)
+
+
+def _responses_tools_to_chat_tools(
+ tools: list[ResponseFunctionTool | dict],
+) -> tuple[list[dict] | None, list[str]]:
+ """Convert supported Responses tools and report unsupported tool types."""
+ if not tools:
+ return None, []
+
+ supported: list[dict] = []
+ unsupported: list[str] = []
+
+ for tool in tools:
+ if isinstance(tool, ResponseFunctionTool):
+ tool_type = tool.type
+ tool_name = tool.name
+ tool_description = tool.description or ""
+ tool_parameters = tool.parameters
+ elif isinstance(tool, dict):
+ tool_type = tool.get("type", "unknown")
+ tool_name = tool.get("name", "")
+ tool_description = tool.get("description", "")
+ tool_parameters = tool.get("parameters", {})
+ else:
+ unsupported.append(type(tool).__name__)
+ continue
+
+ if tool_type == "function":
+ supported.append(
+ {
+ "type": "function",
+ "function": {
+ "name": tool_name,
+ "description": tool_description,
+ "parameters": tool_parameters
+ or {"type": "object", "properties": {}},
+ },
+ }
+ )
+ else:
+ unsupported.append(tool_type)
+
+ return supported or None, unsupported
+
+
+def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]:
+ """Convert Responses API input items into chat-completions-style messages."""
+ messages: list[dict] = []
+
+ if request.previous_response_id:
+ previous = _responses_store.get(request.previous_response_id)
+ if previous is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Previous response `{request.previous_response_id}` not found",
+ )
+ messages.extend(copy.deepcopy(previous["messages"]))
+
+ if request.instructions:
+ messages.append({"role": "system", "content": request.instructions})
+
+ if isinstance(request.input, str):
+ messages.append({"role": "user", "content": request.input})
+ return messages
+
+ for item in request.input:
+ if isinstance(item, dict):
+ item_type = item.get("type", "")
+ if item_type == "message":
+ role = item.get("role", "user")
+ if role == "developer":
+ role = "system"
+ messages.append(
+ {
+ "role": role,
+ "content": _response_content_to_text(item.get("content")),
+ }
+ )
+ elif item_type == "function_call":
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "id": item.get("call_id", _new_response_item_id("call")),
+ "type": "function",
+ "function": {
+ "name": item.get("name", ""),
+ "arguments": item.get("arguments", ""),
+ },
+ }
+ ],
+ }
+ )
+ elif item_type == "function_call_output":
+ messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": item.get("call_id", ""),
+ "content": item.get("output", ""),
+ }
+ )
+ elif item_type == "reasoning":
+ parts = item.get("content", [])
+ reasoning_text = "\n".join(
+ p.get("text", "") for p in parts if isinstance(p, dict)
+ )
+ if reasoning_text:
+ messages.append({"role": "assistant", "content": reasoning_text})
+ else:
+ logger.info(
+ "Skipping unsupported Responses input item type %r", item_type
+ )
+ continue
+
+ if isinstance(item, ResponseMessageItem):
+ role = item.role
+ if role == "developer":
+ role = "system"
+ messages.append(
+ {
+ "role": role,
+ "content": _response_content_to_text(item.content),
+ }
+ )
+ elif isinstance(item, ResponseFunctionCallItem):
+ messages.append(
+ {
+ "role": "assistant",
+ "content": "",
+ "tool_calls": [
+ {
+ "id": item.call_id,
+ "type": "function",
+ "function": {
+ "name": item.name,
+ "arguments": item.arguments,
+ },
+ }
+ ],
+ }
+ )
+ elif isinstance(item, ResponseFunctionCallOutputItem):
+ messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": item.call_id,
+ "content": item.output,
+ }
+ )
+ elif isinstance(item, ResponseReasoningItem):
+ reasoning_text = "\n".join(part.text for part in (item.content or []))
+ if reasoning_text:
+ messages.append({"role": "assistant", "content": reasoning_text})
+ else:
+ logger.info(
+ "Skipping unsupported Responses input item type %r",
+ getattr(item, "type", type(item).__name__),
+ )
+
+ return messages
+
+
+def _responses_request_to_new_persisted_messages(request: ResponsesRequest) -> list[dict]:
+ """Persist only the current request's replayable input items."""
+ request_without_history = request.model_copy(
+ update={"previous_response_id": None, "instructions": None},
+ deep=True,
+ )
+ return _responses_input_to_chat_messages(request_without_history)
+
+
+def _responses_request_to_persisted_messages(request: ResponsesRequest) -> list[dict]:
+ """Persist replayable history for chained previous_response_id requests.
+
+ Responses `instructions` are intentionally not replayed across
+ `previous_response_id`, but replayable message items are.
+ """
+ messages: list[dict] = []
+ if request.previous_response_id:
+ previous = _responses_store.get(request.previous_response_id)
+ if previous is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Previous response `{request.previous_response_id}` not found",
+ )
+ messages.extend(copy.deepcopy(previous["messages"]))
+ messages.extend(_responses_request_to_new_persisted_messages(request))
+ return messages
+
+
+def _responses_request_to_chat_request(request: ResponsesRequest) -> ChatCompletionRequest:
+ """Build a ChatCompletionRequest from a ResponsesRequest."""
+ if request.text.format.type == "json_object":
+ raise HTTPException(
+ status_code=400,
+ detail="Responses text.format.type='json_object' is not supported on this backend",
+ )
+ if request.reasoning is not None:
+ logger.debug("Ignoring reasoning configuration (not supported on this backend)")
+
+ tools, unsupported_tools = _responses_tools_to_chat_tools(request.tools)
+ messages = _responses_input_to_chat_messages(request)
+ if unsupported_tools:
+ tool_list = ", ".join(sorted(set(unsupported_tools)))
+ messages.insert(
+ 0,
+ {
+ "role": "system",
+ "content": (
+ "The following requested tool types are not available on this "
+ f"backend: {tool_list}. Do not call them."
+ ),
+ },
+ )
+
+ system_messages = [msg for msg in messages if msg.get("role") == "system"]
+ non_system_messages = [msg for msg in messages if msg.get("role") != "system"]
+ merged_system_content = "\n\n".join(
+ str(msg.get("content", "")).strip()
+ for msg in system_messages
+ if str(msg.get("content", "")).strip()
+ )
+ messages = (
+ [{"role": "system", "content": merged_system_content}]
+ if merged_system_content
+ else []
+ ) + non_system_messages
+
+ return ChatCompletionRequest(
+ model=request.model,
+ messages=[Message(**msg) for msg in messages],
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_tokens=request.max_output_tokens,
+ stream=False,
+ tools=tools,
+ tool_choice=request.tool_choice,
+ )
+
+def _build_responses_output_items(
+ text: str | None,
+ reasoning: str | None,
+ tool_calls: list[ToolCall] | None,
+) -> list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem]:
+ """Convert parsed assistant output into Responses API output items."""
+ output_items: list[
+ ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem
+ ] = []
+
+ if reasoning:
+ output_items.append(
+ ResponseReasoningItem(
+ id=_new_response_item_id("rs"),
+ content=[ResponseReasoningTextPart(text=reasoning)],
+ )
+ )
+
+ if text:
+ output_items.append(
+ ResponseMessageItem(
+ id=_new_response_item_id("msg"),
+ role="assistant",
+ content=[ResponseTextContentPart(type="output_text", text=text)],
+ )
+ )
+
+ for tool_call in tool_calls or []:
+ output_items.append(
+ ResponseFunctionCallItem(
+ id=_new_response_item_id("fc"),
+ call_id=tool_call.id,
+ name=tool_call.function.name,
+ arguments=tool_call.function.arguments,
+ )
+ )
+
+ return output_items
+
+
+def _response_output_items_to_chat_messages(output_items: list) -> list[dict]:
+ """Persist assistant output in chat-completions form for previous_response_id."""
+ assistant_text_parts: list[str] = []
+ assistant_tool_calls: list[dict] = []
+
+ for item in output_items:
+ if isinstance(item, ResponseMessageItem):
+ assistant_text_parts.append(_response_content_to_text(item.content))
+ elif isinstance(item, ResponseFunctionCallItem):
+ assistant_tool_calls.append(
+ {
+ "id": item.call_id,
+ "type": "function",
+ "function": {
+ "name": item.name,
+ "arguments": item.arguments,
+ },
+ }
+ )
+
+ if not assistant_text_parts and not assistant_tool_calls:
+ return []
+
+ return [
+ {
+ "role": "assistant",
+ "content": "".join(assistant_text_parts),
+ "tool_calls": assistant_tool_calls or None,
+ }
+ ]
+
+
+def _build_response_object(
+ request: ResponsesRequest,
+ output_items: list[ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem],
+ prompt_tokens: int,
+ completion_tokens: int,
+ finish_reason: str | None,
+ response_id: str | None = None,
+) -> ResponseObject:
+ """Build a full Responses API object."""
+ response = ResponseObject(
+ id=response_id or _new_response_item_id("resp"),
+ model=_model_name or request.model,
+ instructions=request.instructions,
+ max_output_tokens=request.max_output_tokens,
+ metadata=request.metadata,
+ output=output_items,
+ parallel_tool_calls=request.parallel_tool_calls,
+ previous_response_id=request.previous_response_id,
+ text=request.text,
+ tool_choice=request.tool_choice,
+ tools=request.tools,
+ top_p=_resolve_top_p(request.top_p),
+ temperature=_resolve_temperature(request.temperature),
+ truncation=request.truncation,
+ user=request.user,
+ store=request.store,
+ usage=ResponsesUsage(
+ input_tokens=prompt_tokens,
+ output_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ ),
+ )
+ if finish_reason == "length":
+ response.status = "incomplete"
+ response.incomplete_details = ResponseIncompleteDetails(
+ reason="max_output_tokens"
+ )
+ return response
+
+
+def _prepare_responses_request(
+ request: ResponsesRequest,
+) -> tuple[BaseEngine, ChatCompletionRequest, list[dict], dict, bool]:
+ """Prepare a Responses request for execution on the chat engine."""
+ _validate_model_name(request.model)
+ engine = get_engine()
+ chat_request = _responses_request_to_chat_request(request)
+
+ if chat_request.messages:
+ logger.info(
+ f"[REQUEST] POST /v1/responses stream={request.stream} "
+ f"model={request.model!r} items="
+ f"{len(request.input) if isinstance(request.input, list) else 1} "
+ f"tools={len(request.tools)}"
+ )
+
+ messages, images, videos = extract_multimodal_content(
+ chat_request.messages,
+ preserve_native_format=engine.preserve_native_tool_format,
+ )
+
+ chat_kwargs = {
+ "max_tokens": chat_request.max_tokens or _default_max_tokens,
+ "temperature": _resolve_temperature(chat_request.temperature),
+ "top_p": _resolve_top_p(chat_request.top_p),
+ }
+ if request.tools:
+ chat_kwargs["tools"] = convert_tools_for_template(chat_request.tools)
+ should_parse_tools = _apply_tool_choice(
+ chat_request.tool_choice, chat_kwargs, messages
+ )
+ if images:
+ chat_kwargs["images"] = images
+ if videos:
+ chat_kwargs["videos"] = videos
+
+ return engine, chat_request, messages, chat_kwargs, should_parse_tools
+
+
+async def _run_responses_request(
+ request: ResponsesRequest,
+ raw_request: Request,
+) -> tuple[ResponseObject | None, list[dict]]:
+ """Execute a Responses API request against the backend chat engine."""
+ engine, chat_request, messages, chat_kwargs, should_parse_tools = (
+ _prepare_responses_request(request)
+ )
+
+ timeout = _default_timeout
+ output = await _wait_with_disconnect(
+ engine.chat(messages=messages, **chat_kwargs),
+ raw_request,
+ timeout=timeout,
+ )
+ if output is None:
+ return None, []
+
+ cleaned_text, tool_calls = _parse_and_validate_tools(
+ output.text, chat_request, should_parse_tools
+ )
+ reasoning_text = None
+ if _reasoning_parser and not tool_calls:
+ reasoning_text, cleaned_text = _reasoning_parser.extract_reasoning(
+ cleaned_text or output.text
+ )
+
+ output_items = _build_responses_output_items(
+ clean_output_text(cleaned_text) if cleaned_text else None,
+ reasoning_text,
+ tool_calls,
+ )
+ response_object = _build_response_object(
+ request=request,
+ output_items=output_items,
+ prompt_tokens=output.prompt_tokens,
+ completion_tokens=output.completion_tokens,
+ finish_reason=output.finish_reason,
+ )
+
+ persisted_messages = _responses_request_to_persisted_messages(request)
+ persisted_messages.extend(_response_output_items_to_chat_messages(output_items))
+ if request.store:
+ _responses_store[response_object.id] = {
+ "messages": copy.deepcopy(persisted_messages),
+ "response": response_object.model_copy(deep=True),
+ }
+ while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE:
+ _responses_store.popitem(last=False)
+
+ return response_object, persisted_messages
+
+
+async def _stream_responses_request(request: ResponsesRequest) -> AsyncIterator[str]:
+ """Execute a Responses API request and stream SSE events incrementally."""
+ engine, chat_request, messages, chat_kwargs, should_parse_tools = (
+ _prepare_responses_request(request)
+ )
+
+ response_id = _new_response_item_id("resp")
+ sequence = 1
+ base_response = _build_response_object(
+ request=request,
+ output_items=[],
+ prompt_tokens=0,
+ completion_tokens=0,
+ finish_reason=None,
+ response_id=response_id,
+ )
+ base_response.status = "in_progress"
+ base_response.usage = None
+
+ yield _responses_sse_event(
+ "response.created",
+ ResponseCreatedEvent(sequence_number=sequence, response=base_response),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.in_progress",
+ ResponseInProgressEvent(sequence_number=sequence, response=base_response),
+ )
+ sequence += 1
+
+ prompt_tokens = 0
+ completion_tokens = 0
+ finish_reason = None
+ last_output = None
+ raw_accumulated_text = ""
+ accumulated_text = ""
+ accumulated_reasoning = ""
+
+ text_item_id: str | None = None
+ text_output_index: int | None = None
+ reasoning_item_id: str | None = None
+ reasoning_output_index: int | None = None
+ next_output_index = 0
+
+ def _start_text_item() -> list[str]:
+ nonlocal text_item_id, text_output_index, next_output_index, sequence
+ events: list[str] = []
+ if text_item_id is None:
+ text_item_id = _new_response_item_id("msg")
+ text_output_index = next_output_index
+ next_output_index += 1
+ events.append(
+ _responses_sse_event(
+ "response.output_item.added",
+ ResponseOutputItemAddedEvent(
+ sequence_number=sequence,
+ output_index=text_output_index,
+ item=ResponseMessageItem(
+ id=text_item_id,
+ role="assistant",
+ status="in_progress",
+ content=[],
+ ),
+ ),
+ )
+ )
+ sequence += 1
+ events.append(
+ _responses_sse_event(
+ "response.content_part.added",
+ ResponseContentPartAddedEvent(
+ sequence_number=sequence,
+ item_id=text_item_id,
+ output_index=text_output_index,
+ content_index=0,
+ part=ResponseTextContentPart(type="output_text", text=""),
+ ),
+ )
+ )
+ sequence += 1
+ return events
+
+ def _start_reasoning_item() -> list[str]:
+ nonlocal reasoning_item_id, reasoning_output_index, next_output_index, sequence
+ events: list[str] = []
+ if reasoning_item_id is None:
+ reasoning_item_id = _new_response_item_id("rs")
+ reasoning_output_index = next_output_index
+ next_output_index += 1
+ events.append(
+ _responses_sse_event(
+ "response.output_item.added",
+ ResponseOutputItemAddedEvent(
+ sequence_number=sequence,
+ output_index=reasoning_output_index,
+ item=ResponseReasoningItem(
+ id=reasoning_item_id,
+ status="in_progress",
+ content=[],
+ ),
+ ),
+ )
+ )
+ sequence += 1
+ events.append(
+ _responses_sse_event(
+ "response.content_part.added",
+ ResponseContentPartAddedEvent(
+ sequence_number=sequence,
+ item_id=reasoning_item_id,
+ output_index=reasoning_output_index,
+ content_index=0,
+ part=ResponseReasoningTextPart(text=""),
+ ),
+ )
+ )
+ sequence += 1
+ return events
+
+ if _reasoning_parser:
+ _reasoning_parser.reset_state()
+
+ global _tool_parser_instance
+ tool_parser = None
+ tool_accumulated_text = ""
+ tool_markup_possible = False
+ if should_parse_tools and _enable_auto_tool_choice and _tool_call_parser:
+ if _tool_parser_instance is None:
+ try:
+ parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
+ tokenizer = None
+ if _engine is not None and hasattr(_engine, "_tokenizer"):
+ tokenizer = _engine._tokenizer
+ _tool_parser_instance = parser_cls(tokenizer)
+ logger.info(
+ "Initialized tool call parser for responses streaming: %s",
+ _tool_call_parser,
+ )
+ except Exception as e:
+ logger.warning(
+ "Failed to init tool parser for responses streaming: %s", e
+ )
+ if _tool_parser_instance is not None:
+ tool_parser = _tool_parser_instance
+ tool_parser.reset()
+
+ async for output in engine.stream_chat(messages=messages, **chat_kwargs):
+ last_output = output
+ finish_reason = output.finish_reason
+ if hasattr(output, "prompt_tokens") and output.prompt_tokens:
+ prompt_tokens = output.prompt_tokens
+ if hasattr(output, "completion_tokens") and output.completion_tokens:
+ completion_tokens = output.completion_tokens
+
+ delta_text = output.new_text or ""
+ if not delta_text:
+ continue
+
+ previous_text = raw_accumulated_text
+ raw_accumulated_text += delta_text
+
+ if _reasoning_parser:
+ delta_msg = _reasoning_parser.extract_reasoning_streaming(
+ previous_text, raw_accumulated_text, delta_text
+ )
+ if delta_msg is None:
+ continue
+
+ if delta_msg.reasoning:
+ for event in _start_reasoning_item():
+ yield event
+ accumulated_reasoning += delta_msg.reasoning
+ yield _responses_sse_event(
+ "response.reasoning_text.delta",
+ ResponseReasoningTextDeltaEvent(
+ sequence_number=sequence,
+ item_id=reasoning_item_id,
+ output_index=reasoning_output_index,
+ content_index=0,
+ delta=delta_msg.reasoning,
+ ),
+ )
+ sequence += 1
+
+ if delta_msg.content:
+ for event in _start_text_item():
+ yield event
+ accumulated_text += delta_msg.content
+ yield _responses_sse_event(
+ "response.output_text.delta",
+ ResponseOutputTextDeltaEvent(
+ sequence_number=sequence,
+ item_id=text_item_id,
+ output_index=text_output_index,
+ content_index=0,
+ delta=delta_msg.content,
+ ),
+ )
+ sequence += 1
+ continue
+
+ content = SPECIAL_TOKENS_PATTERN.sub("", delta_text)
+ if tool_parser and delta_text:
+ if not tool_markup_possible and not _looks_like_streaming_tool_markup(
+ delta_text
+ ):
+ tool_accumulated_text += delta_text
+ else:
+ if not tool_markup_possible:
+ tool_markup_possible = True
+ tool_result = tool_parser.extract_tool_calls_streaming(
+ tool_accumulated_text, tool_accumulated_text + delta_text, delta_text
+ )
+ tool_accumulated_text += delta_text
+ if tool_result is None:
+ continue
+ if "tool_calls" in tool_result:
+ continue
+ content = tool_result.get("content", "")
+
+ if not content:
+ continue
+
+ for event in _start_text_item():
+ yield event
+ accumulated_text += content
+ yield _responses_sse_event(
+ "response.output_text.delta",
+ ResponseOutputTextDeltaEvent(
+ sequence_number=sequence,
+ item_id=text_item_id,
+ output_index=text_output_index,
+ content_index=0,
+ delta=content,
+ ),
+ )
+ sequence += 1
+
+ cleaned_text, tool_calls = _parse_and_validate_tools(
+ raw_accumulated_text, chat_request, should_parse_tools
+ )
+ final_text = accumulated_text
+ if cleaned_text is not None and not final_text and not tool_calls:
+ final_text = clean_output_text(cleaned_text)
+
+ reasoning_item = None
+ if reasoning_item_id is not None:
+ reasoning_item = ResponseReasoningItem(
+ id=reasoning_item_id,
+ status="completed",
+ content=[ResponseReasoningTextPart(text=accumulated_reasoning)],
+ )
+ yield _responses_sse_event(
+ "response.reasoning_text.done",
+ ResponseReasoningTextDoneEvent(
+ sequence_number=sequence,
+ item_id=reasoning_item_id,
+ output_index=reasoning_output_index,
+ content_index=0,
+ text=accumulated_reasoning,
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.content_part.done",
+ ResponseContentPartDoneEvent(
+ sequence_number=sequence,
+ item_id=reasoning_item_id,
+ output_index=reasoning_output_index,
+ content_index=0,
+ part=reasoning_item.content[0],
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.output_item.done",
+ ResponseOutputItemDoneEvent(
+ sequence_number=sequence,
+ output_index=reasoning_output_index,
+ item=reasoning_item,
+ ),
+ )
+ sequence += 1
+
+ text_item = None
+ if text_item_id is not None or final_text:
+ if text_item_id is None:
+ for event in _start_text_item():
+ yield event
+ text_item = ResponseMessageItem(
+ id=text_item_id,
+ role="assistant",
+ status="completed",
+ content=[ResponseTextContentPart(type="output_text", text=final_text)],
+ )
+ yield _responses_sse_event(
+ "response.output_text.done",
+ ResponseOutputTextDoneEvent(
+ sequence_number=sequence,
+ item_id=text_item_id,
+ output_index=text_output_index,
+ content_index=0,
+ text=final_text,
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.content_part.done",
+ ResponseContentPartDoneEvent(
+ sequence_number=sequence,
+ item_id=text_item_id,
+ output_index=text_output_index,
+ content_index=0,
+ part=text_item.content[0],
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.output_item.done",
+ ResponseOutputItemDoneEvent(
+ sequence_number=sequence,
+ output_index=text_output_index,
+ item=text_item,
+ ),
+ )
+ sequence += 1
+
+ function_call_items: list[ResponseFunctionCallItem] = []
+ for tool_call in tool_calls or []:
+ output_index = next_output_index
+ next_output_index += 1
+ item = ResponseFunctionCallItem(
+ id=_new_response_item_id("fc"),
+ call_id=tool_call.id,
+ name=tool_call.function.name,
+ arguments=tool_call.function.arguments,
+ )
+ function_call_items.append(item)
+ yield _responses_sse_event(
+ "response.output_item.added",
+ ResponseOutputItemAddedEvent(
+ sequence_number=sequence,
+ output_index=output_index,
+ item=item.model_copy(update={"status": "in_progress"}),
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.function_call_arguments.delta",
+ ResponseFunctionCallArgumentsDeltaEvent(
+ sequence_number=sequence,
+ item_id=item.id,
+ output_index=output_index,
+ delta=item.arguments,
+ ),
+ )
+ sequence += 1
+ yield _responses_sse_event(
+ "response.output_item.done",
+ ResponseOutputItemDoneEvent(
+ sequence_number=sequence,
+ output_index=output_index,
+ item=item,
+ ),
+ )
+ sequence += 1
+
+ output_items: list[
+ ResponseMessageItem | ResponseReasoningItem | ResponseFunctionCallItem
+ ] = []
+ if reasoning_item is not None:
+ output_items.append(reasoning_item)
+ if text_item is not None:
+ output_items.append(text_item)
+ output_items.extend(function_call_items)
+
+ response_object = _build_response_object(
+ request=request,
+ output_items=output_items,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ finish_reason=finish_reason,
+ response_id=response_id,
+ )
+
+ if request.store and last_output is not None:
+ persisted_messages = _responses_request_to_persisted_messages(request)
+ persisted_messages.extend(_response_output_items_to_chat_messages(output_items))
+ _responses_store[response_object.id] = {
+ "messages": copy.deepcopy(persisted_messages),
+ "response": response_object.model_copy(deep=True),
+ }
+ while len(_responses_store) > _RESPONSES_STORE_MAX_SIZE:
+ _responses_store.popitem(last=False)
+
+ yield _responses_sse_event(
+ "response.completed",
+ ResponseCompletedEvent(sequence_number=sequence, response=response_object),
+ )
+
+
+def _responses_sse_event(event_type: str, payload: BaseModel | dict) -> str:
+ """Encode a Responses API SSE event."""
+ data = payload.model_dump_json() if isinstance(payload, BaseModel) else json.dumps(payload)
+ return f"event: {event_type}\ndata: {data}\n\n"
def _detect_native_tool_support() -> bool:
"""
Detect if the active tool parser supports native tool format.
@@ -2000,73 +3044,69 @@ async def stream_chat_completion(
# Skip this chunk (e.g., token itself)
continue
- content = delta_msg.content
- reasoning = delta_msg.reasoning
-
- # Tool call parsing on the content portion (fix: was
- # previously unreachable because tool parser was in the
- # else branch of this if/else)
- if tool_parser and content:
- if not tool_markup_possible and "<" not in content:
- tool_accumulated_text += content
- else:
- if not tool_markup_possible:
- tool_markup_possible = True
- tool_previous = tool_accumulated_text
- tool_accumulated_text += content
+ # Route content through tool parser when both are active
+ emit_content = delta_msg.content
+ if tool_parser and emit_content:
+ tool_accumulated_text += emit_content
+ if tool_markup_possible or _looks_like_streaming_tool_markup(
+ emit_content
+ ):
+ tool_markup_possible = True
tool_result = tool_parser.extract_tool_calls_streaming(
- tool_previous, tool_accumulated_text, content
+ tool_accumulated_text[: -len(emit_content)],
+ tool_accumulated_text,
+ emit_content,
)
if tool_result is None:
- # Inside tool markup — suppress content, keep reasoning
- content = None
+ emit_content = None
elif "tool_calls" in tool_result:
+ tool_call_complete = bool(tool_result.get("complete"))
tool_calls_detected = True
- chunk = ChatCompletionChunk(
+ tc_chunk = ChatCompletionChunk(
id=response_id,
- model=request.model,
+ model=_model_name,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
- reasoning=reasoning,
tool_calls=tool_result["tool_calls"],
),
finish_reason=(
- "tool_calls" if output.finished else None
+ "tool_calls"
+ if (output.finished or tool_call_complete)
+ else None
),
)
],
- usage=get_usage(output) if output.finished else None,
+ usage=(
+ get_usage(output)
+ if (output.finished or tool_call_complete)
+ else None
+ ),
)
- yield f"data: {chunk.model_dump_json()}\n\n"
+ yield f"data: {tc_chunk.model_dump_json()}\n\n"
+ if tool_call_complete:
+ break
continue
else:
- content = tool_result.get("content", "")
-
- # Skip if both content and reasoning are empty
- if not content and not reasoning:
- continue
-
- chunk = ChatCompletionChunk(
- id=response_id,
- model=_model_name,
- choices=[
- ChatCompletionChunkChoice(
- delta=ChatCompletionChunkDelta(
- content=content if content else None,
- reasoning=reasoning,
- ),
- finish_reason=(
- "tool_calls"
- if (output.finished and tool_calls_detected)
- else (output.finish_reason if output.finished else None)
- ),
- )
- ],
- usage=get_usage(output) if output.finished else None,
- )
- yield f"data: {chunk.model_dump_json()}\n\n"
+ emit_content = tool_result.get("content", emit_content)
+
+ if emit_content or delta_msg.reasoning:
+ chunk = ChatCompletionChunk(
+ id=response_id,
+ model=_model_name,
+ choices=[
+ ChatCompletionChunkChoice(
+ delta=ChatCompletionChunkDelta(
+ content=emit_content,
+ reasoning=delta_msg.reasoning,
+ ),
+ finish_reason=output.finish_reason if output.finished else None,
+ )
+ ],
+ usage=get_usage(output) if output.finished else None,
+ )
+ yield f"data: {chunk.model_dump_json()}\n\n"
else:
# Standard path without reasoning parsing
content = delta_text
@@ -2085,7 +3125,9 @@ async def stream_chat_completion(
# Fast path: skip full parsing until '<' is seen in the stream,
# which could start tool markup (e.g. ). This avoids
# per-token string scanning on the growing accumulated text.
- if not tool_markup_possible and "<" not in delta_text:
+ if not tool_markup_possible and not _looks_like_streaming_tool_markup(
+ delta_text
+ ):
tool_accumulated_text += delta_text
# No tool markup yet, fall through to normal chunk emission
else:
@@ -2102,6 +3144,7 @@ async def stream_chat_completion(
continue
if "tool_calls" in tool_result:
+ tool_call_complete = bool(tool_result.get("complete"))
# Emit structured tool calls
tool_calls_detected = True
chunk = ChatCompletionChunk(
@@ -2113,13 +3156,21 @@ async def stream_chat_completion(
tool_calls=tool_result["tool_calls"]
),
finish_reason=(
- "tool_calls" if output.finished else None
+ "tool_calls"
+ if (output.finished or tool_call_complete)
+ else None
),
)
],
- usage=get_usage(output) if output.finished else None,
+ usage=(
+ get_usage(output)
+ if (output.finished or tool_call_complete)
+ else None
+ ),
)
yield f"data: {chunk.model_dump_json()}\n\n"
+ if tool_call_complete:
+ break
continue
# Normal content from tool parser
@@ -2144,16 +3195,14 @@ async def stream_chat_completion(
)
yield f"data: {chunk.model_dump_json()}\n\n"
- # Fallback: if tool parser accumulated text but never emitted tool_calls
- # (e.g., never arrived - incomplete tool call)
- if (
- tool_parser
- and tool_accumulated_text
- and not tool_calls_detected
- and "" in tool_accumulated_text
- ):
- result = tool_parser.extract_tool_calls(tool_accumulated_text)
- if result.tools_called:
+ # Completion guard: if a stream ended with fully accumulated tool markup but
+ # the incremental parser never emitted, run one final validated parse over
+ # the complete tool text before returning.
+ if tool_parser and tool_accumulated_text and not tool_calls_detected:
+ _, fallback_tool_calls = _parse_and_validate_tools(
+ tool_accumulated_text, request, True
+ )
+ if fallback_tool_calls:
tool_chunk = ChatCompletionChunk(
id=response_id,
model=_model_name,
@@ -2163,14 +3212,14 @@ async def stream_chat_completion(
tool_calls=[
{
"index": i,
- "id": tc["id"],
+ "id": tc.id,
"type": "function",
"function": {
- "name": tc["name"],
- "arguments": tc["arguments"],
+ "name": tc.function.name,
+ "arguments": tc.function.arguments,
},
}
- for i, tc in enumerate(result.tool_calls)
+ for i, tc in enumerate(fallback_tool_calls)
]
),
finish_reason="tool_calls",
diff --git a/vllm_mlx/tool_parsers/harmony_tool_parser.py b/vllm_mlx/tool_parsers/harmony_tool_parser.py
index 34f8555d..c4d08952 100644
--- a/vllm_mlx/tool_parsers/harmony_tool_parser.py
+++ b/vllm_mlx/tool_parsers/harmony_tool_parser.py
@@ -34,6 +34,11 @@ def _generate_tool_id() -> str:
return f"call_{uuid.uuid4().hex[:8]}"
+def _same_tool_call(left: dict[str, str], right: dict[str, str]) -> bool:
+ """Return True when two parsed tool calls are semantically identical."""
+ return left["name"] == right["name"] and left["arguments"] == right["arguments"]
+
+
# Pattern: <|channel|>commentary to=functions.tool_name ... <|call|>
_COMMENTARY_BLOCK_PATTERN = re.compile(
r"<\|channel\|>commentary\s+to=functions\.(\w+)"
@@ -82,26 +87,26 @@ def extract_tool_calls(
try:
arguments = json.loads(args_str)
- tool_calls.append(
- {
- "id": _generate_tool_id(),
- "name": tool_name,
- "arguments": (
- json.dumps(arguments, ensure_ascii=False)
- if isinstance(arguments, dict)
- else str(arguments)
- ),
- }
- )
+ parsed_call = {
+ "id": _generate_tool_id(),
+ "name": tool_name,
+ "arguments": (
+ json.dumps(arguments, ensure_ascii=False)
+ if isinstance(arguments, dict)
+ else str(arguments)
+ ),
+ }
except json.JSONDecodeError:
# Keep the raw arguments string
- tool_calls.append(
- {
- "id": _generate_tool_id(),
- "name": tool_name,
- "arguments": args_str,
- }
- )
+ parsed_call = {
+ "id": _generate_tool_id(),
+ "name": tool_name,
+ "arguments": args_str,
+ }
+
+ if tool_calls and _same_tool_call(parsed_call, tool_calls[-1]):
+ continue
+ tool_calls.append(parsed_call)
# Extract final channel content
final_match = _FINAL_BLOCK_PATTERN.search(model_output)
@@ -143,13 +148,16 @@ def extract_tool_calls_streaming(
channel content as regular content deltas.
"""
# If we see a tool call completion marker in the delta
- if "<|call|>" in delta_text:
- result = self.extract_tool_calls(current_text)
- if result.tools_called:
+ if "<|call|>" in current_text:
+ previous_result = self.extract_tool_calls(previous_text)
+ current_result = self.extract_tool_calls(current_text)
+ new_tool_calls = current_result.tool_calls[len(previous_result.tool_calls) :]
+ if new_tool_calls:
return {
+ "complete": True,
"tool_calls": [
{
- "index": i,
+ "index": len(previous_result.tool_calls) + i,
"id": tc["id"],
"type": "function",
"function": {
@@ -157,7 +165,7 @@ def extract_tool_calls_streaming(
"arguments": tc["arguments"],
},
}
- for i, tc in enumerate(result.tool_calls)
+ for i, tc in enumerate(new_tool_calls)
]
}
diff --git a/vllm_mlx/tool_parsers/qwen_tool_parser.py b/vllm_mlx/tool_parsers/qwen_tool_parser.py
index fd69b96c..5348d53c 100644
--- a/vllm_mlx/tool_parsers/qwen_tool_parser.py
+++ b/vllm_mlx/tool_parsers/qwen_tool_parser.py
@@ -40,6 +40,17 @@ class QwenToolParser(ToolParser):
# Pattern for XML-style: {"json"}
XML_PATTERN = re.compile(r"\s*(\{.*?\})\s*", re.DOTALL)
+ # Pattern for XML-style function blocks:
+ # value
+ FUNCTION_PATTERN = re.compile(
+ r"\s*]+)>(.*?)\s*",
+ re.DOTALL,
+ )
+ PARAM_PATTERN = re.compile(
+ r"]+)>\s*(.*?)\s*",
+ re.DOTALL,
+ )
+
# Pattern for bracket-style: [Calling tool: func_name({...})]
BRACKET_PATTERN = re.compile(r"\[Calling tool:\s*(\w+)\((\{.*?\})\)\]", re.DOTALL)
@@ -101,6 +112,29 @@ def extract_tool_calls(
if xml_matches:
cleaned_text = self.XML_PATTERN.sub("", cleaned_text).strip()
+ # Try function/parameter pattern used by Qwen tool templates
+ function_matches = self.FUNCTION_PATTERN.findall(cleaned_text)
+ for name, params_block in function_matches:
+ arguments = {}
+ for param_name, param_value in self.PARAM_PATTERN.findall(params_block):
+ raw_value = param_value.strip()
+ try:
+ arguments[param_name.strip()] = json.loads(raw_value)
+ except json.JSONDecodeError:
+ arguments[param_name.strip()] = raw_value
+
+ if name.strip():
+ tool_calls.append(
+ {
+ "id": generate_tool_id(),
+ "name": name.strip(),
+ "arguments": json.dumps(arguments, ensure_ascii=False),
+ }
+ )
+
+ if function_matches:
+ cleaned_text = self.FUNCTION_PATTERN.sub("", cleaned_text).strip()
+
if tool_calls:
return ExtractedToolCallInformation(
tools_called=True,
@@ -135,11 +169,12 @@ def extract_tool_calls_streaming(
# If we're in a tool call, accumulate and parse at the end
# For simplicity, return None during accumulation
- if "" in delta_text or ")]" in delta_text:
+ if "" in current_text or ")]" in current_text:
# Tool call complete, parse the whole thing
result = self.extract_tool_calls(current_text)
if result.tools_called:
return {
+ "complete": True,
"tool_calls": [
{
"index": i,