From 7d03fd7457944aaacd68c7180b18d577582fa573 Mon Sep 17 00:00:00 2001 From: ec-jt Date: Sun, 22 Mar 2026 21:26:01 +0000 Subject: [PATCH] tool_parser: stop string streaming at next boundary Signed-off-by: ec-jt --- vllm/tool_parsers/qwen3coder_tool_parser.py | 531 +++++++++++++------- 1 file changed, 357 insertions(+), 174 deletions(-) diff --git a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py index 216ae163b77a..1ec1ab9a6a21 100644 --- a/vllm/tool_parsers/qwen3coder_tool_parser.py +++ b/vllm/tool_parsers/qwen3coder_tool_parser.py @@ -62,8 +62,10 @@ def __init__(self, tokenizer: TokenizerLike): self.tool_call_function_regex = re.compile( r"| tag + # This allows to appear in parameter content self.tool_call_parameter_regex = re.compile( - r"|(?=)|$)", + r")|$)", re.DOTALL, ) @@ -108,6 +110,112 @@ def _reset_streaming_state(self): # Store accumulated parameters for type conversion self.accumulated_params = {} self.streaming_request = None + # Incremental string streaming state (ported from GLM4 parser) + self._streaming_string_value = False + self._value_buffer = "" + + @staticmethod + def _json_escape_string_content(s: str) -> str: + """JSON-escape string content for incremental streaming.""" + if not s: + return "" + return json.dumps(s, ensure_ascii=False)[1:-1] + + @staticmethod + def _strip_param_delimiter_prefix(value: str) -> str: + """Strip the single delimiter newline right after a parameter tag.""" + if value.startswith("\r\n"): + return value[2:] + if value.startswith("\n"): + return value[1:] + return value + + @staticmethod + def _strip_param_delimiter_suffix(value: str) -> str: + """Strip the single delimiter newline right before a close tag.""" + if value.endswith("\r\n"): + return value[:-2] + if value.endswith("\n"): + return value[:-1] + if value.endswith("\r"): + return value[:-1] + return value + + def _ensure_streaming_tool_state(self, tool_index: int) -> None: + """Ensure per-tool arrays are allocated for the given index.""" + while len(self.streamed_args_for_tool) <= tool_index: + self.streamed_args_for_tool.append("") + while len(self.prev_tool_call_arr) <= tool_index: + self.prev_tool_call_arr.append({"name": "", "arguments": "{}"}) + + def _emit_tool_args_delta(self, fragment: str) -> DeltaMessage | None: + """Emit a tool args fragment and keep streamed state synchronized.""" + if not fragment: + return None + + self._ensure_streaming_tool_state(self.current_tool_index) + self.streamed_args_for_tool[self.current_tool_index] += fragment + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_index, + function=DeltaFunctionCall(arguments=fragment), + ) + ] + ) + + def _sync_current_tool_call_arguments( + self, + tool_text: str, + request: ChatCompletionRequest, + ) -> None: + """Sync finalized function args into prev_tool_call_arr when parsable.""" + if self.current_tool_index >= len(self.prev_tool_call_arr): + return + if ( + self.tool_call_prefix not in tool_text + or self.function_end_token not in tool_text + ): + return + + start = tool_text.find(self.tool_call_prefix) + len(self.tool_call_prefix) + end = tool_text.find(self.function_end_token, start) + if end == -1: + return + + function_call_str = tool_text[start:end] + parsed = self._parse_xml_function_call(function_call_str, request.tools) + if parsed is None: + return + + self.prev_tool_call_arr[self.current_tool_index] = { + "name": parsed.function.name, + "arguments": parsed.function.arguments, + } + + def _is_string_type(self, param_name: str) -> bool: + """Check if a parameter is string type based on tool schema.""" + if not self.streaming_request or not self.streaming_request.tools: + return True # Default to string for unknown params + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools, + ) + if param_name not in param_config: + return True # Default to string for unknown params + param_type = ( + str(param_config[param_name].get("type", "string")).strip().lower() + if isinstance(param_config.get(param_name), dict) + else "string" + ) + return param_type in [ + "string", + "str", + "text", + "varchar", + "char", + "enum", + ] def _get_arguments_config( self, func_name: str, tools: list[ChatCompletionToolsParam] | None @@ -262,10 +370,12 @@ def _parse_xml_function_call( param_name = match_text[:idx] param_value = str(match_text[idx + 1 :]) # Remove prefix and trailing \n - if param_value.startswith("\n"): - param_value = param_value[1:] - if param_value.endswith("\n"): - param_value = param_value[:-1] + param_value = self._strip_param_delimiter_prefix(param_value) + param_value = self._strip_param_delimiter_suffix(param_value) + + # Strip trailing tag if present + # (since we use structural boundaries) + param_value = re.sub(r"\s*\s*$", "", param_value) param_dict[param_name] = self._convert_param_value( param_value, param_name, param_config, function_name @@ -362,6 +472,8 @@ def extract_tool_calls_streaming( if not previous_text: self._reset_streaming_state() self.streaming_request = request + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [] # If no delta text, return None unless it's an EOS token after tools if not delta_text: @@ -430,13 +542,21 @@ def extract_tool_calls_streaming( return DeltaMessage(content=content_before) return None else: - # Check if we're between tool calls - skip whitespace + # Check if we're between/after tool calls - skip whitespace if ( current_text.rstrip().endswith(self.tool_call_end_token) and delta_text.strip() == "" ): # We just ended a tool call, skip whitespace return None + # Also skip whitespace-only content if any tool calls + # have been completed (prevents trailing \n after last + # tool call from being emitted as content) + if ( + delta_text.strip() == "" + and self.tool_call_end_token in current_text + ): + return None # Normal content, no tool call return DeltaMessage(content=delta_text) @@ -490,19 +610,18 @@ def extract_tool_calls_streaming( # Always append — each tool call is a separate # invocation even if the function name is the same # (e.g. two consecutive "read" calls). - self.prev_tool_call_arr.append( - { - "name": self.current_function_name, - "arguments": "{}", - } - ) + self._ensure_streaming_tool_state(self.current_tool_index) + self.prev_tool_call_arr[self.current_tool_index] = { + "name": self.current_function_name, + "arguments": "{}", + } # Initialize streamed args tracking for this tool. # The serving layer reads streamed_args_for_tool to # compute remaining arguments at stream end. Without # this, IndexError occurs when the serving layer # accesses streamed_args_for_tool[index]. - self.streamed_args_for_tool.append("") + self.streamed_args_for_tool[self.current_tool_index] = "" # Send header with function info return DeltaMessage( @@ -528,185 +647,249 @@ def extract_tool_calls_streaming( # json_started from what was actually streamed. if not self.json_started: self.json_started = True - self.streamed_args_for_tool[self.current_tool_index] += "{" - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="{"), - ) - ] - ) - - # Find all parameter start positions in current tool_text - param_starts = [] - search_idx = 0 - while True: - search_idx = tool_text.find(self.parameter_prefix, search_idx) - if search_idx == -1: - break - param_starts.append(search_idx) - search_idx += len(self.parameter_prefix) - - # Process ALL complete params in a loop (spec decode fix). - # With speculative decoding a single delta can deliver - # multiple complete parameters at once. The old single-pass - # code would process one and ``return None`` if the next was - # incomplete — skipping any already-complete params that - # preceded it. Using a loop with ``break`` instead ensures - # we emit every complete parameter before yielding control. - json_fragments = [] - while not self.in_param and self.param_count < len(param_starts): - param_idx = param_starts[self.param_count] - param_start = param_idx + len(self.parameter_prefix) - remaining = tool_text[param_start:] + return self._emit_tool_args_delta("{") + + # ------------------------------------------------------- + # Handle incremental string value streaming + # ------------------------------------------------------- + if self._streaming_string_value: + # We're in the middle of streaming a string parameter + # value incrementally. Extract the current value from + # the accumulated tool_text. + param_tag = f"{self.parameter_prefix}{self.current_param_name}>" + param_tag_pos = tool_text.find(param_tag) + if param_tag_pos == -1: + return None - if ">" not in remaining: - break + value_start_pos = param_tag_pos + len(param_tag) + value_text = tool_text[value_start_pos:] + # Strip leading newline (Qwen3 format puts \n after >) + value_text = self._strip_param_delimiter_prefix(value_text) + + # Check if parameter value is complete. Function/tool close can + # also terminate a parameter in malformed/fragmented streams. + boundary_candidates: list[tuple[int, str]] = [] + for token, kind in ( + (self.parameter_end_token, "parameter"), + (self.parameter_prefix, "next_param"), + (self.function_end_token, "function"), + (self.tool_call_end_token, "tool"), + ): + pos = value_text.find(token) + if pos != -1: + boundary_candidates.append((pos, kind)) - name_end = remaining.find(">") - current_param_name = remaining[:name_end] + if boundary_candidates: + val_end, boundary_kind = min( + boundary_candidates, key=lambda x: x[0] + ) + # Parameter complete - emit remaining content and + # close the JSON string quote + remaining_value = value_text[:val_end] + remaining_value = self._strip_param_delimiter_suffix( + remaining_value + ) + new_content = remaining_value[len(self._value_buffer) :] + self._value_buffer = "" + self._streaming_string_value = False + self.param_count += 1 - value_start = param_start + name_end + 1 - value_text = tool_text[value_start:] - if value_text.startswith("\n"): - value_text = value_text[1:] + escaped = self._json_escape_string_content(new_content) + frag = escaped + '"' - param_end_idx = value_text.find(self.parameter_end_token) - if param_end_idx == -1: - next_param_idx = value_text.find(self.parameter_prefix) - func_end_idx = value_text.find(self.function_end_token) + if boundary_kind in {"function", "tool"}: + self._sync_current_tool_call_arguments(tool_text, request) + self.in_function = False + self.json_closed = True + frag += "}" - if next_param_idx != -1 and ( - func_end_idx == -1 or next_param_idx < func_end_idx + return self._emit_tool_args_delta(frag) + else: + # Parameter still streaming - emit safe content + # Hold back trailing \n and partial tags. + # In Qwen3-Coder format, \n precedes as a + # delimiter and should not be part of the value. + current_value = value_text + + # Find content safe to emit (not part of closing tag) + # Check for partial at end of buffer + safe_len = len(current_value) + + # Hold back trailing \n (delimiter before ) + if current_value.endswith("\n"): + safe_len = len(current_value) - 1 + + # Hold back partial boundary suffixes (close tags + next + # parameter start token). Covers cases like "\n<", + # "\n - # is missing. Use as a delimiter - # if present in the value so we don't include - # the closing tag as part of the param value. - tool_end_in_value = value_text.find(self.tool_call_end_token) - if tool_end_in_value != -1: - param_end_idx = tool_end_in_value - else: - # Parameter incomplete — break so we still - # emit any fragments accumulated by earlier - # loop iterations. - break + for i in range(1, len(close_token)): + if current_value.endswith(close_token[:i]): + candidate_len = len(current_value) - i + # Also hold back the delimiter newline that + # commonly precedes closing tags. + if ( + candidate_len > 0 + and current_value[candidate_len - 1] == "\n" + ): + candidate_len -= 1 + if ( + candidate_len > 0 + and current_value[candidate_len - 1] == "\r" + ): + candidate_len -= 1 + safe_len = min(safe_len, candidate_len) + break + + safe_value = current_value[: max(safe_len, 0)] + new_content = safe_value[len(self._value_buffer) :] + if new_content: + self._value_buffer = safe_value + escaped = self._json_escape_string_content(new_content) + if escaped: + return self._emit_tool_args_delta(escaped) + return None - if param_end_idx == -1: + # ------------------------------------------------------- + # Look for parameters + # ------------------------------------------------------- + param_starts = [] + idx = 0 + while True: + idx = tool_text.find(self.parameter_prefix, idx) + if idx == -1: break + param_starts.append(idx) + idx += len(self.parameter_prefix) - param_value = value_text[:param_end_idx] - if param_value.endswith("\n"): - param_value = param_value[:-1] - - self.current_param_name = current_param_name - self.accumulated_params[current_param_name] = param_value - - param_config = self._get_arguments_config( - self.current_function_name or "", - self.streaming_request.tools if self.streaming_request else None, - ) - - converted_value = self._convert_param_value( - param_value, - current_param_name, - param_config, - self.current_function_name or "", - ) - - serialized_value = json.dumps(converted_value, ensure_ascii=False) - - if self.param_count == 0: - json_fragment = f'"{current_param_name}": {serialized_value}' - else: - json_fragment = f', "{current_param_name}": {serialized_value}' + # Check if we should start a new parameter + if not self.in_param and len(param_starts) > self.param_count: + # Process the next parameter + param_idx = param_starts[self.param_count] + param_start = param_idx + len(self.parameter_prefix) + remaining = tool_text[param_start:] - self.param_count += 1 - json_fragments.append(json_fragment) + if ">" in remaining: + # We have the complete parameter name + name_end = remaining.find(">") + self.current_param_name = remaining[:name_end] + + # Check if this is a string type parameter + is_string = self._is_string_type(self.current_param_name) + + # Find the parameter value start + value_start = param_start + name_end + 1 + value_text = tool_text[value_start:] + value_text = self._strip_param_delimiter_prefix(value_text) + + if is_string: + # ----------------------------------------- + # String type: use incremental streaming + # ----------------------------------------- + # Emit the key and opening quote immediately + if self.param_count == 0: + key_frag = f'"{self.current_param_name}": "' + else: + key_frag = f', "{self.current_param_name}": "' - if json_fragments: - combined = "".join(json_fragments) + self._streaming_string_value = True + self._value_buffer = "" - if self.current_tool_index < len(self.streamed_args_for_tool): - self.streamed_args_for_tool[self.current_tool_index] += combined - else: - logger.warning( - "streamed_args_for_tool out of sync: index=%d len=%d", - self.current_tool_index, - len(self.streamed_args_for_tool), - ) + return self._emit_tool_args_delta(key_frag) + else: + # ----------------------------------------- + # Non-string type: wait for complete value + # ----------------------------------------- + param_end_idx = value_text.find(self.parameter_end_token) + boundary_kind = "parameter" + if param_end_idx == -1: + # No closing tag yet, look for boundaries + next_param_idx = value_text.find(self.parameter_prefix) + func_end_idx = value_text.find(self.function_end_token) + tool_end_idx = value_text.find(self.tool_call_end_token) + + boundary_candidates = [] + if next_param_idx != -1: + boundary_candidates.append( + (next_param_idx, "next_param") + ) + if func_end_idx != -1: + boundary_candidates.append((func_end_idx, "function")) + if tool_end_idx != -1: + boundary_candidates.append((tool_end_idx, "tool")) + + if boundary_candidates: + param_end_idx, boundary_kind = min( + boundary_candidates, key=lambda x: x[0] + ) + else: + return None + + if param_end_idx != -1: + param_value = value_text[:param_end_idx] + param_value = self._strip_param_delimiter_suffix( + param_value + ) - return DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments=combined), - ) - ] - ) + self.accumulated_params[self.current_param_name] = ( + param_value + ) - # Check for function end AFTER processing parameters. - # This ordering is critical: with speculative decoding a - # burst can deliver the final parameter value together with - # . If the close check ran first it would emit - # "}" and set in_function=False before the parameter loop - # ever ran, causing the parameter to be silently dropped. - if not self.json_closed and self.function_end_token in tool_text: - self.json_closed = True + param_config = self._get_arguments_config( + self.current_function_name or "", + self.streaming_request.tools + if self.streaming_request + else None, + ) - func_start = tool_text.find(self.tool_call_prefix) + len( - self.tool_call_prefix - ) - func_content_end = tool_text.find(self.function_end_token, func_start) - if func_content_end != -1: - func_content = tool_text[func_start:func_content_end] - try: - parsed_tool = self._parse_xml_function_call( - func_content, - self.streaming_request.tools - if self.streaming_request - else None, - ) - if parsed_tool and self.current_tool_index < len( - self.prev_tool_call_arr - ): - self.prev_tool_call_arr[self.current_tool_index][ - "arguments" - ] = parsed_tool.function.arguments - except Exception: - logger.debug( - "Failed to parse tool call during streaming: %s", - tool_text, - exc_info=True, - ) - - if self.current_tool_index < len(self.streamed_args_for_tool): - self.streamed_args_for_tool[self.current_tool_index] += "}" - else: - logger.warning( - "streamed_args_for_tool out of sync: index=%d len=%d", - self.current_tool_index, - len(self.streamed_args_for_tool), - ) + converted_value = self._convert_param_value( + param_value, + self.current_param_name, + param_config, + self.current_function_name or "", + ) - result = DeltaMessage( - tool_calls=[ - DeltaToolCall( - index=self.current_tool_index, - function=DeltaFunctionCall(arguments="}"), - ) - ] - ) + serialized_value = json.dumps( + converted_value, ensure_ascii=False + ) + if self.param_count == 0: + json_fragment = ( + f'"{self.current_param_name}": {serialized_value}' + ) + else: + json_fragment = ( + f', "{self.current_param_name}": {serialized_value}' + ) + + self.param_count += 1 + + if boundary_kind in {"function", "tool"}: + self._sync_current_tool_call_arguments( + tool_text, request + ) + self.in_function = False + self.json_closed = True + json_fragment += "}" + + return self._emit_tool_args_delta(json_fragment) + + # Function ended and all started parameters have been flushed: + # emit closing brace exactly once. + if ( + not self.json_closed + and not self._streaming_string_value + and self.function_end_token in tool_text + and self.param_count >= tool_text.count(self.parameter_prefix) + ): + self._sync_current_tool_call_arguments(tool_text, request) self.in_function = False self.json_closed = True - self.accumulated_params = {} - - return result + return self._emit_tool_args_delta("}") return None