Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions vllm/tool_parsers/qwen3coder_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ def _convert_param_value(
)
return param_value

def _ensure_streamed_args_for_tool_entry(self) -> None:
"""Ensures self.streamed_args_for_tool has an entry for current tool."""
if len(self.streamed_args_for_tool) <= self.current_tool_index:
self.streamed_args_for_tool.append("")

def _parse_xml_function_call(self, function_call_str: str) -> ToolCall | None:
# Extract function name
end_index = function_call_str.find(">")
Expand Down Expand Up @@ -486,6 +491,12 @@ def extract_tool_calls_streaming(
# accesses streamed_args_for_tool[index].
self.streamed_args_for_tool.append("")

# Ensure streamed_args_for_tool list has entry for this tool
self._ensure_streamed_args_for_tool_entry()

# Ensure streamed_args_for_tool list has entry for this tool
self._ensure_streamed_args_for_tool_entry()

# Send header with function info
return DeltaMessage(
tool_calls=[
Expand All @@ -510,6 +521,8 @@ def extract_tool_calls_streaming(
# json_started from what was actually streamed.
if not self.json_started:
self.json_started = True
# Update streamed_args_for_tool to track what we've sent
self._ensure_streamed_args_for_tool_entry()
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage(
tool_calls=[
Expand Down Expand Up @@ -687,6 +700,215 @@ def extract_tool_calls_streaming(

return result

# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)

# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]

if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]

# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]

# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or
# function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token)

if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.tool_call_end_token in tool_text:
# Tool call is complete, so parameter
# must be complete too. Use all
# remaining text before function end
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None

if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]

# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value

# Get parameter configuration for type conversion
param_config = find_tool_properties(
self.streaming_request.tools
if self.streaming_request
else self.tools,
self.current_function_name or "",
)

# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value,
self.current_param_name,
param_config,
self.current_function_name or "",
)

# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)

if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)

self.param_count += 1

self._ensure_streamed_args_for_tool_entry()
self.streamed_args_for_tool[self.current_tool_index] += (
json_fragment
)

return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)

# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]

# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]

if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]

# Store complete value
full_value = self.current_param_value + value_chunk
self.accumulated_params[self.current_param_name] = full_value

# Get parameter configuration for type conversion
param_config = find_tool_properties(
self.streaming_request.tools
if self.streaming_request
else self.tools,
self.current_function_name or "",
)

# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
full_value,
self.current_param_name or "",
param_config,
self.current_function_name or "",
)

# Serialize the converted value
serialized_value = json.dumps(converted_value, ensure_ascii=False)

# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""

# Just close the current parameter string
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments='"'
), # Close the string quote
)
]
)
else:
# Continue accumulating value
value_chunk = delta_text

# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]

if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]

if value_chunk:
# Stream the escaped delta
prev_escaped = (
json.dumps(self.current_param_value, ensure_ascii=False)[
1:-1
]
if self.current_param_value
else ""
)
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]

if delta_escaped:
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped
),
)
]
)

return None

def get_structural_tag(self, request: ChatCompletionRequest):
Expand Down
Loading