Skip to content

[Bugfix] Fix Qwen3Coder tool call streaming with speculative decoding#35421

Closed
voipmonitor wants to merge 4 commits intovllm-project:mainfrom
voipmonitor:fix/spec-decode-tool-call-streaming
Closed

[Bugfix] Fix Qwen3Coder tool call streaming with speculative decoding#35421
voipmonitor wants to merge 4 commits intovllm-project:mainfrom
voipmonitor:fix/spec-decode-tool-call-streaming

Conversation

@voipmonitor
Copy link
Copy Markdown
Contributor

Summary

Fixes broken tool call JSON when using Qwen3CoderToolParser (--tool-call-parser qwen3_coder) with speculative decoding (--num-speculative-tokens N where N >= 2).

PR #35347 partially addressed this for num_speculative_tokens: 2 by adding Qwen35CoderToolParser, but the fix still fails for num_speculative_tokens: 3, 4, 5 because the root cause is in the serving layer and in missing streamed_args_for_tool tracking in the base Qwen3CoderToolParser. This PR fixes the actual root causes directly in Qwen3CoderToolParser, so no separate parser class is needed.

Root Cause Analysis

Three bugs interact to produce malformed tool call JSON:

Bug 1: Double-serialization in serving.py (the primary cause)

prev_tool_call_arr[index]["arguments"] is already a JSON string (e.g. '{"city": "Prague"}'), but the serving layer calls json.dumps() on it again:

# BEFORE (broken):
expected_call = json.dumps(
    tool_parser.prev_tool_call_arr[index].get("arguments", {}),
    ensure_ascii=False,
)
# Result: '"{\\"city\\": \\"Prague\\"}"' (double-serialized!)

