diff --git a/tests/reasoning/test_qwen3_tool_call_recovery.py b/tests/reasoning/test_qwen3_tool_call_recovery.py new file mode 100644 index 000000000000..29e54eb92907 --- /dev/null +++ b/tests/reasoning/test_qwen3_tool_call_recovery.py @@ -0,0 +1,207 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Tests for Qwen3 reasoning parser tool-call recovery (Issue #39056). + +These tests verify that XML tool-call blocks emitted inside are +correctly promoted into content so the downstream Qwen3CoderToolParser +can parse them — and that any pre-existing response text is preserved. +""" + +import json + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.reasoning import ReasoningParser, ReasoningParserManager +from vllm.tool_parsers.qwen3coder_tool_parser import Qwen3CoderToolParser + +parser_name = "qwen3" + +_TOOL_CALL_BLOCK = ( + "\n" + "\n" + "\n" + "204\n" + "\n" + "\n" + "" +) + + +class _FakeQwen3ToolTokenizer: + """Minimal tokenizer stub sufficient for parser construction.""" + + def get_vocab(self) -> dict[str, int]: + return { + "": 1, + "": 2, + "": 3, + "": 4, + } + + +def _make_parser() -> ReasoningParser: + return ReasoningParserManager.get_reasoning_parser(parser_name)( + _FakeQwen3ToolTokenizer() + ) + + +def _make_request() -> ChatCompletionRequest: + return ChatCompletionRequest(messages=[], model="test-model") + + +# --------------------------------------------------------------------------- +# Basic promotion: tool call extracted from reasoning, placed into content +# --------------------------------------------------------------------------- + +def test_embedded_tool_call_is_promoted_from_reasoning_into_content(): + """Tool-call block inside must move to content, not stay in reasoning.""" + parser = _make_parser() + request = _make_request() + + reasoning, content = parser.extract_reasoning( + "The verification confirms my solution:\n" + "- s = 2.5 km/h\n" + "- t = 24 minutes\n" + "- Total time at speed 3 km/h = 204 minutes\n\n" + + _TOOL_CALL_BLOCK + + "\n", + request=request, + ) + + assert reasoning is not None + assert "" not in reasoning, "tool call must not remain in reasoning" + assert content is not None + assert "" in content, "tool call must be present in content" + assert "" in content + + +# --------------------------------------------------------------------------- +# Ordering fix: existing content text must be preserved after promotion +# --------------------------------------------------------------------------- + +def test_existing_content_text_is_preserved_after_tool_call_promotion(): + """ + Pre-existing response text must appear BEFORE the promoted tool-call block + in content. The Qwen3CoderToolParser reads content up to the first tool + marker as the human-readable reply; if the tool block were prepended that + text would be silently discarded. + """ + parser = _make_parser() + request = _make_request() + + # Model emits tool call inside , then text after + _, content = parser.extract_reasoning( + "verify result\n" + + _TOOL_CALL_BLOCK + + "Here is the answer.", + request=request, + ) + + assert content is not None + assert "Here is the answer." in content, "post- text must be preserved" + + # The text must come BEFORE the tool call so the tool parser keeps it + text_pos = content.index("Here is the answer.") + tool_pos = content.index("") + assert text_pos < tool_pos, ( + "existing response text must appear before the promoted tool call " + f"(text at {text_pos}, tool at {tool_pos})" + ) + + +# --------------------------------------------------------------------------- +# End-to-end: promoted content must be parseable by Qwen3CoderToolParser +# --------------------------------------------------------------------------- + +def test_promoted_tool_call_is_parseable_by_qwen3coder_and_trailing_text_preserved(): + """ + Full pipeline: reasoning parser promotes the tool call, then + Qwen3CoderToolParser extracts it. Trailing assistant text must survive. + """ + parser = _make_parser() + tool_parser = Qwen3CoderToolParser(_FakeQwen3ToolTokenizer(), tools=None) + request = _make_request() + + _, content = parser.extract_reasoning( + "verify result\n" + + _TOOL_CALL_BLOCK + + "assistant trailing text", + request=request, + ) + + assert content is not None + assert "assistant trailing text" in content + + tool_call_info = tool_parser.extract_tool_calls(content, request=request) + + assert tool_call_info.tools_called is True + assert len(tool_call_info.tool_calls) == 1 + tool_call = tool_call_info.tool_calls[0] + assert tool_call.function.name == "Finish" + assert json.loads(tool_call.function.arguments) == {"answer": "204"} + + # FIX for reviewer comment: verify trailing text is preserved in the + # final extracted content field, not discarded by the tool parser. + assert tool_call_info.content is not None, ( + "trailing response text must be preserved by Qwen3CoderToolParser" + ) + assert "assistant trailing text" in tool_call_info.content + + +# --------------------------------------------------------------------------- +# Truncated output: no , but tool call still recoverable +# --------------------------------------------------------------------------- + +def test_truncated_reasoning_still_recovers_embedded_tool_call(): + """When output is cut off before , embedded tool calls still promote.""" + parser = _make_parser() + request = _make_request() + + reasoning, content = parser.extract_reasoning( + "verify result\n" + _TOOL_CALL_BLOCK, + request=request, + ) + + assert reasoning is not None + assert "" not in reasoning + assert content is not None + assert "" in content + + +# --------------------------------------------------------------------------- +# Normal reasoning: no tool call → unchanged behaviour +# --------------------------------------------------------------------------- + +def test_normal_reasoning_extraction_unchanged(): + """Reasoning without any tool call must pass through unmodified.""" + parser = _make_parser() + request = _make_request() + + raw_reasoning = "Let me think about this carefully.\nThe answer is 42." + reasoning, content = parser.extract_reasoning( + f"{raw_reasoning}The answer is 42.", + request=request, + ) + + assert reasoning == raw_reasoning + assert content == "The answer is 42." + + +# --------------------------------------------------------------------------- +# No regression: post- content preserved without tool call +# --------------------------------------------------------------------------- + +def test_post_think_content_preserved_without_tool_call(): + """Content after must be returned verbatim when no tool call.""" + parser = _make_parser() + request = _make_request() + + _, content = parser.extract_reasoning( + "some reasoningplain response text", + request=request, + ) + + assert content == "plain response text" diff --git a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py index e38b0de3d822..36f73a10767c 100644 --- a/vllm/reasoning/qwen3_reasoning_parser.py +++ b/vllm/reasoning/qwen3_reasoning_parser.py @@ -1,18 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Iterable, Sequence +import re +from collections.abc import Sequence from typing import TYPE_CHECKING from vllm.entrypoints.openai.engine.protocol import DeltaMessage from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser if TYPE_CHECKING: - from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.tokenizers import TokenizerLike +_EMBEDDED_TOOL_CALL_RE = re.compile( + r"(.*?)|.*$", + re.DOTALL, +) + + class Qwen3ReasoningParser(BaseThinkingReasoningParser): """ Reasoning parser for the Qwen3/Qwen3.5 model family. @@ -31,25 +40,15 @@ class Qwen3ReasoningParser(BaseThinkingReasoningParser): use an older chat template where the model generates itself. This parser handles both styles: if appears in the generated output it is stripped before extraction (non-streaming) or skipped (streaming). - - NOTE: Qwen3.5 models may emit inside the thinking block - without closing first. is treated as an implicit - end of reasoning, matching the approach in KimiK2ReasoningParser. """ def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs): super().__init__(tokenizer, *args, **kwargs) - chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} # Qwen3 defaults to thinking enabled; only treat output as # pure content when the user explicitly disables it. self.thinking_enabled = chat_kwargs.get("enable_thinking", True) - self._tool_call_tag = "" - self._tool_call_token_id = self.vocab.get(self._tool_call_tag) - self._tool_call_end_tag = "" - self._tool_call_end_token_id = self.vocab.get(self._tool_call_end_tag) - @property def start_token(self) -> str: """The token that starts reasoning content.""" @@ -60,60 +59,66 @@ def end_token(self) -> str: """The token that ends reasoning content.""" return "" - def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: - start_token_id = self.start_token_id - end_token_id = self.end_token_id - tool_call_token_id = self._tool_call_token_id - tool_call_end_token_id = self._tool_call_end_token_id - - for i in range(len(input_ids) - 1, -1, -1): - token_id = input_ids[i] - if token_id == start_token_id: - # Found before or - return False - if token_id == end_token_id: - return True - if tool_call_token_id is not None and token_id == tool_call_token_id: - # Only treat as implicit reasoning end if this - # is NOT followed by . Paired occurrences are - # template examples in the prompt, not model output. - if tool_call_end_token_id is not None and any( - input_ids[j] == tool_call_end_token_id - for j in range(i + 1, len(input_ids)) - ): - continue - return True - return False - - def is_reasoning_end_streaming( - self, input_ids: Sequence[int], delta_ids: Iterable[int] - ) -> bool: - if super().is_reasoning_end_streaming(input_ids, delta_ids): - return True - if self._tool_call_token_id is not None: - return self._tool_call_token_id in delta_ids - return False - - def extract_content_ids(self, input_ids: list[int]) -> list[int]: - """ - Extract content token ids from the input_ids. + @staticmethod + def _split_embedded_tool_calls( + reasoning: str | None, + content: str | None, + ) -> tuple[str | None, str | None]: + """Promote tool-call XML blocks out of reasoning into content. + + Qwen3.5 models can emit XML tool calls before ``. The serving + stack parses reasoning before tool calls, so these embedded tool calls + would otherwise be lost because downstream tool parsers only inspect + the content channel. + + Tool-call blocks are APPENDED to any existing content so that + pre-existing response text (which comes before the first tool marker) + is preserved by the downstream Qwen3CoderToolParser. Prepending + would place existing text *after* the tool marker, causing the tool + parser to discard it when it extracts content up to the first marker. """ - result = super().extract_content_ids(input_ids) - if result: - return result - # Fall back: content starts at (implicit reasoning end). if ( - self._tool_call_token_id is not None - and self._tool_call_token_id in input_ids + not reasoning + or "" not in reasoning + or " str: + block = match.group(0) + if " tuple[str | None, str | None]: """ Extract reasoning content from the model output. @@ -125,6 +130,7 @@ def extract_reasoning( When thinking is explicitly disabled and no appears, returns (None, model_output) — all output is content. + Otherwise (thinking enabled, default), a missing means the output was truncated and everything is reasoning: returns (model_output, None). @@ -132,30 +138,27 @@ def extract_reasoning( Returns: tuple[Optional[str], Optional[str]]: reasoning content and content """ - # Strip if present in the generated output. model_output_parts = model_output.partition(self.start_token) model_output = ( - model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + model_output_parts[2] + if model_output_parts[1] + else model_output_parts[0] ) - if self.end_token in model_output: - reasoning, _, content = model_output.partition(self.end_token) - return reasoning, content or None + if self.end_token not in model_output: + if not self.thinking_enabled: + # Thinking explicitly disabled — treat everything as content. + return None, model_output + # Thinking enabled but no : output was truncated. + # Everything generated so far is reasoning. + return self._split_embedded_tool_calls(model_output, None) - if not self.thinking_enabled: - # Thinking explicitly disabled — treat everything as content. - return None, model_output + # Extract reasoning content from the model output. + reasoning, _, content = model_output.partition(self.end_token) - # No — check for implicit reasoning end via . - tool_call_index = model_output.find(self._tool_call_tag) - if tool_call_index != -1: - reasoning = model_output[:tool_call_index] - content = model_output[tool_call_index:] - return reasoning or None, content or None - # Thinking enabled but no : output was truncated. - # Everything generated so far is reasoning. - return model_output, None + final_content = content or None + return self._split_embedded_tool_calls(reasoning, final_content) def extract_reasoning_streaming( self, @@ -183,14 +186,14 @@ def extract_reasoning_streaming( if self.start_token_id in delta_token_ids: start_idx = delta_text.find(self.start_token) if start_idx >= 0: - delta_text = delta_text[start_idx + len(self.start_token) :] + delta_text = delta_text[start_idx + len(self.start_token):] if self.end_token_id in delta_token_ids: # End token in this delta: split reasoning from content. end_index = delta_text.find(self.end_token) if end_index >= 0: reasoning = delta_text[:end_index] - content = delta_text[end_index + len(self.end_token) :] + content = delta_text[end_index + len(self.end_token):] if not reasoning and not content: return None return DeltaMessage( @@ -200,20 +203,6 @@ def extract_reasoning_streaming( # end_token_id in IDs but not in text (already stripped) return None - # Implicit reasoning end via . - if ( - self._tool_call_token_id is not None - and self._tool_call_token_id in delta_token_ids - ): - tool_index = delta_text.find(self._tool_call_tag) - if tool_index >= 0: - reasoning = delta_text[:tool_index] - content = delta_text[tool_index:] - return DeltaMessage( - reasoning=reasoning if reasoning else None, - content=content if content else None, - ) - # No end token in this delta. if not delta_text: # Nothing left after stripping start token. @@ -221,11 +210,6 @@ def extract_reasoning_streaming( elif self.end_token_id in previous_token_ids: # End token already passed: everything is content now. return DeltaMessage(content=delta_text) - elif ( - self._tool_call_token_id is not None - and self._tool_call_token_id in previous_token_ids - ): - return DeltaMessage(content=delta_text) else: # No end token yet: still in reasoning phase. return DeltaMessage(reasoning=delta_text)