diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 39f8635bf297..06b16cde6748 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1249,13 +1249,23 @@ async def chat_completion_stream_generator( ) # get the expected call based on partial JSON - # parsing which "autocompletes" the JSON - expected_call = json.dumps( - tool_parser.prev_tool_call_arr[index].get( - "arguments", {} - ), - ensure_ascii=False, + # parsing which "autocompletes" the JSON. + # Tool parsers (e.g. Qwen3Coder) store + # arguments as a JSON string in + # prev_tool_call_arr. Calling json.dumps() + # on an already-serialized string would + # double-serialize it (e.g. '{"k":1}' becomes + # '"{\\"k\\":1}"'), which then causes the + # replace() below to fail and append the + # entire double-serialized string as a + # spurious final delta. + args = tool_parser.prev_tool_call_arr[index].get( + "arguments", {} ) + if isinstance(args, str): + expected_call = args + else: + expected_call = json.dumps(args, ensure_ascii=False) # get what we've streamed so far for arguments # for the current tool diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index dfe790ee752e..14a67ec38e47 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -490,10 +490,17 @@ def extract_tool_calls_streaming( self.prev_tool_call_arr.append( { "name": self.current_function_name, - "arguments": "{}", # Placeholder, will be updated later + "arguments": "{}", } ) + # Initialize streamed args tracking for this tool. + # The serving layer reads streamed_args_for_tool to + # compute remaining arguments at stream end. Without + # this, IndexError occurs when the serving layer + # accesses streamed_args_for_tool[index]. + self.streamed_args_for_tool.append("") + # Send header with function info return DeltaMessage( tool_calls=[ @@ -511,9 +518,14 @@ def extract_tool_calls_streaming( # We've sent header, now handle function body if self.in_function: - # Send opening brace if not sent yet - if not self.json_started and self.parameter_prefix not in delta_text: + # Always send opening brace first, regardless of whether + # parameter_prefix is in the current delta. With speculative + # decoding, a single delta may contain both the opening brace + # and parameter data; skipping "{" here would desync + # json_started from what was actually streamed. + if not self.json_started: self.json_started = True + self.streamed_args_for_tool[self.current_tool_index] += "{" return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -523,10 +535,6 @@ def extract_tool_calls_streaming( ] ) - # Make sure json_started is set if we're processing parameters - if not self.json_started: - self.json_started = True - # Check for function end in accumulated text if not self.json_closed and self.function_end_token in tool_text: # Close JSON @@ -558,7 +566,23 @@ def extract_tool_calls_streaming( self.prev_tool_call_arr[i]["arguments"] = args break except Exception: - pass # Ignore parsing errors during streaming + logger.debug( + "Failed to parse tool call during streaming: %s", + tool_text, + exc_info=True, + ) + + # Send closing brace; the serving layer autocomplete + # will fill in any missing arguments based on + # prev_tool_call_arr vs streamed_args_for_tool. + if self.current_tool_index < len(self.streamed_args_for_tool): + self.streamed_args_for_tool[self.current_tool_index] += "}" + else: + logger.warning( + "streamed_args_for_tool out of sync: index=%d len=%d", + self.current_tool_index, + len(self.streamed_args_for_tool), + ) result = DeltaMessage( tool_calls=[ @@ -676,6 +700,19 @@ def extract_tool_calls_streaming( self.param_count += 1 + # Track what we've streamed so the serving + # layer can compute remaining args at the end. + if self.current_tool_index < len(self.streamed_args_for_tool): + self.streamed_args_for_tool[self.current_tool_index] += ( + json_fragment + ) + else: + logger.warning( + "streamed_args_for_tool out of sync: index=%d len=%d", + self.current_tool_index, + len(self.streamed_args_for_tool), + ) + return DeltaMessage( tool_calls=[ DeltaToolCall(