diff --git a/tests/tool_parsers/test_qwen3coder_tool_parser.py b/tests/tool_parsers/test_qwen3coder_tool_parser.py index c62e95830243..d761ab6580cd 100644 --- a/tests/tool_parsers/test_qwen3coder_tool_parser.py +++ b/tests/tool_parsers/test_qwen3coder_tool_parser.py @@ -1146,3 +1146,157 @@ def test_no_double_serialization_string_args(qwen3_tool_parser): args = json.loads(raw_arguments) assert args["message"] == "hello world" assert '\\"hello world\\"' not in raw_arguments + + +def test_extract_tool_calls_streaming_split_tag(qwen3_tool_parser): + """ + This highlights the need to use current_text instead of delta_text. + """ + request = ChatCompletionRequest(model=MODEL, messages=[]) + + # Iteration 1: "" + prev_text_2 = curr_text_1 + delta_text_2 = "_call>" + curr_text_2 = prev_text_2 + delta_text_2 + + msg2 = qwen3_tool_parser.extract_tool_calls_streaming( + previous_text=prev_text_2, + current_text=curr_text_2, + delta_text=delta_text_2, + previous_token_ids=[1, 2, 3, 4], + current_token_ids=[1, 2, 3, 4, 5], + delta_token_ids=[5], + request=request + ) + + # The assertion must verify that the is_tool_call_started variable correctly switches to True + assert qwen3_tool_parser.is_tool_call_started is True, "is_tool_call_started should be True when '' is completed in current_text." + + # and that the function does not return fragments of the tag in DeltaMessage(content=...) + if msg1 and msg1.content: + assert "" not in msg2.content + + + +def test_extract_tool_calls_streaming_speculative_decode_loss(qwen3_tool_parser): + """ + if json_started=False, and the delta contains the parameters AND the end of the tool call, + the parser should not just return '{' and lose the parameters. + """ + + request = ChatCompletionRequest(model="test", messages=[]) + + text1 = "\n\n" + qwen3_tool_parser.extract_tool_calls_streaming( + "", text1, text1, [], [1], [1], request + ) + + # Delta 2 has the rest of the tool call + delta_str = "\nParis\n\n\n" + text2 = text1 + delta_str + delta2 = qwen3_tool_parser.extract_tool_calls_streaming( + text1, text2, delta_str, [1], [1,2], [2], request + ) + + # The parameters should be in delta2! + assert delta2 is not None + assert delta2.tool_calls is not None + assert len(delta2.tool_calls) == 1 + args = delta2.tool_calls[0].function.arguments + assert "Paris" in args, f"Arguments lost! Got: {args}" + + +def test_extract_tool_calls_streaming_various_chunk_sizes(qwen3_tool_parser): + """ + Test streaming with various chunk sizes using the exact template from Qwen 3.6. + """ + + request = ChatCompletionRequest(model="test", messages=[]) + + # Exact template format from Qwen 3.6 + template_text = """ + + +value_1 + + +This is the value for the second parameter +that can span +multiple lines + + +""" + + # Test with different chunk sizes to simulate different network/speculative decoding behaviors + for chunk_size in [1, 3, 15, len(template_text)]: + # Reset parser state + qwen3_tool_parser._reset_streaming_state() + + tool_states = {} + + # Simulate custom streaming to precisely control chunk sizes + current_text = "" + previous_text = "" + ptr = 0 + + while ptr < len(template_text): + delta = template_text[ptr:ptr+chunk_size] + previous_text = current_text + current_text += delta + ptr += chunk_size + + delta_message = qwen3_tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta, + previous_token_ids=[], + current_token_ids=[], + delta_token_ids=[], + request=request + ) + + if delta_message and delta_message.tool_calls: + for tool_call in delta_message.tool_calls: + idx = tool_call.index + if idx not in tool_states: + tool_states[idx] = { + "id": None, + "name": None, + "arguments": "", + "type": None, + } + + if tool_call.id: + tool_states[idx]["id"] = tool_call.id + if tool_call.type: + tool_states[idx]["type"] = tool_call.type + if tool_call.function: + if tool_call.function.name: + tool_states[idx]["name"] = tool_call.function.name + if tool_call.function.arguments is not None: + tool_states[idx]["arguments"] += tool_call.function.arguments + + assert 0 in tool_states + assert tool_states[0]["name"] == "example_function_name" + + import json + args = json.loads(tool_states[0]["arguments"]) + assert args["example_parameter_1"] == "value_1" + assert args["example_parameter_2"] == "This is the value for the second parameter\nthat can span\nmultiple lines" diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index 7b089ceffbc0..7f75f2451554 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -25,7 +25,7 @@ Tool, ToolParser, ) -from vllm.tool_parsers.utils import find_tool_properties +from vllm.tool_parsers.utils import find_tool_properties, partial_tag_overlap logger = init_logger(__name__) @@ -109,6 +109,8 @@ def _reset_streaming_state(self): # Store accumulated parameters for type conversion self.accumulated_params = {} self.streaming_request = None + self._sent_content_idx = 0 + self.current_tool_index = 0 def _convert_param_value( self, param_value: str, param_name: str, param_config: dict, func_name: str @@ -372,6 +374,22 @@ def extract_tool_calls_streaming( # Check if this tool call has ended tool_ends = current_text.count(self.tool_call_end_token) if tool_ends > self.current_tool_index: + # Find the end of the tool call that just finished and update + # _sent_content_idx to prevent it from leaking into content. + search_idx = 0 + for _ in range(self.current_tool_index + 1): + search_idx = current_text.find(self.tool_call_start_token, + search_idx) + if search_idx == -1: + break + end_idx = current_text.find(self.tool_call_end_token, + search_idx) + if end_idx != -1: + self._sent_content_idx = max( + self._sent_content_idx, + end_idx + len(self.tool_call_end_token)) + search_idx += len(self.tool_call_start_token) + # This tool has ended, advance to next self.current_tool_index += 1 self.header_sent = False @@ -380,47 +398,55 @@ def extract_tool_calls_streaming( self.json_closed = False self.accumulated_params = {} - # Check if there are more tool calls - tool_starts = current_text.count(self.tool_call_start_token) - if self.current_tool_index >= tool_starts: - # No more tool calls - self.is_tool_call_started = False + # Always reset is_tool_call_started when a tool call ends. + # This allows correctly sending any content between or after + # tool calls. + self.is_tool_call_started = False # Continue processing next tool return None + content_message = None # Handle normal content before tool calls if not self.is_tool_call_started: # Check if tool call is starting + tool_starts_count = current_text.count(self.tool_call_start_token) if ( self.tool_call_start_token_id in delta_token_ids - or self.tool_call_start_token in delta_text + or tool_starts_count > self.current_tool_index ): self.is_tool_call_started = True # Return any content before the tool call - if self.tool_call_start_token in delta_text: - content_before = delta_text[ - : delta_text.index(self.tool_call_start_token) - ] + last_start = current_text.find(self.tool_call_start_token, self._sent_content_idx) + if last_start != -1 and last_start > self._sent_content_idx: + content_before = current_text[self._sent_content_idx:last_start] + self._sent_content_idx = last_start if content_before: - return DeltaMessage(content=content_before) - return None + content_message = DeltaMessage(content=content_before) else: + overlap = partial_tag_overlap(current_text, self.tool_call_start_token) + sendable_idx = len(current_text) - overlap + # Check if we're between tool calls - skip whitespace if ( current_text.rstrip().endswith(self.tool_call_end_token) and delta_text.strip() == "" ): # We just ended a tool call, skip whitespace + self._sent_content_idx = len(current_text) return None - # Normal content, no tool call - return DeltaMessage(content=delta_text) + + if sendable_idx > self._sent_content_idx: + content = current_text[self._sent_content_idx:sendable_idx] + self._sent_content_idx = sendable_idx + if content: + return DeltaMessage(content=content) + return None # Check if we're between tool calls (waiting for next one) # Count tool calls we've seen vs processed tool_starts_count = current_text.count(self.tool_call_start_token) if self.current_tool_index >= tool_starts_count: - # We're past all tool calls, shouldn't be here - return None + return content_message # We're in a tool call, find the current tool call portion # Need to find the correct tool call based on current_tool_index @@ -434,8 +460,7 @@ def extract_tool_calls_streaming( idx += len(self.tool_call_start_token) if self.current_tool_index >= len(tool_start_positions): - # No more tool calls to process yet - return None + return content_message tool_start_idx = tool_start_positions[self.current_tool_index] # Find where this tool call ends (or current position if not ended yet) @@ -447,6 +472,7 @@ def extract_tool_calls_streaming( tool_start_idx : tool_end_idx + len(self.tool_call_end_token) ] + tool_call_fragments = None # Looking for function header if not self.header_sent: if self.tool_call_prefix in tool_text: @@ -479,21 +505,16 @@ def extract_tool_calls_streaming( # accesses streamed_args_for_tool[index]. self.streamed_args_for_tool.append("") - # Send header with function info - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - id=self.current_tool_id, - function=DeltaFunctionCall( - name=self.current_function_name, arguments="" - ), - type="function", - ) - ] + tool_call_fragments = DeltaToolCall( + index=self.current_tool_index, + id=self.current_tool_id, + function=DeltaFunctionCall(name=self.current_function_name, arguments=""), + type="function", ) - return None + if not self.header_sent: + return content_message + arguments_to_emit = "" # We've sent header, now handle function body if self.in_function: # Always send opening brace first, regardless of whether @@ -504,16 +525,8 @@ def extract_tool_calls_streaming( if not self.json_started: self.json_started = True self.streamed_args_for_tool[self.current_tool_index] += "{" - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ] - ) + arguments_to_emit += "{" - # Find all parameter start positions in current tool_text param_starts = [] search_idx = 0 while True: @@ -614,15 +627,7 @@ def extract_tool_calls_streaming( self.current_tool_index, len(self.streamed_args_for_tool), ) - - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments=combined), - ) - ] - ) + arguments_to_emit += combined # Check for function end AFTER processing parameters. # This ordering is critical: with speculative decoding a @@ -664,20 +669,24 @@ def extract_tool_calls_streaming( self.current_tool_index, len(self.streamed_args_for_tool), ) - - result = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ] - ) - + arguments_to_emit += "}" self.in_function = False self.json_closed = True self.accumulated_params = {} - return result + if tool_call_fragments or arguments_to_emit: + if not tool_call_fragments: + tool_call_fragments = DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=arguments_to_emit), + ) + else: + tool_call_fragments.function.arguments += arguments_to_emit + + if content_message: + content_message.tool_calls = [tool_call_fragments] + return content_message + else: + return DeltaMessage(tool_calls=[tool_call_fragments]) - return None + return content_message