Skip to content

fix: integrate tool call parsing with reasoning parser in streaming mode#148

Open
jsirish wants to merge 1 commit intowaybarrios:mainfrom
dynamic:fix/streaming-tool-call-with-reasoning
Open

fix: integrate tool call parsing with reasoning parser in streaming mode#148
jsirish wants to merge 1 commit intowaybarrios:mainfrom
dynamic:fix/streaming-tool-call-with-reasoning

Conversation

@jsirish
Copy link
Copy Markdown

@jsirish jsirish commented Mar 9, 2026

Problem

When both --reasoning-parser and --tool-call-parser are enabled, the reasoning parser branch in stream_chat_completion consumes all streaming tokens without routing them through the tool call parser. This causes <tool_call> XML markup to be emitted as raw text in the reasoning or content fields instead of being parsed into structured tool_calls in the response.

This affects models like Qwen3-Coder that emit tool calls directly inside reasoning blocks without a clean </think> transition.

Fix

The fix integrates tool call detection into the reasoning parser's streaming branch:

  1. After the reasoning parser processes each delta, the effective text (whether content or reasoning) is also fed through the tool call parser
  2. If tool markup is detected, the output is suppressed until the full tool call is parsed
  3. When a complete tool call is found, it's emitted as a structured tool_calls chunk with finish_reason: "tool_calls"
  4. Non-tool-call content continues to flow through the reasoning parser's normal path

Testing

Tested with:

  • Model: Qwen3-Coder-Next-8bit
  • Flags: --reasoning-parser qwen3 --tool-call-parser qwen3_coder --continuous-batching
  • Load: 4 concurrent agents (CoPaw, OpenCode) over ~200 requests
  • Result: All requests returned structured tool calls correctly; zero errors

When both --reasoning-parser and --tool-call-parser are enabled,
the reasoning parser branch in stream_chat_completion would consume
all tokens without routing them through the tool call parser. This
meant <tool_call> XML was emitted as raw text in reasoning or
content fields instead of being parsed into structured tool_calls.

This fix feeds the reasoning parser's output (whether content or
reasoning text) through the tool call parser to detect and emit
structured tool calls. Tested with Qwen3-Coder-Next-8bit using
--reasoning-parser qwen3 --tool-call-parser qwen3_coder under
concurrent 4-agent load.
Copilot AI review requested due to automatic review settings March 9, 2026 06:23
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes streaming chat completion behavior when both --reasoning-parser and --tool-call-parser are enabled, ensuring tool-call markup emitted inside reasoning streams is parsed into structured tool_calls chunks rather than leaking as raw <tool_call> text.

Changes:

  • Feed reasoning-parser streaming deltas through the tool-call parser (using “effective” text from reasoning/content).
  • Suppress output while tool markup is incomplete, and emit a structured tool_calls chunk when a tool call completes.
  • Preserve normal reasoning/content streaming behavior when no tool markup is present.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1941 to +1988
effective_text = delta_msg.content or delta_msg.reasoning or ""
if tool_parser and effective_text:
if not tool_markup_possible and "<" not in effective_text:
tool_accumulated_text += effective_text
# No tool markup yet — emit the delta as-is
else:
if not tool_markup_possible:
tool_markup_possible = True
tool_previous = tool_accumulated_text
tool_accumulated_text += effective_text
tool_result = tool_parser.extract_tool_calls_streaming(
tool_previous, tool_accumulated_text, effective_text,
)

if tool_result is None:
# Inside tool markup — suppress output entirely
continue

if "tool_calls" in tool_result:
# Emit structured tool calls instead of reasoning/content
tool_calls_detected = True
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
tool_calls=tool_result["tool_calls"]
),
finish_reason=(
"tool_calls" if output.finished else None
),
)
],
usage=get_usage(output) if output.finished else None,
)
yield f"data: {chunk.model_dump_json()}\n\n"
continue

# Tool parser returned content (not a tool call) — use it
tool_content = tool_result.get("content", "")
if tool_content:
if delta_msg.reasoning:
delta_msg.reasoning = tool_content
delta_msg.content = None
else:
delta_msg.content = tool_content
delta_msg.reasoning = None
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