This causes expected_call.replace(actual_call, "", 1) to fail (the format doesn't match), so remaining_call equals the entire double-serialized string, which gets appended as a spurious final delta. The client then assembles:

{"city": "Prague""{\"city\": \"Prague\"}"

Bug 2: Missing streamed_args_for_tool tracking in Qwen3CoderToolParser

The parser never populates self.streamed_args_for_tool, but the serving layer reads it at stream end (streamed_args_for_tool[index]), causing IndexError: list index out of range.

Bug 3: Conditional { sending can be skipped

The original condition if not self.json_started and self.parameter_prefix not in delta_text skips sending { when the delta contains parameter data (common with speculative decoding where multiple tokens arrive at once). But json_started is then set to True anyway, desyncing the tracked state from what was actually streamed.

Changes

  • serving.py: Check if arguments is already a string before calling json.dumps(), preventing double-serialization.
  • qwen3coder_tool_parser.py:
    • Add streamed_args_for_tool tracking at all argument emission points (header, {, parameters, }).
    • Always send { first regardless of delta content, removing the parameter_prefix not in delta_text condition.
    • Remove dead code (if not self.json_started: self.json_started = True was unreachable after the fix).

How to Reproduce

# Start vLLM with speculative decoding and Qwen3Coder tool parser
vllm serve <model> \
  --tool-call-parser qwen3_coder \
  --speculative-model <draft_model> \
  --num-speculative-tokens 5

# Send a multi-turn tool call conversation. On the second turn
# (when the assistant's tool_call response is sent back), the
# server crashes with either:
#   - json.JSONDecodeError: Expecting ',' delimiter
#   - IndexError: list index out of range (streamed_args_for_tool)

With num_speculative_tokens: 1 it works. With :2 it may work (depending on token boundaries). With :3 and above it consistently fails.

Test Plan

  • Tested with num_speculative_tokens: 1, 2, 3, 4, 5 — all work after the fix
  • Tested with multi-parameter tool calls (arrays, strings, integers)
  • Verified no regression for non-speculative-decoding tool calls
  • Unit tests for streamed_args_for_tool population

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses several bugs related to tool call streaming with speculative decoding in the Qwen3CoderToolParser. The changes prevent double-serialization of arguments in the serving layer and correctly track streamed arguments in the parser to avoid IndexError and state desynchronization. The fixes appear correct and well-targeted. I've added a few comments with high severity regarding potential silent failures in exception handling and state management that could lead to incorrect JSON output, which would undermine the goal of this bugfix.

@mergify mergify bot added frontend qwen Related to Qwen models bug Something isn't working labels Feb 26, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Feb 26, 2026

Hi @voipmonitor, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@jmfirth-arkane
Copy link
Copy Markdown

jmfirth-arkane commented Feb 27, 2026

Thanks, this patch is working great at MTP=3. Test Dockerfile:

FROM vllm/vllm-openai:qwen3_5

# Grab the PR diff and apply it directly to the installed package
RUN apt-get update && apt-get install -y patch curl && \
    curl -L https://github.com/vllm-project/vllm/pull/35421.diff -o /tmp/pr.diff && \
    # Find where vllm is installed
    VLLM_PATH=$(python3 -c "import vllm; import os; print(os.path.dirname(vllm.__file__))") && \
    # Apply the patch - strip the leading 'vllm/' prefix to match installed paths
    cd "$VLLM_PATH" && \
    patch -p2 --forward < /tmp/pr.diff || true && \
    rm /tmp/pr.diff

Copy link
Copy Markdown
Collaborator

@chaunceyjiang chaunceyjiang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks~

@chaunceyjiang
Copy link
Copy Markdown
Collaborator

you need to DCO.

@chaunceyjiang chaunceyjiang self-assigned this Feb 28, 2026
voipmonitor and others added 3 commits February 28, 2026 09:22
Fix three bugs that cause broken tool call JSON when using
Qwen3CoderToolParser with speculative decoding (num_speculative_tokens >= 2):

1. serving.py: double-serialization of prev_tool_call_arr arguments.
   Tool parsers store arguments as a JSON string, but the serving layer
   called json.dumps() on it again, producing '"{\"k\":1}"'. This caused
   the replace() autocomplete logic to fail and append the entire
   double-serialized string as a spurious final delta.

2. qwen3coder_tool_parser.py: missing streamed_args_for_tool tracking.
   The parser never populated streamed_args_for_tool, causing IndexError
   when the serving layer accessed streamed_args_for_tool[index] at
   stream end.

3. qwen3coder_tool_parser.py: conditional "{" sending could be skipped.
   The condition `parameter_prefix not in delta_text` could prevent
   sending "{" while still setting json_started=True, desyncing the
   tracked state from what was actually streamed to the client.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
- Replace bare `except: pass` with `logger.debug` + exc_info for
  tool call parsing errors during streaming.
- Add `logger.warning` in else branches of streamed_args_for_tool
  bounds checks to surface state inconsistencies instead of failing
  silently.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
@voipmonitor voipmonitor force-pushed the fix/spec-decode-tool-call-streaming branch from f40dde2 to 9c2e48b Compare February 28, 2026 08:22
@voipmonitor
Copy link
Copy Markdown
Contributor Author

you need to DCO.

fixed

voipmonitor added a commit to voipmonitor/vllm that referenced this pull request Feb 28, 2026
…loss

With speculative decoding (e.g. num_speculative_tokens=5), token bursts
can deliver multiple complete parameters and </function> in a single
delta.  The existing streaming parser had two ordering/control-flow bugs
that caused parameters to be silently dropped, leading to tool calls
with missing fields (e.g. `content: undefined` on a file-write tool).

Root cause 1 — close-before-params ordering:
The </function> check ran BEFORE the parameter extraction loop.  When a
burst delivered the final parameter together with </function>, the close
handler fired first, emitted "}", and set in_function=False — the
parameter was never processed.

Root cause 2 — single-pass + early return:
The parameter extraction used a single `if` block that processed exactly
one parameter per call. When the current parameter was incomplete it
executed `return None`, discarding any already-complete parameters from
the same burst that hadn't been emitted yet.

Root cause 3 — already_added dedup:
prev_tool_call_arr entries were deduplicated by function name, so two
calls to the same function (e.g. two consecutive `read` calls) would
share a single entry, causing IndexError or wrong argument updates.

Fixes:
- Reorder: process parameters FIRST (while-loop), check </function>
  AFTER, so no parameter can be skipped by an early close.
- Loop + break: replace the single-pass `if` with a `while` loop that
  accumulates all complete parameter fragments, using `break` instead
  of `return None` when one is incomplete, ensuring earlier complete
  params are still emitted.
- Always-append: remove the already_added check; each tool call gets
  its own entry indexed by current_tool_index.
- Index-based update: update prev_tool_call_arr by current_tool_index
  instead of name-based search (companion fix for always-append).
- streamed_args_for_tool tracking: append "" on header, accumulate
  fragments and "}" so the serving layer can compute remaining args.
- serving.py: guard against double-serialization when prev_tool_call_arr
  stores arguments as a JSON string (isinstance check).
- Remove dead in_param branch (never activated in practice).

Supersedes vllm-project#35421 which addressed streamed_args_for_tool tracking and
the json_started condition but did not fix the critical close-before-
params ordering or the single-pass early-return bugs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
voipmonitor added a commit to voipmonitor/vllm that referenced this pull request Feb 28, 2026
…loss

With speculative decoding (e.g. num_speculative_tokens=5), token bursts
can deliver multiple complete parameters and </function> in a single
delta.  The existing streaming parser had two ordering/control-flow bugs
that caused parameters to be silently dropped, leading to tool calls
with missing fields (e.g. `content: undefined` on a file-write tool).

Root cause 1 — close-before-params ordering:
The </function> check ran BEFORE the parameter extraction loop.  When a
burst delivered the final parameter together with </function>, the close
handler fired first, emitted "}", and set in_function=False — the
parameter was never processed.

Root cause 2 — single-pass + early return:
The parameter extraction used a single `if` block that processed exactly
one parameter per call. When the current parameter was incomplete it
executed `return None`, discarding any already-complete parameters from
the same burst that hadn't been emitted yet.

Root cause 3 — already_added dedup:
prev_tool_call_arr entries were deduplicated by function name, so two
calls to the same function (e.g. two consecutive `read` calls) would
share a single entry, causing IndexError or wrong argument updates.

Fixes:
- Reorder: process parameters FIRST (while-loop), check </function>
  AFTER, so no parameter can be skipped by an early close.
- Loop + break: replace the single-pass `if` with a `while` loop that
  accumulates all complete parameter fragments, using `break` instead
  of `return None` when one is incomplete, ensuring earlier complete
  params are still emitted.
- Always-append: remove the already_added check; each tool call gets
  its own entry indexed by current_tool_index.
- Index-based update: update prev_tool_call_arr by current_tool_index
  instead of name-based search (companion fix for always-append).
- streamed_args_for_tool tracking: append "" on header, accumulate
  fragments and "}" so the serving layer can compute remaining args.
- serving.py: guard against double-serialization when prev_tool_call_arr
  stores arguments as a JSON string (isinstance check).
- Remove dead in_param branch (never activated in practice).

Supersedes vllm-project#35421 which addressed streamed_args_for_tool tracking and
the json_started condition but did not fix the critical close-before-
params ordering or the single-pass early-return bugs.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Martin Vit <martin@voipmonitor.org>
@voipmonitor
Copy link
Copy Markdown
Contributor Author

voipmonitor commented Feb 28, 2026

all those fixes are still not correct and fails at some scenarios with speculative decoding >3 - I'm closing this PR and creating new one: #35615 @chaunceyjiang please check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working frontend qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants