From 2f8476704c65a014f2b17706850ea08054ac8792 Mon Sep 17 00:00:00 2001 From: Martin Vit Date: Thu, 26 Feb 2026 16:16:59 +0000 Subject: [PATCH 1/3] [Bugfix] Fix Qwen3Coder tool call streaming with speculative decoding Fix three bugs that cause broken tool call JSON when using Qwen3CoderToolParser with speculative decoding (num_speculative_tokens >= 2): 1. serving.py: double-serialization of prev_tool_call_arr arguments. Tool parsers store arguments as a JSON string, but the serving layer called json.dumps() on it again, producing '"{\"k\":1}"'. This caused the replace() autocomplete logic to fail and append the entire double-serialized string as a spurious final delta. 2. qwen3coder_tool_parser.py: missing streamed_args_for_tool tracking. The parser never populated streamed_args_for_tool, causing IndexError when the serving layer accessed streamed_args_for_tool[index] at stream end. 3. qwen3coder_tool_parser.py: conditional "{" sending could be skipped. The condition `parameter_prefix not in delta_text` could prevent sending "{" while still setting json_started=True, desyncing the tracked state from what was actually streamed to the client. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Martin Vit --- .../openai/chat_completion/serving.py | 24 ++++++++--- vllm/tool_parsers/qwen3coder_tool_parser.py | 43 +++++++++++++++---- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 39f8635bf297..555f2966f912 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1249,13 +1249,25 @@ async def chat_completion_stream_generator( ) # get the expected call based on partial JSON - # parsing which "autocompletes" the JSON - expected_call = json.dumps( - tool_parser.prev_tool_call_arr[index].get( - "arguments", {} - ), - ensure_ascii=False, + # parsing which "autocompletes" the JSON. + # Tool parsers (e.g. Qwen3Coder) store + # arguments as a JSON string in + # prev_tool_call_arr. Calling json.dumps() + # on an already-serialized string would + # double-serialize it (e.g. '{"k":1}' becomes + # '"{\\"k\\":1}"'), which then causes the + # replace() below to fail and append the + # entire double-serialized string as a + # spurious final delta. + args = tool_parser.prev_tool_call_arr[index].get( + "arguments", {} ) + if isinstance(args, str): + expected_call = args + else: + expected_call = json.dumps( + args, ensure_ascii=False + ) # get what we've streamed so far for arguments # for the current tool diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index dfe790ee752e..68dafebe4e52 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -490,10 +490,17 @@ def extract_tool_calls_streaming( self.prev_tool_call_arr.append( { "name": self.current_function_name, - "arguments": "{}", # Placeholder, will be updated later + "arguments": "{}", } ) + # Initialize streamed args tracking for this tool. + # The serving layer reads streamed_args_for_tool to + # compute remaining arguments at stream end. Without + # this, IndexError occurs when the serving layer + # accesses streamed_args_for_tool[index]. + self.streamed_args_for_tool.append("") + # Send header with function info return DeltaMessage( tool_calls=[ @@ -511,9 +518,14 @@ def extract_tool_calls_streaming( # We've sent header, now handle function body if self.in_function: - # Send opening brace if not sent yet - if not self.json_started and self.parameter_prefix not in delta_text: + # Always send opening brace first, regardless of whether + # parameter_prefix is in the current delta. With speculative + # decoding, a single delta may contain both the opening brace + # and parameter data; skipping "{" here would desync + # json_started from what was actually streamed. + if not self.json_started: self.json_started = True + self.streamed_args_for_tool[self.current_tool_index] += "{" return DeltaMessage( tool_calls=[ DeltaToolCall( @@ -523,10 +535,6 @@ def extract_tool_calls_streaming( ] ) - # Make sure json_started is set if we're processing parameters - if not self.json_started: - self.json_started = True - # Check for function end in accumulated text if not self.json_closed and self.function_end_token in tool_text: # Close JSON @@ -558,7 +566,17 @@ def extract_tool_calls_streaming( self.prev_tool_call_arr[i]["arguments"] = args break except Exception: - pass # Ignore parsing errors during streaming + pass + + # Send closing brace; the serving layer autocomplete + # will fill in any missing arguments based on + # prev_tool_call_arr vs streamed_args_for_tool. + if self.current_tool_index < len( + self.streamed_args_for_tool + ): + self.streamed_args_for_tool[ + self.current_tool_index + ] += "}" result = DeltaMessage( tool_calls=[ @@ -676,6 +694,15 @@ def extract_tool_calls_streaming( self.param_count += 1 + # Track what we've streamed so the serving + # layer can compute remaining args at the end. + if self.current_tool_index < len( + self.streamed_args_for_tool + ): + self.streamed_args_for_tool[ + self.current_tool_index + ] += json_fragment + return DeltaMessage( tool_calls=[ DeltaToolCall( From fbe88beb24dfb2ee4ce40e94cc8f0168a9a25ebb Mon Sep 17 00:00:00 2001 From: Martin Vit Date: Thu, 26 Feb 2026 16:24:33 +0000 Subject: [PATCH 2/3] Address review feedback: add logging for silent failures - Replace bare `except: pass` with `logger.debug` + exc_info for tool call parsing errors during streaming. - Add `logger.warning` in else branches of streamed_args_for_tool bounds checks to surface state inconsistencies instead of failing silently. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Martin Vit --- vllm/tool_parsers/qwen3coder_tool_parser.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index 68dafebe4e52..314f8d77944e 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -566,7 +566,12 @@ def extract_tool_calls_streaming( self.prev_tool_call_arr[i]["arguments"] = args break except Exception: - pass + logger.debug( + "Failed to parse tool call during " + "streaming: %s", + tool_text, + exc_info=True, + ) # Send closing brace; the serving layer autocomplete # will fill in any missing arguments based on @@ -577,6 +582,13 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ self.current_tool_index ] += "}" + else: + logger.warning( + "streamed_args_for_tool out of sync: " + "index=%d len=%d", + self.current_tool_index, + len(self.streamed_args_for_tool), + ) result = DeltaMessage( tool_calls=[ @@ -702,6 +714,13 @@ def extract_tool_calls_streaming( self.streamed_args_for_tool[ self.current_tool_index ] += json_fragment + else: + logger.warning( + "streamed_args_for_tool out of sync: " + "index=%d len=%d", + self.current_tool_index, + len(self.streamed_args_for_tool), + ) return DeltaMessage( tool_calls=[ From 9c2e48b197f0ca77021fe8b7c8ee91761d6f9699 Mon Sep 17 00:00:00 2001 From: Martin Vit Date: Thu, 26 Feb 2026 19:00:15 +0000 Subject: [PATCH 3/3] Style: apply ruff format Co-Authored-By: Claude Opus 4.6 Signed-off-by: Martin Vit --- .../openai/chat_completion/serving.py | 4 +-- vllm/tool_parsers/qwen3coder_tool_parser.py | 27 +++++++------------ 2 files changed, 10 insertions(+), 21 deletions(-) diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 555f2966f912..06b16cde6748 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -1265,9 +1265,7 @@ async def chat_completion_stream_generator( if isinstance(args, str): expected_call = args else: - expected_call = json.dumps( - args, ensure_ascii=False - ) + expected_call = json.dumps(args, ensure_ascii=False) # get what we've streamed so far for arguments # for the current tool diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index 314f8d77944e..14a67ec38e47 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -567,8 +567,7 @@ def extract_tool_calls_streaming( break except Exception: logger.debug( - "Failed to parse tool call during " - "streaming: %s", + "Failed to parse tool call during streaming: %s", tool_text, exc_info=True, ) @@ -576,16 +575,11 @@ def extract_tool_calls_streaming( # Send closing brace; the serving layer autocomplete # will fill in any missing arguments based on # prev_tool_call_arr vs streamed_args_for_tool. - if self.current_tool_index < len( - self.streamed_args_for_tool - ): - self.streamed_args_for_tool[ - self.current_tool_index - ] += "}" + if self.current_tool_index < len(self.streamed_args_for_tool): + self.streamed_args_for_tool[self.current_tool_index] += "}" else: logger.warning( - "streamed_args_for_tool out of sync: " - "index=%d len=%d", + "streamed_args_for_tool out of sync: index=%d len=%d", self.current_tool_index, len(self.streamed_args_for_tool), ) @@ -708,16 +702,13 @@ def extract_tool_calls_streaming( # Track what we've streamed so the serving # layer can compute remaining args at the end. - if self.current_tool_index < len( - self.streamed_args_for_tool - ): - self.streamed_args_for_tool[ - self.current_tool_index - ] += json_fragment + if self.current_tool_index < len(self.streamed_args_for_tool): + self.streamed_args_for_tool[self.current_tool_index] += ( + json_fragment + ) else: logger.warning( - "streamed_args_for_tool out of sync: " - "index=%d len=%d", + "streamed_args_for_tool out of sync: index=%d len=%d", self.current_tool_index, len(self.streamed_args_for_tool), )