effective_text = delta_msg.content or delta_msg.reasoning breaks transition deltas where the reasoning parser returns both reasoning and content (see DeltaMessage contract). In that case: (1) only content is fed into the tool parser (so tool markup in the reasoning portion can be missed), and (2) if the tool parser returns content, the current rewrite branch overwrites delta_msg.reasoning and drops delta_msg.content, causing the client to receive content in the wrong field and lose the original reasoning.

Suggestion: handle the transition chunk explicitly. Either feed both parts to the tool parser in-order (reasoning then content) while updating tool_accumulated_text, or build effective_text from both and track which field(s) it came from so you only rewrite the corresponding field(s) (and never discard the other).

Copilot uses AI. Check for mistakes.
Comment on lines +1960 to +1973
# Emit structured tool calls instead of reasoning/content
tool_calls_detected = True
chunk = ChatCompletionChunk(
id=response_id,
model=request.model,
choices=[
ChatCompletionChunkChoice(
delta=ChatCompletionChunkDelta(
tool_calls=tool_result["tool_calls"]
),
finish_reason=(
"tool_calls" if output.finished else None
),
)
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a tool call is detected (tool_calls_detected = True), this branch emits a tool_calls delta chunk, but the reasoning-parser path does not appear to update the final non-tool chunk's finish_reason based on tool_calls_detected the way the non-reasoning path does (it uses output.finish_reason directly). That can lead to inconsistent finish_reason behavior depending on whether _reasoning_parser is enabled.

Suggestion: ensure the reasoning-parser path mirrors the non-reasoning path’s finish_reason selection once tool_calls_detected is set (or explicitly justify the difference).

Copilot uses AI. Check for mistakes.
Comment on lines +1937 to +1946
# Check if tool call markup appears in reasoning or content.
# Some models (e.g. Qwen3-Coder) emit <tool_call> directly
# inside reasoning without a </think> transition, so we need to
# intercept tool call tokens regardless of which field they land in.
effective_text = delta_msg.content or delta_msg.reasoning or ""
if tool_parser and effective_text:
if not tool_markup_possible and "<" not in effective_text:
tool_accumulated_text += effective_text
# No tool markup yet — emit the delta as-is
else:
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is behavioral and fairly subtle (tool-call parsing integrated into the reasoning streaming path), but there’s no unit test exercising stream_chat_completion with both _reasoning_parser and _tool_call_parser enabled (especially the Qwen3-style case where <tool_call> appears inside reasoning). Adding a regression test with a stub engine that streams deltas containing a tool call inside a think/reasoning block would help prevent future regressions and validate that raw <tool_call> text is suppressed and a structured tool_calls chunk is emitted.

Copilot uses AI. Check for mistakes.
jackzampolin added a commit to jackzampolin/vllm-mlx that referenced this pull request Mar 11, 2026
…arser in streaming

Models like Qwen3-Coder emit <tool_call> inside reasoning blocks without
a clean </think> transition. This fix feeds both content and reasoning
through the tool parser, while still emitting reasoning chunks when
content is being suppressed for tool markup parsing.

Cherry-picked from: waybarrios#148
Original author: Dynamic LLM

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@waybarrios
Copy link
Copy Markdown
Owner

I reviewed it in detail and found two small bugs that need fixing before we can merge

  1. effective_text uses or which drops data on transition chunks. when DeltaMessage has both reasoning AND content set (the transition from thinking to answering), only one field gets fed to the tool parser. if a tool_call marker spans that boundary it wont get detected

fix - concatenate both fields instead of short-circuiting

effective_text = (delta_msg.reasoning or "") + (delta_msg.content or "")
  1. finish_reason in the reasoning branch doesnt check tool_calls_detected. when tool calls were found earlier in the stream, the final chunk says "stop" instead of "tool_calls" which breaks clients that rely on that value to know they should execute tools

fix - match the non-reasoning branch behavior

finish_reason=(
    "tool_calls"
    if (output.finished and tool_calls_detected)
    else (output.finish_reason if output.finished else None)
),

also the CI lint fails because its running Black on Python 3.11 which cant parse 3.13 syntax, thats a CI infra issue not your code. would be good to enable "allow edits from maintainers" on this PR so i can push fixes directly

let me know if you want to apply these or if you want me to handle it

waybarrios added a commit to otarkhan/vllm-mlx that referenced this pull request Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants