diff --git a/tests/test_tool_call_promotion.py b/tests/test_tool_call_promotion.py new file mode 100644 index 00000000..187b0375 --- /dev/null +++ b/tests/test_tool_call_promotion.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for tool call promotion from reasoning to content.""" + +import logging + +import pytest + +from vllm_mlx.reasoning import get_parser + + +@pytest.fixture +def parser(): + cls = get_parser("qwen3") + return cls() + + +class TestNonStreamingPromotion: + """Non-streaming extract_reasoning() promotes from reasoning.""" + + def test_closed_tool_call_inside_think_appended(self, parser): + """Test case 1: closed block appended to content.""" + output = ( + "I should check the weather.\n" + "\n" + "\n" + "Tokyo\n" + "\n" + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning is not None + assert "I should check the weather" in reasoning + assert "" not in reasoning + assert content is not None + assert "" in content + assert "get_weather" in content + + def test_tool_call_after_think_unchanged(self, parser): + """Test case 2: tool call in content stays in content.""" + output = ( + "Let me think about this.\n" + "\n" + "\n" + "Tokyo\n" + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert reasoning == "Let me think about this." + assert content is not None + assert "" in content + + def test_multiple_tool_calls_inside_think(self, parser): + """Test case 3: multiple closed blocks all appended.""" + output = ( + "I need two lookups.\n" + "\n" + "Tokyo\n" + "\n" + "Now the second one.\n" + "\n" + "JST\n" + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert "" not in reasoning + assert "I need two lookups" in reasoning + assert "Now the second one" in reasoning + assert content is not None + assert content.count("") == 2 + + def test_truncated_unclosed_tool_call_prepended(self, parser): + """Test case 4: unclosed tool call in reasoning prepended to content.""" + output = ( + "Let me call the API.\n" + "\n" + "\n" + "Tokyo\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert content is not None + assert "" in content + assert "Let me call the API" in (reasoning or "") + + def test_hermes_json_tool_call(self, parser): + """Test case 5: Hermes JSON format promoted.""" + output = ( + "I will check.\n" + "\n" + '{"name": "get_weather", "arguments": {"city": "Tokyo"}}\n' + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert "" not in reasoning + assert content is not None + assert "get_weather" in content + + def test_prose_mention_not_promoted(self, parser): + """Test case 6: prose mentioning without structure stays.""" + output = ( + "The model should use to invoke functions. " + "Then it should verify.\n" + "The answer is 42." + ) + reasoning, content = parser.extract_reasoning(output) + assert "should use" in reasoning + assert content == "The answer is 42." + + def test_content_none_handled(self, parser): + """Test case 7: content=None when reasoning is entire output.""" + output = ( + "Checking.\n" + "\n" + "test\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert content is not None + assert "" in content + + def test_tool_call_spanning_think_boundary(self, parser): + """Test case 8: unclosed portion prepended, reassembles with content.""" + output = ( + "R\n" + "\n" + "" + "1\n" + "\n" + "\n" + "C" + ) + reasoning, content = parser.extract_reasoning(output) + assert content is not None + assert "" in content + assert "" in content + + def test_closed_appended_preserves_existing_content(self, parser): + """Closed block appended after existing post-think content.""" + output = ( + "Let me check.\n" + "\n" + "test\n" + "\n" + "\n" + "Here is my answer." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Here is my answer." in content + assert "" in content + assert content.index("Here is my answer") < content.index("") + + def test_promotion_logs_warning(self, parser, caplog): + """Promotion should log a warning for operators.""" + with caplog.at_level(logging.WARNING): + output = ( + "\n" + "\n" + "1\n" + "\n" + "\n" + ) + parser.extract_reasoning(output) + assert any("tool_call" in r.message.lower() for r in caplog.records) + + def test_no_tool_calls_no_warning(self, parser, caplog): + """No promotion -> no warning logged.""" + with caplog.at_level(logging.WARNING): + output = "Just reasoning.\nContent." + parser.extract_reasoning(output) + assert not any("tool_call" in r.message.lower() for r in caplog.records) + + +class TestStreamingPromotion: + """Streaming extract_reasoning_streaming() promotes tool calls. + + Uses full-text or large chunk sizes. The upstream streaming parser + uses current_text/previous_text tag detection which has imprecision + when tags span chunk boundaries at very small chunk sizes. + """ + + def _stream(self, parser, text, chunk_size=None): + """Feed text through streaming parser in chunks.""" + if chunk_size is None: + chunk_size = len(text) + parser.reset_state() + reasoning_parts = [] + content_parts = [] + accumulated = "" + for i in range(0, len(text), chunk_size): + chunk = text[i : i + chunk_size] + previous = accumulated + accumulated += chunk + delta = parser.extract_reasoning_streaming(previous, accumulated, chunk) + if delta: + if delta.reasoning: + reasoning_parts.append(delta.reasoning) + if delta.content: + content_parts.append(delta.content) + final = parser.finalize_stream() + if final: + if final.reasoning: + reasoning_parts.append(final.reasoning) + if final.content: + content_parts.append(final.content) + return "".join(reasoning_parts) or None, "".join(content_parts) or None + + def test_stream_tool_call_inside_think(self, parser): + """Tool call promoted as content during streaming.""" + text = ( + "I should check.\n" + "\n" + "\n" + "Tokyo\n" + "\n" + "\n" + "\n" + "Final answer." + ) + reasoning, content = self._stream(parser, text) + assert reasoning is not None + assert "I should check" in reasoning + assert content is not None + assert "" in content + assert "get_weather" in content + assert "Final answer." in content + + def test_stream_think_ends_while_buffering(self, parser): + """ before flushes as content.""" + text = ( + "Check.\n" + "\n" + "test\n" + "\n" + "Done." + ) + reasoning, content = self._stream(parser, text) + assert content is not None + assert "" in content + + def test_stream_finalize_with_buffered_tool_call(self, parser): + """Stream ends mid-tool-call, flushed as content.""" + text = ( + "\n" + "\n" + "1\n" + ) + reasoning, content = self._stream(parser, text) + assert content is not None + assert "" in content + + def test_stream_multiple_tool_calls(self, parser): + """Each tool call promoted independently.""" + text = ( + "Two calls.\n" + "\n" + "1\n" + "\n" + "Middle reasoning.\n" + "\n" + "2\n" + "\n" + "\n" + ) + reasoning, content = self._stream(parser, text) + assert "Two calls" in (reasoning or "") + assert "Middle reasoning" in (reasoning or "") + assert content is not None + assert content.count("") == 2 + + def test_stream_large_chunks(self, parser): + """Promotion correct at chunk sizes that don't split tags.""" + text = ( + "Check.\n" + "\n" + "1\n" + "\n" + "\nDone." + ) + for cs in [20, 50, len(text)]: + reasoning, content = self._stream(parser, text, chunk_size=cs) + assert content is not None, f"chunk_size={cs}" + assert "" in content, f"chunk_size={cs}" + + def test_stream_tool_call_closed_immediately_before_think_end(self, parser): + """ with no trailing content.""" + text = ( + "R\n" + "\n" + "1\n" + "" + ) + reasoning, content = self._stream(parser, text) + assert content is not None + assert "" in content + + def test_stream_single_delta_promotion_via_transition(self, parser): + """Single large delta hits _transition_to_content catch-all.""" + text = ( + "R\n" + "\n" + "1\n" + "\n" + "More reasoning.\n" + "\n" + "Content." + ) + reasoning, content = self._stream(parser, text) + assert content is not None + assert "" in content + assert "Content." in content + + def test_stream_no_tool_calls_regression(self, parser): + """Normal reasoning unchanged when streamed as full text.""" + text = "Just thinking here.\nThe answer is 42." + reasoning, content = self._stream(parser, text) + assert "Just thinking here." in (reasoning or "") + assert content is not None + assert "The answer is 42." in content + + +class TestComposition: + """End-to-end: reasoning parser + tool parser compose correctly.""" + + def test_promoted_parsed_by_tool_parser(self, parser): + """Tool parser finds promoted calls.""" + pytest.importorskip("transformers") + from vllm_mlx.tool_parsers import ToolParserManager + + output = ( + "Let me look this up.\n" + "\n" + "\n" + "Tokyo\n" + "\n" + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert content is not None + + for parser_name in ["qwen3_xml", "qwen3.5", "qwen", "qwen3"]: + try: + tool_cls = ToolParserManager.get_tool_parser(parser_name) + break + except KeyError: + continue + else: + pytest.skip("No Qwen tool parser registered") + + tool_parser = tool_cls(None) + result = tool_parser.extract_tool_calls(content) + assert result.tools_called + assert len(result.tool_calls) >= 1 + assert result.tool_calls[0]["name"] == "get_weather" + + def test_promoted_preserves_trailing_content(self, parser): + """Closed appended after existing content.""" + output = ( + "Checking.\n" + "\n" + "test\n" + "\n" + "\n" + "Based on results, here is my answer." + ) + reasoning, content = parser.extract_reasoning(output) + assert "Based on results" in content + assert "" in content + + def test_tool_choice_required_with_promotion(self, parser): + """Required tool choice + promoted content.""" + output = ( + "User requires a tool call.\n" + "\n" + "1\n" + "\n" + "\n" + ) + reasoning, content = parser.extract_reasoning(output) + assert content is not None + assert "mandatory_fn" in content diff --git a/vllm_mlx/reasoning/think_parser.py b/vllm_mlx/reasoning/think_parser.py index 4c7f9719..dfdec903 100644 --- a/vllm_mlx/reasoning/think_parser.py +++ b/vllm_mlx/reasoning/think_parser.py @@ -17,10 +17,14 @@ whole-output rescanning behavior. """ +import logging +import re from abc import abstractmethod from .base import DeltaMessage, ReasoningParser +logger = logging.getLogger(__name__) + class BaseThinkingReasoningParser(ReasoningParser): """ @@ -51,18 +55,28 @@ def start_token(self) -> str: def end_token(self) -> str: """The token/tag that ends reasoning content (e.g., '').""" + _TOOL_CALL_START = "" + _TOOL_CALL_END = "" + _TOOL_CALL_CLOSED_RE = re.compile(r"(.*?)", re.DOTALL) + _TOOL_CALL_UNCLOSED_RE = re.compile(r"\s*[\{<].*$", re.DOTALL) + def __init__(self, tokenizer=None): super().__init__(tokenizer) # Streaming state — reset per request via reset_state() self._phase: str = "pre_think" # "pre_think" | "thinking" | "content" self._content_started = False self._content_buffer = "" + # Tool call promotion state. + self._in_tool_call = False + self._tool_call_buffer = "" def reset_state(self): """Reset state machine for a new streaming request.""" self._phase = "pre_think" self._content_started = False self._content_buffer = "" + self._in_tool_call = False + self._tool_call_buffer = "" def extract_reasoning( self, @@ -84,19 +98,15 @@ def extract_reasoning( """ text = model_output - # Cases 1 and 2: consume one or more leading reasoning spans. Some - # thinking models emit an extra empty ```` block after - # the forced transition; that block still belongs to reasoning, not - # final content. if self.end_token in text: - return self._extract_complete_reasoning(text) + reasoning, content = self._extract_complete_reasoning(text) + return self._promote_tool_calls(reasoning, content) - # Case 3: Only start tag (incomplete reasoning, no end yet) if self.start_token in text: _, _, reasoning = text.partition(self.start_token) - return reasoning.strip() or None, None + reasoning = reasoning.strip() or None + return self._promote_tool_calls(reasoning, None) - # Case 4: No tags at all — pure content return None, model_output def extract_reasoning_streaming( @@ -149,6 +159,15 @@ def extract_reasoning_streaming( reasoning = after[:eidx] content = after[eidx + len(end_tok) :] return self._transition_to_content(reasoning, content) + + tc_start = self._TOOL_CALL_START + if tc_start in after: + tc_idx = after.find(tc_start) + self._in_tool_call = True + self._tool_call_buffer = after[tc_idx:] + before = after[:tc_idx] + return DeltaMessage(reasoning=before) if before else None + return DeltaMessage(reasoning=after) if after else None # Implicit mode: completed without an explicit . @@ -171,7 +190,23 @@ def extract_reasoning_streaming( # ── Phase: thinking ─────────────────────────────────────── # Inside a reasoning block, waiting for end tag. + # Also detects blocks and promotes them to content. if self._phase == "thinking": + if self._in_tool_call: + return self._thinking_tool_call(previous_text, current_text, delta_text) + + tc_start = self._TOOL_CALL_START + if tc_start in current_text and tc_start not in previous_text: + self._in_tool_call = True + idx = delta_text.find(tc_start) + if idx >= 0: + reasoning = delta_text[:idx] + self._tool_call_buffer = delta_text[idx:] + else: + self._tool_call_buffer = tc_start + reasoning = delta_text + return DeltaMessage(reasoning=reasoning) if reasoning else None + if end_tok in current_text and end_tok not in previous_text: self._phase = "content" idx = delta_text.find(end_tok) @@ -228,6 +263,7 @@ def _transition_to_content( self, reasoning: str | None, content: str | None ) -> DeltaMessage | None: """Return a delta while suppressing leading post-transition think blocks.""" + reasoning, content = self._promote_tool_calls(reasoning, content) content_msg = self._content_delta(content or "") extra_reasoning = content_msg.reasoning if content_msg else None final_content = content_msg.content if content_msg else None @@ -287,3 +323,140 @@ def _content_delta(self, delta_text: str) -> DeltaMessage | None: if reasoning_parts: return DeltaMessage(reasoning="".join(reasoning_parts)) return None + + def _thinking_tool_call( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """Handle streaming while inside a during thinking phase.""" + tc_end = self._TOOL_CALL_END + end_tok = self.end_token + + if tc_end in current_text and tc_end not in previous_text: + self._tool_call_buffer += delta_text + idx = self._tool_call_buffer.find(tc_end) + promoted = self._tool_call_buffer[: idx + len(tc_end)] + remainder = self._tool_call_buffer[idx + len(tc_end) :] + self._tool_call_buffer = "" + self._in_tool_call = False + logger.warning("Promoted streaming tool_call block from reasoning") + + if end_tok in remainder: + self._phase = "content" + eidx = remainder.find(end_tok) + reasoning = remainder[:eidx].strip() or None + after_think = remainder[eidx + len(end_tok) :] + content_msg = self._content_delta(after_think) if after_think else None + final_content = promoted + ( + (content_msg.content or "") if content_msg else "" + ) + extra_r = content_msg.reasoning if content_msg else None + r_text = (reasoning or "") + (extra_r or "") + return DeltaMessage( + content=final_content or None, + reasoning=r_text or None, + ) + + tc_start = self._TOOL_CALL_START + if tc_start in remainder: + tc_idx = remainder.find(tc_start) + self._in_tool_call = True + self._tool_call_buffer = remainder[tc_idx:] + reasoning = remainder[:tc_idx].strip() or None + return DeltaMessage(content=promoted, reasoning=reasoning) + + reasoning = remainder.strip() or None + return DeltaMessage(content=promoted, reasoning=reasoning) + + if end_tok in current_text and end_tok not in previous_text: + self._tool_call_buffer += delta_text + self._in_tool_call = False + self._phase = "content" + logger.warning( + "Promoted unclosed streaming tool_call " + "(think ended before tool_call closed)" + ) + idx = self._tool_call_buffer.find(end_tok) + if idx >= 0: + promoted = self._tool_call_buffer[:idx] + after = self._tool_call_buffer[idx + len(end_tok) :] + else: + promoted = self._tool_call_buffer + after = "" + self._tool_call_buffer = "" + content_msg = self._content_delta(after) if after else None + final_content = ( + promoted + (content_msg.content or "") if content_msg else promoted + ) + return DeltaMessage(content=final_content or None) + + self._tool_call_buffer += delta_text + return None + + def finalize_stream(self) -> DeltaMessage | None: + """Flush any buffered tool call text at end of stream.""" + if self._in_tool_call and self._tool_call_buffer: + promoted = self._tool_call_buffer + self._tool_call_buffer = "" + self._in_tool_call = False + logger.warning("Promoted unclosed streaming tool_call at stream end") + return DeltaMessage(content=promoted) + return None + + @classmethod + def _promote_tool_calls( + cls, reasoning: str | None, content: str | None + ) -> tuple[str | None, str | None]: + if not reasoning or "" not in reasoning: + return reasoning, content + + # Closed regex first: extract complete ... blocks. + # Then unclosed regex on the already-stripped reasoning. + closed: list[str] = [] + + def _collect_closed(match): + closed.append(match.group(0)) + return "" + + cleaned = cls._TOOL_CALL_CLOSED_RE.sub(_collect_closed, reasoning) + + unclosed_match = cls._TOOL_CALL_UNCLOSED_RE.search(cleaned) + unclosed_block = None + if unclosed_match: + unclosed_block = unclosed_match.group(0) + cleaned = cleaned[: unclosed_match.start()] + + cleaned = cleaned.strip() or None + promoted_count = len(closed) + (1 if unclosed_block else 0) + + if promoted_count == 0: + return reasoning, content + + result_content = content or "" + + if unclosed_block: + result_content = ( + unclosed_block + "\n" + result_content + if result_content + else unclosed_block + ) + + if closed: + closed_text = "\n".join(closed) + result_content = ( + result_content + "\n" + closed_text if result_content else closed_text + ) + + result_content = result_content.strip() or None + + logger.warning( + "Promoted %d tool_call block(s) from reasoning to content " + "(%d closed, %d unclosed)", + promoted_count, + len(closed), + 1 if unclosed_block else 0, + ) + + return cleaned, result_content