Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions vllm/tool_parsers/gemma4_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def _reset_streaming_state(self) -> None:
self.current_tool_name_sent = False
self.prev_tool_call_arr: list[dict] = []
self.streamed_args_for_tool: list[str] = []
self.tool_call_ids: dict[int, str] = {}
self.tool_names: dict[int, str] = {}

def adjust_request(
self, request: ChatCompletionRequest | ResponsesRequest
Expand Down Expand Up @@ -628,6 +630,11 @@ def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
# Step 1: Send function name (once)
if not self.current_tool_name_sent and func_name:
self.current_tool_name_sent = True
tool_call_id = make_tool_call_id()
# Store id and name for re-emission in subsequent chunks
# (required for strict clients like @ai-sdk/OpenCode)
self.tool_call_ids[self.current_tool_id] = tool_call_id
self.tool_names[self.current_tool_id] = func_name
self.prev_tool_call_arr[self.current_tool_id] = {
"name": func_name,
"arguments": {},
Expand All @@ -637,7 +644,7 @@ def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
DeltaToolCall(
index=self.current_tool_id,
type="function",
id=make_tool_call_id(),
id=tool_call_id,
function=DeltaFunctionCall(
name=func_name,
arguments="",
Expand Down Expand Up @@ -680,13 +687,21 @@ def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args

# Re-emit id, type, and function.name for strict clients
# (e.g., @ai-sdk/OpenCode) that validate every chunk
tool_call_id = self.tool_call_ids.get(self.current_tool_id, "")
function_name = self.tool_names.get(self.current_tool_id, "")
Comment on lines +692 to +693

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If a tool call is completed within a single chunk (or before _handle_tool_call_middle is invoked), self.tool_call_ids and self.tool_names will not contain entries for the current tool ID. This results in empty strings being emitted for id and name, which will likely cause validation failures in strict clients like @ai-sdk/OpenCode. We should ensure these fields are populated using fallbacks in _handle_tool_call_end.

                tool_call_id = self.tool_call_ids.get(self.current_tool_id)
                if tool_call_id is None:
                    tool_call_id = make_tool_call_id()
                    self.tool_call_ids[self.current_tool_id] = tool_call_id

                function_name = self.tool_names.get(self.current_tool_id)
                if function_name is None:
                    function_name = all_matches[self.current_tool_id][0]
                    self.tool_names[self.current_tool_id] = function_name


return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
id=tool_call_id,
type="function",
function=DeltaFunctionCall(
name=function_name,
arguments=diff,
).model_dump(exclude_none=True),
)
]
)
Expand Down Expand Up @@ -776,13 +791,21 @@ def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
self.streamed_args_for_tool[self.current_tool_id] = safe_json
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args

# Re-emit id, type, and function.name for strict clients
# (e.g., @ai-sdk/OpenCode) that validate every chunk
tool_call_id = self.tool_call_ids.get(self.current_tool_id, "")
function_name = self.tool_names.get(self.current_tool_id, "")

return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_id,
function=DeltaFunctionCall(arguments=diff).model_dump(
exclude_none=True
),
id=tool_call_id,
type="function",
function=DeltaFunctionCall(
name=function_name,
arguments=diff,
).model_dump(exclude_none=True),
)
]
)
Expand Down
Loading