diff --git a/tests/reasoning/test_gemma4_reasoning_parser.py b/tests/reasoning/test_gemma4_reasoning_parser.py index 699fc509d828..c1dec93af755 100644 --- a/tests/reasoning/test_gemma4_reasoning_parser.py +++ b/tests/reasoning/test_gemma4_reasoning_parser.py @@ -111,7 +111,7 @@ def generic_tokenizer(): } THOUGHT_PREFIX_ONLY = { "output": "<|channel>thought\n", - "reasoning": "", + "reasoning": None, # empty thinking block → no reasoning_content emitted "content": None, "is_reasoning_end": True, } @@ -273,3 +273,38 @@ def test_gemma4_previous_turn_reasoning_is_reasoning_end(generic_tokenizer): ) is_reasoning_end = parser.is_reasoning_end(output_tokens) assert not is_reasoning_end + + +def test_gemma4_tool_response_does_not_block_reasoning_end(generic_tokenizer): + """<|tool_response> in the same delta must not mask a preceding <|tool_call>. + + When --stream-interval batches all generated tokens into one chunk the + sequence is: + + <|channel>thought\\n<|tool_call>...<|tool_response> + + The old is_reasoning_end returned False immediately on <|tool_response> + (searching backwards), never reaching the <|tool_call> token. That kept + state.reasoning_ended=False, so DelegatingParser.parse_delta never entered + the tool-call phase and the raw Gemma4 format leaked as content. + """ + vocab = generic_tokenizer.get_vocab() + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + + # Exact token sequence produced by gemma4-26b with stream-interval 20: + # thought\n<|tool_call>call:exec{...}<|tool_response> + output_tokens = ( + [vocab["<|channel>"]] # chunk N-1: reasoning start + + gemma4_encode_output( # chunk N: everything else in one batch + generic_tokenizer, + "thought\n<|tool_call>done", + ) + + [vocab["<|tool_response>"]] # stop token + ) + + assert parser.is_reasoning_end(output_tokens), ( + "is_reasoning_end must return True when <|tool_call> precedes " + "<|tool_response> in the same delta" + ) diff --git a/tests/reasoning/test_gemma4_reasoning_with_tool_call.py b/tests/reasoning/test_gemma4_reasoning_with_tool_call.py new file mode 100644 index 000000000000..58a372cd4820 --- /dev/null +++ b/tests/reasoning/test_gemma4_reasoning_with_tool_call.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Combined reasoning + tool-call parsing tests for Gemma4. + +Exercises DelegatingParser.parse_delta() with both Gemma4ReasoningParser +and Gemma4ToolParser active — the scenario where <|channel>thought... +precedes a tool call, covering both token-by-token and single-delta (large +stream-interval) delivery. +""" + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.parser.abstract_parser import _WrappedParser +from vllm.reasoning.gemma4_reasoning_parser import Gemma4ReasoningParser +from vllm.tool_parsers.gemma4_tool_parser import Gemma4ToolParser +from vllm.tokenizers.registry import get_tokenizer + +TOKENIZER_NAME = "google/gemma-4-E2B-it" + + +@pytest.fixture(scope="module") +def tokenizer(): + return get_tokenizer(TOKENIZER_NAME) + + +@pytest.fixture +def parser(tokenizer): + """Fresh parser per test — avoids _reasoning_text/_prefix_stripped state leak.""" + _WrappedParser.reasoning_parser_cls = Gemma4ReasoningParser + _WrappedParser.tool_parser_cls = Gemma4ToolParser + return _WrappedParser(tokenizer) + + +def _encode(tokenizer, text: str) -> list[int]: + """Encode text including Gemma4 special tokens into token IDs.""" + vocab = tokenizer.get_vocab() + enc = getattr(tokenizer, "tokenizer", tokenizer) + for special, tok_id in [ + ("<|channel>", vocab.get("<|channel>")), + ("", vocab.get("")), + ("<|tool_call>", vocab.get("<|tool_call>")), + ("", vocab.get("")), + ('<|"|>', vocab.get('<|"|>')), + ]: + if special in text and tok_id is not None: + parts = text.split(special, 1) + return _encode(tokenizer, parts[0]) + [tok_id] + _encode(tokenizer, parts[1]) + try: + return enc.encode(text, add_special_tokens=False) + except TypeError: + return enc.encode(text) + + +def _make_request(): + req = ChatCompletionRequest(messages=[], model="gemma4-test") + req.skip_special_tokens = False + return req + + +def _run_streaming(parser_instance, token_strings: list[str], tokenizer): + """Feed token strings one at a time through parse_delta.""" + vocab = tokenizer.get_vocab() + enc = getattr(tokenizer, "tokenizer", tokenizer) + request = _make_request() + reasoning_parts, content_parts, tool_calls = [], [], [] + + for tok_str in token_strings: + tok_id = vocab.get(tok_str) + if tok_id is not None: + ids = [tok_id] + else: + try: + ids = enc.encode(tok_str, add_special_tokens=False) + except TypeError: + ids = enc.encode(tok_str) + + delta = parser_instance.parse_delta(tok_str, ids, request) + if delta is None: + continue + if delta.reasoning: + reasoning_parts.append(delta.reasoning) + if delta.content: + content_parts.append(delta.content) + if delta.tool_calls: + tool_calls.extend(delta.tool_calls) + + return ( + "".join(reasoning_parts) or None, + "".join(content_parts) or None, + tool_calls, + ) + + +def _run_single_delta(parser_instance, full_text: str, tokenizer): + """Feed entire output as one delta (simulates large stream-interval).""" + request = _make_request() + full_ids = _encode(tokenizer, full_text) + delta = parser_instance.parse_delta(full_text, full_ids, request) + if delta is None: + return None, None, [] + return ( + delta.reasoning or None, + delta.content or None, + delta.tool_calls or [], + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +def test_reasoning_then_tool_call_token_by_token(parser, tokenizer): + """Token-by-token delivery: reasoning extracted, tool call parsed.""" + token_strings = ( + ["<|channel>", "thought", "\n", "I", " need", " to", " find", " files", + ""] + + ["<|tool_call>", "call", ":", "find", "{", "path", ":", '<|"|>', + "research", '<|"|>', "}", ""] + ) + reasoning, content, tool_calls = _run_streaming(parser, token_strings, tokenizer) + + assert reasoning is not None + assert not reasoning.startswith("thought"), ( + f"'thought\\n' prefix must be stripped; got {reasoning!r}" + ) + assert "<|channel>" not in reasoning + assert "" not in reasoning + + assert len(tool_calls) >= 1 + assert tool_calls[0].function.name == "find" + + +def test_reasoning_then_tool_call_single_delta(parser, tokenizer): + """Single-delta delivery (large stream-interval): reasoning must not be lost.""" + full_text = ( + '<|channel>thought\nI need to find files' + '<|tool_call>call:find{path:<|"|>research<|"|>}' + ) + reasoning, content, tool_calls = _run_single_delta(parser, full_text, tokenizer) + + assert reasoning is not None, ( + "reasoning was silently dropped when tool call arrived in the same delta" + ) + assert not reasoning.startswith("thought"), ( + f"'thought\\n' prefix must be stripped; got {reasoning!r}" + ) + assert "<|channel>" not in reasoning + assert "" not in reasoning + + assert len(tool_calls) >= 1 + assert tool_calls[0].function.name == "find" + + +def test_reasoning_after_tool_response(parser, tokenizer): + """Second-turn generation: reasoning must not leak when prompt has a prior + completed tool call + tool response (the multi-turn reasoning-leak bug). + + Simulates: prompt_token_ids ends with <|tool_call>...<|tool_response>... + which used to make is_reasoning_end() return True (finding the prior + <|tool_call> while searching backward past <|tool_response>), causing + reasoning_ended=True at the very start and leaking <|channel>thought... + tokens as content. + """ + vocab = tokenizer.get_vocab() + + tool_call_tok = vocab.get("<|tool_call>") + tool_call_end_tok = vocab.get("") + tool_resp_tok = vocab.get("<|tool_response>") + tool_resp_end_tok = vocab.get("") + + # Synthetic prompt_token_ids: simulate a completed first-turn tool exchange. + # The structure mirrors the Gemma4 template output: + # <|tool_call>body<|tool_response>body + # The end marker is required for is_reasoning_end to + # distinguish this (completed exchange) from a bare stop token. + prompt_ids: list[int] = [] + if tool_call_tok is not None: + prompt_ids.append(tool_call_tok) + prompt_ids += [1000, 1001, 1002] # tool call body tokens + if tool_call_end_tok is not None: + prompt_ids.append(tool_call_end_tok) + if tool_resp_tok is not None: + prompt_ids.append(tool_resp_tok) + prompt_ids += [2000, 2001] # tool response body tokens + if tool_resp_end_tok is not None: + prompt_ids.append(tool_resp_end_tok) + + request = _make_request() + reasoning_parts: list[str] = [] + content_parts: list[str] = [] + tool_calls_found: list = [] + + # Feed second-turn generation as individual token strings, passing + # prompt_token_ids only on the very first call (mimics parse_delta usage). + enc = getattr(tokenizer, "tokenizer", tokenizer) + first = True + for tok_str in ["<|channel>", "thought", "\n", "I", " need", " to", " answer", + "", "The", " answer", " is", " 42"]: + tok_id = vocab.get(tok_str) + if tok_id is not None: + ids = [tok_id] + else: + try: + ids = enc.encode(tok_str, add_special_tokens=False) + except TypeError: + ids = enc.encode(tok_str) + + delta = parser.parse_delta( + tok_str, ids, request, + prompt_token_ids=prompt_ids if first else None, + ) + first = False + if delta is None: + continue + if delta.reasoning: + reasoning_parts.append(delta.reasoning) + if delta.content: + content_parts.append(delta.content) + if delta.tool_calls: + tool_calls_found.extend(delta.tool_calls) + + reasoning = "".join(reasoning_parts) or None + content = "".join(content_parts) or None + + assert reasoning is not None, ( + "reasoning was lost in second-turn generation after tool response in prompt" + ) + assert not reasoning.startswith("thought"), ( + f"'thought\\n' prefix must be stripped; got {reasoning!r}" + ) + assert "<|channel>" not in reasoning + assert "" not in reasoning + + assert content is not None, "content after reasoning must not be dropped" + assert "42" in content, f"expected '42' in content, got {content!r}" + assert len(tool_calls_found) == 0 + + # No raw thinking tokens should have leaked into content + assert "<|channel>" not in (content or ""), "thinking start token leaked into content" + assert "" not in (content or ""), "thinking end token leaked into content" + + +def test_reasoning_only_no_tool_call(parser, tokenizer): + """Reasoning only (no tool call): content passes through cleanly.""" + token_strings = ( + ["<|channel>", "thought", "\n", "Let", " me", " think", ""] + + ["The", " answer", " is", " 42"] + ) + reasoning, content, tool_calls = _run_streaming(parser, token_strings, tokenizer) + + assert reasoning is not None + assert not reasoning.startswith("thought"), ( + f"'thought\\n' prefix must be stripped; got {reasoning!r}" + ) + assert content is not None + assert "42" in content + assert len(tool_calls) == 0 + + +def test_empty_thinking_block_tool_call_no_reasoning_leak(parser, tokenizer): + """Empty thinking block (<|channel>thought\\n) followed by a + tool call must NOT emit an empty-string reasoning_content delta. + + When the model produces only the 'thought\\n' role label (nothing after + it inside the channel) the prefix-stripping logic previously returned + DeltaMessage(reasoning='') — an empty string, not None. The harness + received {"reasoning_content": ""} and mis-rendered it. The fix makes + the parser return None (or forward the post-channel content only) so + no empty reasoning delta is ever emitted. + + Exercises both token-by-token and single-delta delivery. + """ + vocab = tokenizer.get_vocab() + enc = getattr(tokenizer, "tokenizer", tokenizer) + + # Token-by-token: each token arrives individually. + token_strings = ( + ["<|channel>", "thought", "\n", ""] + + ["<|tool_call>", "call", ":", "find", "{", "path", ":", '<|"|>', + "research", '<|"|>', "}", ""] + ) + reasoning, content, tool_calls = _run_streaming(parser, token_strings, tokenizer) + + assert reasoning is None, ( + f"empty thinking block must not emit reasoning_content; got {reasoning!r}" + ) + assert len(tool_calls) >= 1, "tool call must still be parsed" + assert tool_calls[0].function.name == "find" + + # Single-delta: the whole output arrives in one chunk (stream-interval 20). + parser2_instance = type(parser)(tokenizer) + type(parser2_instance).reasoning_parser_cls = type(parser).reasoning_parser_cls + type(parser2_instance).tool_parser_cls = type(parser).tool_parser_cls + + # Build a fresh parser with the correct classes set. + from vllm.parser.abstract_parser import _WrappedParser + from vllm.reasoning.gemma4_reasoning_parser import Gemma4ReasoningParser + from vllm.tool_parsers.gemma4_tool_parser import Gemma4ToolParser + _WrappedParser.reasoning_parser_cls = Gemma4ReasoningParser + _WrappedParser.tool_parser_cls = Gemma4ToolParser + parser2 = _WrappedParser(tokenizer) + + full_text = ( + '<|channel>thought\n' + '<|tool_call>call:find{path:<|"|>research<|"|>}' + ) + reasoning2, content2, tool_calls2 = _run_single_delta(parser2, full_text, tokenizer) + + assert reasoning2 is None, ( + f"single-delta empty thinking must not emit reasoning_content; " + f"got {reasoning2!r}" + ) + assert len(tool_calls2) >= 1, "tool call must still be parsed in single-delta mode" + assert tool_calls2[0].function.name == "find" diff --git a/tests/tool_parsers/test_gemma4_tool_parser.py b/tests/tool_parsers/test_gemma4_tool_parser.py index 6f3709e19a45..5fd35467f3a6 100644 --- a/tests/tool_parsers/test_gemma4_tool_parser.py +++ b/tests/tool_parsers/test_gemma4_tool_parser.py @@ -728,3 +728,226 @@ def test_streaming_trailing_bare_bool_not_duplicated(self, parser, mock_request) } assert args_text.count("replace_all") == 1 + + def test_complete_tool_call_in_single_delta(self, parser, mock_request): + """Entire tool call arrives in one streaming chunk (stream-interval batching). + + When --stream-interval batches all tokens into one SSE event, start and + end tokens arrive in the same delta. Case 2 in _extract_streaming must + initialise tool state even when start_count == end_count, and + _handle_tool_call_end must emit both the function name and arguments. + """ + full_call = ( + "<|tool_call>call:exec{command:<|\"|>echo hello<|\"|>}" + ) + results = self._simulate_streaming(parser, mock_request, [full_call]) + + name = self._collect_function_name(results) + args_text = self._collect_arguments(results) + + assert name == "exec", f"expected 'exec', got {name!r}" + assert args_text, "no arguments were streamed" + parsed = json.loads(args_text) + assert parsed == {"command": "echo hello"} + + # Exactly one tool call header must have been emitted. + headers = [ + delta + for delta, _ in results + if delta and delta.tool_calls and delta.tool_calls[0].id is not None + ] + assert len(headers) == 1 + + def test_multiple_tool_calls_in_single_delta(self, parser, mock_request): + """Multiple complete tool calls arriving in one streaming chunk. + + When N tool calls all arrive in the same batched delta, + Case 2 must initialise N entries (not just one) and + _handle_tool_call_end must emit all N tool calls whose + names have not been individually streamed yet. + """ + full_delta = ( + "<|tool_call>call:read{path:<|\"|>a.py<|\"|>}" + "<|tool_call>call:write{path:<|\"|>b.py<|\"|>,content:<|\"|>hello<|\"|>}" + ) + results = self._simulate_streaming(parser, mock_request, [full_delta]) + + all_tcs = [ + tc + for delta, _ in results + if delta and delta.tool_calls + for tc in delta.tool_calls + ] + assert len(all_tcs) == 2, f"expected 2 tool calls, got {len(all_tcs)}: {all_tcs}" + + names = [ + (tc.function.name if hasattr(tc.function, "name") else tc.function.get("name")) + for tc in all_tcs + ] + assert "read" in names + assert "write" in names + + def test_streaming_mixed_partial_and_complete_in_one_delta( + self, parser, mock_request + ): + """Mixed scenario: partially-streamed call finishes in same delta + as a new complete call arrives. + + Simulates: + 1. Delta 1: starts first call, sends name, streams partial args + 2. Delta 2: continues streaming args (call still in middle) + 3. Delta 3: first call ends + second call begins and ends in one chunk + + This exercises the fix for PR #42875: the old _handle_tool_call_end + used a single current_tool_name_sent flag to branch between + 'single-delta' and 'normal' paths, causing partially-streamed calls + to be skipped when new complete calls arrived in the same delta. + """ + chunks = [ + # Delta 1: start first tool call, send name, partial args + '<|tool_call>call:search{query:<|"|>hel', + # Delta 2: continue streaming args (still in middle of call) + "lo", + # Delta 3: first call finishes + second complete call arrives + # in the same delta (mixed scenario) + '<|"|>}<|tool_call>call:read{path:<|"|>/foo<|"|>}', + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + + # Collect all deltas by tool index + deltas_by_index: dict[int, dict] = {} + for delta, _ in results: + if delta and delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in deltas_by_index: + deltas_by_index[idx] = {"name": None, "arguments": ""} + func = tc.function if isinstance(tc.function, + dict) else tc.function + if isinstance(func, dict): + if func.get("name"): + deltas_by_index[idx]["name"] = func["name"] + if func.get("arguments"): + deltas_by_index[idx]["arguments"] += func["arguments"] + else: + if getattr(func, "name", None): + deltas_by_index[idx]["name"] = func.name + if getattr(func, "arguments", None): + deltas_by_index[idx]["arguments"] += func.arguments + + # We should have two tool calls + assert 0 in deltas_by_index, "Missing tool call 0" + assert 1 in deltas_by_index, "Missing tool call 1" + + # Both should have valid JSON arguments + for idx in [0, 1]: + args_str = deltas_by_index[idx]["arguments"] + assert args_str, f"Tool call {idx} has no arguments" + try: + parsed = json.loads(args_str) + assert isinstance(parsed, dict), f"Tool call {idx} args not a dict" + except json.JSONDecodeError as e: + pytest.fail( + f"Tool call {idx} arguments not valid JSON: " + f"'{args_str}' - {e}" + ) + + # Verify specific values + # call 0 (search): name may or may not be present depending on + # whether delta 1 emitted it; args should be complete + assert json.loads(deltas_by_index[0]["arguments"]) == { + "query": "hello" + }, f'Expected {{"query": "hello"}}, got {deltas_by_index[0]["arguments"]}' + + # call 1 (read): should have name + full args + assert deltas_by_index[1]["name"] == "read", f'Expected name "read", got {deltas_by_index[1]["name"]}' + assert json.loads(deltas_by_index[1]["arguments"]) == { + "path": "/foo" + }, f'Expected {{"path": "/foo"}}, got {deltas_by_index[1]["arguments"]}' + + def _collect_content(self, results): + """Concatenate all content deltas from streaming results.""" + out = "" + for delta, _ in results: + if delta and getattr(delta, "content", None): + out += delta.content + return out + + def test_streaming_inter_call_text_preserved_in_single_delta( + self, parser, mock_request + ): + """Plain text between two tool calls in one delta must be preserved. + + When stream-interval batching produces + ``<|tool_call>...X<|tool_call>...``, + the inter-call text ``X`` must appear in the streamed content. + """ + delta = ( + '<|tool_call>call:read{path:<|"|>a.py<|"|>}' + " and then " + '<|tool_call>call:read{path:<|"|>b.py<|"|>}' + ) + results = self._simulate_streaming(parser, mock_request, [delta]) + + content = self._collect_content(results) + assert "and then" in content, ( + f"inter-call text lost; content={content!r}" + ) + # No raw tool call markers should leak. + assert TOOL_CALL_START not in content + assert TOOL_CALL_END not in content + + def test_streaming_no_arg_fragment_leak_when_started_inside( + self, parser, mock_request + ): + """A delta that starts inside a tool call must not leak arg fragments. + + Chunk 1 opens the tool call and streams partial args. Chunk 2 + contains the rest of the args plus the closing end token plus a + new complete tool call: the leftover argument bytes (e.g. ``}``, + ``"``) before the end token must NOT be emitted as content. + """ + chunks = [ + '<|tool_call>call:search{query:<|"|>hel', + 'lo<|"|>} result: ' + '<|tool_call>call:noop{}', + ] + results = self._simulate_streaming(parser, mock_request, chunks) + + content = self._collect_content(results) + # The legitimate text between calls IS content. + assert "result:" in content, ( + f"inter-call text dropped; content={content!r}" + ) + # No tool call markers must leak into content. + assert TOOL_CALL_START not in content + assert TOOL_CALL_END not in content + # No raw argument fragments must leak (e.g. the closing brace + # or quote delimiter from the still-streaming first call). + assert "}" not in content, ( + f"raw arg fragment leaked into content: {content!r}" + ) + assert '<|"|>' not in content + + def test_streaming_end_then_start_no_duplication( + self, parser, mock_request + ): + """End-then-start in one delta must not duplicate inter-call text. + + Regression for the ``last_end < first_start`` slicing edge case + flagged in review: the text between the two markers must appear + exactly once in content. + """ + chunks = [ + '<|tool_call>call:a{x:<|"|>1<|"|>', + '} middle <|tool_call>call:b{y:<|"|>2<|"|>}', + ] + results = self._simulate_streaming(parser, mock_request, chunks) + + content = self._collect_content(results) + assert content.count("middle") == 1, ( + f"inter-call text duplicated; content={content!r}" + ) + assert TOOL_CALL_START not in content + assert TOOL_CALL_END not in content diff --git a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py index 2a13f138607b..f59af5f900e5 100644 --- a/vllm/parser/abstract_parser.py +++ b/vllm/parser/abstract_parser.py @@ -706,9 +706,10 @@ def parse_delta( delta_text = current_text delta_token_ids = current_token_ids - # A boundary delta may carry both reasoning and tool call, - # save it before the tool parser overwrites delta_message. - reasoning = delta_message.reasoning if delta_message else None + # Preserve any reasoning extracted in the same delta where + # reasoning ended and tool calls began (single-delta batching). + saved_reasoning = delta_message.reasoning if delta_message else None + delta_message, state.function_name_returned = ( self._extract_tool_calls_streaming( previous_text=state.previous_text, @@ -723,11 +724,11 @@ def parse_delta( function_name_returned=state.function_name_returned, ) ) - if reasoning: - if not delta_message: - delta_message = DeltaMessage() - delta_message.reasoning = reasoning - + if saved_reasoning is not None: + if delta_message is None: + delta_message = DeltaMessage(reasoning=saved_reasoning) + else: + delta_message.reasoning = saved_reasoning if ( delta_message and delta_message.tool_calls diff --git a/vllm/reasoning/gemma4_reasoning_parser.py b/vllm/reasoning/gemma4_reasoning_parser.py index 6f2241603f9a..a9527bb4c58e 100644 --- a/vllm/reasoning/gemma4_reasoning_parser.py +++ b/vllm/reasoning/gemma4_reasoning_parser.py @@ -55,6 +55,9 @@ def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): self.new_turn_token_id = self.vocab["<|turn>"] self.tool_call_token_id = self.vocab["<|tool_call>"] self.tool_response_token_id = self.vocab["<|tool_response>"] + # End-of-tool-response marker (closes <|tool_response>...). + # May be None if not in vocabulary. + self.tool_response_end_token_id = self.vocab.get("") def adjust_request( self, request: "ChatCompletionRequest | ResponsesRequest" @@ -79,6 +82,12 @@ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: new_turn_token_id = self.new_turn_token_id tool_call_token_id = self.tool_call_token_id tool_response_token_id = self.tool_response_token_id + tool_response_end_token_id = self.tool_response_end_token_id + + # Tracks whether we passed a end marker (which only + # appears in prompt context when the tool exchange is complete) before + # reaching <|tool_response> start token while searching backward. + saw_tool_response_end = False # Search from the end of input_ids to find the last match. for i in range(len(input_ids) - 1, -1, -1): @@ -87,12 +96,28 @@ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: if input_ids[i] == tool_call_token_id: # We're generating a tool call, so reasoning must be ended. return True - if input_ids[i] in (new_turn_token_id, tool_response_token_id): - # We found a new turn or tool response token so don't consider - # reasoning ended yet, since the model starts new reasoning - # after these tokens. + if input_ids[i] == new_turn_token_id: + # A new conversation turn is starting; new reasoning may follow. return False - if input_ids[i] == end_token_id: + if ( + tool_response_end_token_id is not None + and input_ids[i] == tool_response_end_token_id + ): + # closes a tool-response block in the prompt. + # Set a flag so we know the next <|tool_response> we encounter + # is part of a completed exchange, not a bare stop token. + saw_tool_response_end = True + elif input_ids[i] == tool_response_token_id: + if saw_tool_response_end: + # Completed tool exchange in the prompt: the model is in a + # fresh state and may start new reasoning in this turn. + # Returning False prevents a prior <|tool_call> from being + # found further back and incorrectly triggering return True. + return False + # else: <|tool_response> is a bare stop token appended to the + # delta after a tool call — keep searching backward to find the + # preceding <|tool_call> token and correctly return True. + elif input_ids[i] == end_token_id: return True return False @@ -113,7 +138,7 @@ def extract_reasoning( reasoning, content = super().extract_reasoning(model_output, request) if reasoning is not None: - reasoning = _strip_thought_label(reasoning) + reasoning = _strip_thought_label(reasoning) or None return reasoning, content # ------------------------------------------------------------------ @@ -194,8 +219,14 @@ def extract_reasoning_streaming( else: if len(self._reasoning_text) >= prefix_len: self._prefix_stripped = True - result.reasoning = "" - return result + # The entire delta was the stripped prefix — + # suppress the empty reasoning delta. If the base + # parser extracted post-reasoning text (e.g. tool + # call markup), forward that content so the + # DelegatingParser can hand it to the tool parser. + if result.content: + return DeltaMessage(content=result.content) + return None return None # Case 2: Accumulated text is a strict prefix of diff --git a/vllm/tool_parsers/gemma4_tool_parser.py b/vllm/tool_parsers/gemma4_tool_parser.py index 9925284273f9..79fcad10989a 100644 --- a/vllm/tool_parsers/gemma4_tool_parser.py +++ b/vllm/tool_parsers/gemma4_tool_parser.py @@ -386,6 +386,7 @@ def _reset_streaming_state(self) -> None: self.current_tool_name_sent = False self.prev_tool_call_arr: list[dict] = [] self.streamed_args_for_tool: list[str] = [] + self.buffered_delta_text = "" def adjust_request( self, request: ChatCompletionRequest | ResponsesRequest @@ -497,6 +498,13 @@ def extract_tool_calls_streaming( delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: + # Reset streaming state at the start of each new request. + # The parser instance is reused across all streaming requests; + # without this, current_tool_id, prev_tool_call_arr, and + # streamed_args_for_tool leak between requests. + if previous_text == "": + self._reset_streaming_state() + # Buffer delta text to handle multi-token special sequences delta_text = self._buffer_delta_text(delta_text) # Keep current_text from the upstream stream state. The buffered delta @@ -549,26 +557,58 @@ def _extract_streaming( return DeltaMessage(content=delta_text) return None - # Case 2: Starting a new tool call - if start_count > prev_start_count and start_count > end_count: - self.current_tool_id += 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - self.prev_tool_call_arr.append({}) - logger.debug("Starting new tool call %d", self.current_tool_id) - # Don't return yet — fall through to try parsing if there's - # content after <|tool_call> in this same delta - # (but usually it's just the token itself, so return None) - if len(delta_text) <= len(self.tool_call_start_token): + # Case 2: Starting a new tool call. + # Note: do NOT require start_count > end_count here — when the entire + # tool call arrives in one streaming chunk (common with stream-interval + # batching), start and end counts are equal in the same delta. We + # still need to initialise the tool state so _handle_tool_call_end + # (Case 3 below) can emit the complete call correctly. + if start_count > prev_start_count: + new_calls = start_count - prev_start_count + for _ in range(new_calls): + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + self.prev_tool_call_arr.append({}) + logger.debug("Starting new tool call %d", self.current_tool_id) + # If the delta contains only start token(s) and nothing more, wait. + # Otherwise fall through — end tokens may also be present. + if len(delta_text) <= len(self.tool_call_start_token) * new_calls: return None # Case 3: Tool call just ended if end_count > prev_end_count: - return self._handle_tool_call_end(current_text) + result = self._handle_tool_call_end(current_text) + # Capture content from delta_text that lies outside tool call + # markers. Using delta_text (the buffered version from + # _buffer_delta_text) is critical: current_text includes + # partial tokens that were held back for the next chunk. + if delta_text: + started_inside = prev_start_count > prev_end_count + content = self._extract_content_outside_tool_calls( + delta_text, started_inside + ) + if content.strip(): + if result is not None: + result.content = (result.content or "") + content + else: + result = DeltaMessage(content=content) + return result # Case 4: In the middle of a tool call — parse partial content if start_count > end_count: - return self._handle_tool_call_middle(current_text) + result = self._handle_tool_call_middle(current_text) + if delta_text: + started_inside = prev_start_count > prev_end_count + content = self._extract_content_outside_tool_calls( + delta_text, started_inside + ) + if content.strip(): + if result is not None: + result.content = (result.content or "") + content + else: + result = DeltaMessage(content=content) + return result # Default: generate text outside tool calls if delta_text: @@ -578,6 +618,44 @@ def _extract_streaming( return DeltaMessage(content=text) return None + def _extract_content_outside_tool_calls( + self, delta_text: str, started_inside: bool + ) -> str: + """Collect text spans in delta_text that lie outside tool call markers. + + Walks the buffered delta and concatenates every region that is not + inside a ``<|tool_call>...`` pair, taking into account + whether the delta began inside an active tool call (i.e. the + previous text had an unclosed ``<|tool_call>``). When started + inside, the text before the first end token is part of the + ongoing call's arguments and is excluded from the returned + content. The returned string never contains tool call markers. + """ + parts: list[str] = [] + start_token = self.tool_call_start_token + end_token = self.tool_call_end_token + inside = started_inside + pos = 0 + n = len(delta_text) + while pos < n: + if inside: + end_idx = delta_text.find(end_token, pos) + if end_idx == -1: + # Remainder of delta is arguments — not content. + break + pos = end_idx + len(end_token) + inside = False + else: + start_idx = delta_text.find(start_token, pos) + if start_idx == -1: + parts.append(delta_text[pos:]) + break + if start_idx > pos: + parts.append(delta_text[pos:start_idx]) + pos = start_idx + len(start_token) + inside = True + return "".join(parts) + def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]: """Extract function name and raw argument string from partial text. @@ -655,8 +733,18 @@ def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None: def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None: """Handle streaming when a tool call has just completed. - Performs a final parse of the complete tool call and flushes - any remaining un-streamed argument fragments. + Performs a final parse of every tool call that ended in the current + streaming delta and emits either remaining argument diffs (for calls + that were already partially streamed) or complete tool calls (for + calls that arrived entirely in this delta). + + Handles three scenarios: + 1. Pure single-delta: N tool calls all arrived in one chunk (none + previously streamed). + 2. Pure normal: one tool call ending that was partially streamed in + previous deltas (name already sent). + 3. Mixed: a partially-streamed call finishing + one or more new + complete calls arriving in the same delta. """ if self.current_tool_id < 0 or self.current_tool_id >= len( self.prev_tool_call_arr @@ -667,31 +755,59 @@ def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None: ) return None - # Parse the complete tool call using regex for accuracy all_matches = self.tool_call_regex.findall(current_text) - if self.current_tool_id < len(all_matches): - _, args_str = all_matches[self.current_tool_id] + if not all_matches: + return None + + tool_calls: list[DeltaToolCall] = [] + + # Iterate through every tool call that ended in this delta, up to + # current_tool_id. For each, emit either a complete DeltaToolCall + # (if the name was never streamed) or the remaining argument diff + # (if it was already partially streamed in previous deltas). + for idx in range(min(len(all_matches), self.current_tool_id + 1)): + if idx >= len(self.prev_tool_call_arr): + break + + func_name, args_str = all_matches[idx] final_args = _parse_gemma4_args(args_str) final_args_json = json.dumps(final_args, ensure_ascii=False) - prev_streamed = self.streamed_args_for_tool[self.current_tool_id] - if len(final_args_json) > len(prev_streamed): - diff = final_args_json[len(prev_streamed) :] - self.streamed_args_for_tool[self.current_tool_id] = final_args_json - self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args - - return DeltaMessage( - tool_calls=[ + if not self.prev_tool_call_arr[idx].get("name"): + # Unstreamed call — emit complete tool call (name + full args) + self.prev_tool_call_arr[idx] = { + "name": func_name, + "arguments": final_args, + } + self.streamed_args_for_tool[idx] = final_args_json + tool_calls.append( + DeltaToolCall( + index=idx, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=func_name, + arguments=final_args_json, + ).model_dump(exclude_none=True), + ) + ) + else: + # Previously-streamed call — emit remaining argument diff + prev_streamed = self.streamed_args_for_tool[idx] + if len(final_args_json) > len(prev_streamed): + diff = final_args_json[len(prev_streamed) :] + self.streamed_args_for_tool[idx] = final_args_json + self.prev_tool_call_arr[idx]["arguments"] = final_args + tool_calls.append( DeltaToolCall( - index=self.current_tool_id, + index=idx, function=DeltaFunctionCall(arguments=diff).model_dump( exclude_none=True ), ) - ] - ) + ) - return None + return DeltaMessage(tool_calls=tool_calls) if tool_calls else None def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None: """Parse raw Gemma4 arguments, convert to JSON, diff, and emit.