Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,177 @@ def test_rate_limiter_window_cleanup(self):
class TestStreamChatCompletion:
"""Tests for streaming chat completion behavior."""

@pytest.mark.anyio
async def test_stream_without_parser_flags_emits_structured_tool_calls(
self, monkeypatch
):
"""Streaming tools should still parse without explicit parser flags."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(text="", new_text="<tool_call>", finished=False),
GenerationOutput(
text="",
new_text="<function=list_directory>",
finished=False,
),
GenerationOutput(
text="",
new_text="<parameter=path>/Users/testuser</parameter>",
finished=False,
),
GenerationOutput(
text="",
new_text="</function>",
finished=False,
),
GenerationOutput(
text="",
new_text="</tool_call>",
finished=True,
finish_reason="stop",
prompt_tokens=5,
completion_tokens=7,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", False)
monkeypatch.setattr(server, "_tool_call_parser", None)
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[Message(role="user", content="hi")],
tools=[
{
"type": "function",
"function": {
"name": "list_directory",
"description": "List files in a directory",
"parameters": {
"type": "object",
"properties": {"path": {"type": "string"}},
"required": ["path"],
},
},
}
],
stream=True,
)

chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]

payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]
tool_payloads = [
payload
for payload in payloads
if payload["choices"] and payload["choices"][0]["delta"].get("tool_calls")
]

assert len(tool_payloads) == 1
delta = tool_payloads[0]["choices"][0]["delta"]
assert delta["tool_calls"][0]["function"]["name"] == "list_directory"
assert delta["tool_calls"][0]["function"]["arguments"] == (
'{"path": "/Users/testuser"}'
)
assert delta["content"] is None
assert tool_payloads[0]["choices"][0]["finish_reason"] == "tool_calls"
assert tool_payloads[0]["usage"] == {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12,
}

@pytest.mark.anyio
async def test_stream_without_parser_flags_keeps_plain_text(self, monkeypatch):
"""Generic streaming fallback should not interfere with normal text."""
from vllm_mlx.engine.base import GenerationOutput
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
stream_chat_completion,
)
import vllm_mlx.server as server

class FakeEngine:
model_name = "fake-engine"

async def stream_chat(self, messages, **kwargs):
chunks = [
GenerationOutput(text="", new_text="hello ", finished=False),
GenerationOutput(
text="",
new_text="world",
finished=True,
finish_reason="stop",
prompt_tokens=4,
completion_tokens=2,
),
]
for chunk in chunks:
yield chunk

monkeypatch.setattr(server, "_model_name", "served-model")
monkeypatch.setattr(server, "_reasoning_parser", None)
monkeypatch.setattr(server, "_enable_auto_tool_choice", False)
monkeypatch.setattr(server, "_tool_call_parser", None)
monkeypatch.setattr(server, "_tool_parser_instance", None)

request = ChatCompletionRequest(
model="served-model",
messages=[Message(role="user", content="hi")],
tools=[
{
"type": "function",
"function": {
"name": "list_directory",
"description": "List files in a directory",
"parameters": {"type": "object", "properties": {}},
},
}
],
stream=True,
)

chunks = [
chunk
async for chunk in stream_chat_completion(
FakeEngine(), request.messages, request
)
]
payloads = [
json.loads(chunk.removeprefix("data: ").strip())
for chunk in chunks
if chunk != "data: [DONE]\n\n"
]

assert payloads[1]["choices"][0]["delta"]["content"] == "hello "
assert payloads[2]["choices"][0]["delta"]["content"] == "world"
assert payloads[2]["choices"][0]["finish_reason"] == "stop"

@pytest.mark.anyio
async def test_reasoning_stream_emits_structured_tool_calls(self, monkeypatch):
"""Tool markup after </think> should emit tool_calls chunks."""
Expand Down
146 changes: 101 additions & 45 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def _resolve_top_p(request_value: float | None) -> float:
# Safety net: the tool parser should consume these, but if it doesn't
# (e.g. malformed JSON, stray closing tags), strip them before emitting.
_TOOL_MARKUP_PATTERN = re.compile(r"</?tool_call>|</?tool_call_reasoning>")
_STREAMING_TOOL_MARKERS = (
"<tool_call>",
"<|tool_call>",
"<function=",
"[Calling tool:",
"[TOOL_CALLS]",
"<minimax:tool_call>",
'<invoke name="',
)


def _load_prefix_cache_from_disk() -> None:
Expand Down Expand Up @@ -580,6 +589,66 @@ def _detect_native_tool_support() -> bool:
return False


def _tool_choice_disabled(request: ChatCompletionRequest | None) -> bool:
"""Return True when tool_choice explicitly disables tool calling."""
if request is None:
return False

tool_choice = getattr(request, "tool_choice", None)
if tool_choice is None:
request_dict = request.model_dump()
tool_choice = request_dict.get("tool_choice")
return tool_choice == "none"


def _get_streaming_tool_parser(request: ChatCompletionRequest | None):
"""Get a streaming-capable tool parser for this request.

Uses the configured parser when auto tool choice is enabled, otherwise falls
back to the generic auto parser so streaming still matches the generic
non-streaming tool parsing behavior.
"""
global _tool_parser_instance

if request is None:
return None
if _tool_choice_disabled(request):
return None

tokenizer = None
if _engine is not None and hasattr(_engine, "_tokenizer"):
tokenizer = _engine._tokenizer

if _enable_auto_tool_choice and _tool_call_parser:
if _tool_parser_instance is None:
try:
parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
_tool_parser_instance = parser_cls(tokenizer)
logger.info(f"Initialized tool call parser: {_tool_call_parser}")
except Exception as e:
logger.warning(f"Failed to init tool parser for streaming: {e}")
return None
_tool_parser_instance.reset()
return _tool_parser_instance

if not getattr(request, "tools", None):
return None

try:
parser_cls = ToolParserManager.get_tool_parser("auto")
parser = parser_cls(tokenizer)
parser.reset()
return parser
except Exception as e:
logger.warning(f"Failed to init generic streaming tool parser: {e}")
return None


def _streaming_tool_markup_possible(text: str) -> bool:
"""Heuristic marker check to avoid parser work on ordinary text chunks."""
return any(marker in text for marker in _STREAMING_TOOL_MARKERS)


def load_embedding_model(
model_name: str | None,
*,
Expand Down Expand Up @@ -2294,24 +2363,10 @@ async def _stream_anthropic_messages(

# Tool call streaming suppression — prevents raw tool markup from leaking
# as text_delta events. Mirrors the OpenAI streaming path logic.
global _tool_parser_instance
tool_parser = None
tool_accumulated_text = ""
tool_markup_possible = False
tool_choice = getattr(openai_request, "tool_choice", None)
if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none":
if _tool_parser_instance is None:
try:
parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
tokenizer = None
if _engine is not None and hasattr(_engine, "_tokenizer"):
tokenizer = _engine._tokenizer
_tool_parser_instance = parser_cls(tokenizer)
except Exception:
pass
if _tool_parser_instance is not None:
tool_parser = _tool_parser_instance
tool_parser.reset()
tool_parser = _get_streaming_tool_parser(openai_request)

try:
async for output in engine.stream_chat(messages=messages, **chat_kwargs):
Expand Down Expand Up @@ -2341,7 +2396,12 @@ async def _stream_anthropic_messages(

# Filter tool call markup during streaming
if tool_parser and content_to_emit:
if not tool_markup_possible and "<" not in content_to_emit:
if (
not tool_markup_possible
and not _streaming_tool_markup_possible(
tool_accumulated_text + content_to_emit
)
):
tool_accumulated_text += content_to_emit
else:
if not tool_markup_possible:
Expand Down Expand Up @@ -2386,7 +2446,12 @@ async def _stream_anthropic_messages(

# Filter tool call markup during streaming
if tool_parser and content_to_emit:
if not tool_markup_possible and "<" not in content_to_emit:
if (
not tool_markup_possible
and not _streaming_tool_markup_possible(
tool_accumulated_text + content_to_emit
)
):
tool_accumulated_text += content_to_emit
else:
if not tool_markup_possible:
Expand Down Expand Up @@ -2618,27 +2683,11 @@ async def stream_chat_completion(
last_output = None

# Tool call streaming state
global _tool_parser_instance
tool_parser = None
tool_accumulated_text = ""
tool_calls_detected = False
tool_markup_possible = False # Fast path: skip parsing until '<' seen
tool_choice = getattr(request, "tool_choice", None)
if _enable_auto_tool_choice and _tool_call_parser and tool_choice != "none":
# Initialize parser if needed (same as _parse_tool_calls_with_parser)
if _tool_parser_instance is None:
try:
parser_cls = ToolParserManager.get_tool_parser(_tool_call_parser)
tokenizer = None
if _engine is not None and hasattr(_engine, "_tokenizer"):
tokenizer = _engine._tokenizer
_tool_parser_instance = parser_cls(tokenizer)
logger.info(f"Initialized tool call parser: {_tool_call_parser}")
except Exception as e:
logger.warning(f"Failed to init tool parser for streaming: {e}")
if _tool_parser_instance is not None:
tool_parser = _tool_parser_instance
tool_parser.reset()
tool_markup_possible = False # Fast path: skip parsing until markers appear
tool_parser = _get_streaming_tool_parser(request)

try:
# Stream content
Expand Down Expand Up @@ -2689,7 +2738,12 @@ async def stream_chat_completion(

# Tool call parsing on content portion
if tool_parser and content:
if not tool_markup_possible and "<" not in content:
if (
not tool_markup_possible
and not _streaming_tool_markup_possible(
tool_accumulated_text + content
)
):
tool_accumulated_text += content
# Suppress whitespace-only content when tools are active;
# avoids emitting stray newlines before tool call XML.
Expand Down Expand Up @@ -2799,10 +2853,16 @@ async def stream_chat_completion(

# Tool call streaming parsing
if tool_parser and delta_text:
# Fast path: skip full parsing until '<' is seen in the stream,
# which could start tool markup (e.g. <tool_call>). This avoids
# per-token string scanning on the growing accumulated text.
if not tool_markup_possible and "<" not in delta_text:
# Fast path: skip full parsing until likely tool markup appears.
# This preserves the cheap path for ordinary text while still
# allowing generic streaming tool parsing when no explicit
# parser flags are configured.
if (
not tool_markup_possible
and not _streaming_tool_markup_possible(
tool_accumulated_text + delta_text
)
):
tool_accumulated_text += delta_text
# No tool markup yet, fall through to normal chunk emission
else:
Expand Down Expand Up @@ -2883,11 +2943,7 @@ async def stream_chat_completion(
tool_parser
and tool_accumulated_text
and not tool_calls_detected
and (
"<tool_call>" in tool_accumulated_text
or "<|tool_call>" in tool_accumulated_text
or "<function" in tool_accumulated_text
)
and _streaming_tool_markup_possible(tool_accumulated_text)
):
final_parse_result = tool_parser.extract_tool_calls(tool_accumulated_text)
if final_parse_result.tools_called:
Expand Down
Loading