Skip to content

[Bugfix] Fix Harmony streaming cross-channel delta accumulation#36011

Open
will-deines wants to merge 7 commits intovllm-project:mainfrom
will-deines:fix/harmony-streaming-cross-channel-delta
Open

[Bugfix] Fix Harmony streaming cross-channel delta accumulation#36011
will-deines wants to merge 7 commits intovllm-project:mainfrom
will-deines:fix/harmony-streaming-cross-channel-delta

Conversation

@will-deines
Copy link
Copy Markdown

@will-deines will-deines commented Mar 4, 2026

Summary

Fix Harmony streaming content leaks when a token batch crosses a channel boundary (e.g. analysis → commentary). With --stream-interval > 1, vLLM yields up to N tokens at a time. When a batch spans two channels, append_output accumulated all content deltas into a single string, and emit_content_delta_events classified the entire blob using the channel at the end of the batch — causing analysis-tail text to leak into response.output_text.delta events and actual commentary content to be lost or misclassified.

Symptom: Streaming clients see an empty output_text stream (no response.output_text.delta events for visible text), even though response.completed contains the correct output_text.

Fix: Track (channel, recipient, delta) triples per contiguous run within each batch, then emit each segment with the correct event type.


Related Issues & PRs

Directly Addressed

# Title Status Relation
#27641 [Bug]: Streaming tool call randomly failed when using gpt-oss-120b/20b Open Same root cause — cross-channel misrouting in multi-token streaming batches. Tool call arguments end up in reasoning/content instead of structured tool_calls.
#31501 [Bug]: --stream-interval > 1 causes tool call arguments to be empty/lost Open Same root cause — when stream_interval > 1, multiple messages complete within a single yield, and content gets misclassified by end-of-batch channel state.

Prior Fixes for the Same Class of Bug

# Title Status Relation
#30437 [Bugfix] missing tokens occur in harmony streaming Merged Predecessor — fixed multi-token yields dropping tokens entirely. Our bug is the next layer: tokens are all processed, but their channel attribution is wrong when the batch crosses a boundary.
#26291 [MODEL] Fix handling of multiple channels for gpt-oss with speculative decoding Merged Same pattern — fixed multi-token cross-channel misrouting in speculative decoding context.
#32114 [Bugfix] Fix Harmony preamble visibility in Responses API Merged Same function — fixed 6 code paths in emit_content_delta_events() where preamble content was misrouted as reasoning.
#24306 [gpt-oss] Fix streamableparser for missing handling of certain token_ids Merged Original fix — fixed per-token processing loop that only handled one token per yield.

Code Introduced By

# Title Status Relation
#35148 [Responses] Decouple SSE event helpers from Harmony context Merged Introduced the bug — created the dispatcher emit_content_delta_events(ctx, state) that reads ctx.last_content_delta + ctx.parser.current_channel (end-of-batch state).
#34909 [Refactor] Extract Harmony streaming SSE event builders Merged Foundation — extracted ~800 lines of _emit_* methods into streaming_events.py, establishing the two-layer dispatcher + leaf-helper architecture.

Related Open PRs

# Title Status Relation
#35449 [Bugfix] Fix tool call arguments parsed as content/reasoning in harmony streaming Open Same problem, different code path — fixes extract_harmony_streaming_delta (Chat Completions streaming) with a HarmonyStreamingState dataclass. Our fix addresses the Responses API streaming path (emit_content_delta_events). Both are needed.
#32997 [Bugfix]: Prevent reasoning_content leak Open Related leak — fixes reasoning_content incorrectly flushed into content in the final streamed chunk. Different mechanism but same symptom family.
#33520 [Bugfix] Fix tool call streaming for gpt-oss/Harmony models Open Related — fixes split tool calls where base_index changes as messages complete within a batch. Another multi-token batch state-tracking bug.
#37070 [Bugfix] Fix harmony streaming tool call crash and argument splitting Open Related — recent fix for harmony streaming tool call crashes and argument splitting. Same symptom family (multi-token batch state issues).
#37071 [Bugfix] Fix Responses API harmony streaming: token splitting, missing done events Open Related — recent fix for Responses API harmony streaming path (token splitting, missing done events). Directly overlaps with the code paths this PR modifies.

Root Cause

File: vllm/entrypoints/openai/responses/context.pyStreamingHarmonyContext.append_output()

# BEFORE (buggy):
last_delta_text = ""
for tok in output.outputs[0].token_ids:
    self.parser.process(tok)
    last_delta_text += self.parser.last_content_delta or ""  # ← accumulates across channels
if last_delta_text:
    self.last_content_delta = last_delta_text  # ← one blob, channel info lost

File: vllm/entrypoints/openai/responses/streaming_events.pyemit_content_delta_events()

# BEFORE (buggy):
delta = ctx.last_content_delta        # ← mixed content from both channels
channel = ctx.parser.current_channel  # ← channel at END of batch
# → analysis-tail text emitted as output_text.delta, commentary content lost

When --stream-interval 20 yields a batch that starts in analysis and transitions to commentary:

  1. Tokens 1–15 (analysis): content accumulated into last_delta_text
  2. Token 16: channel transitions to commentary
  3. Tokens 17–20 (commentary): content also accumulated into last_delta_text
  4. current_channel is now commentary (end of batch)
  5. emit_content_delta_events emits the entire blob as output_text.delta — leaking analysis content into visible text
  6. In subsequent batches, commentary content may arrive but current_channel has moved on, so it gets dropped

Fix

1. context.py — Track per-channel deltas

# AFTER:
self.channel_deltas: list[tuple[str | None, str | None, str]] = []
# (channel, recipient, delta) — one entry per contiguous run

for tok in output.outputs[0].token_ids:
    self.parser.process(tok)
    tok_delta = self.parser.last_content_delta
    if tok_delta:
        channel = self.parser.current_channel
        recipient = self.parser.current_recipient
        # Coalesce consecutive tokens in the same channel+recipient
        if (self.channel_deltas
                and self.channel_deltas[-1][0] == channel
                and self.channel_deltas[-1][1] == recipient):
            ch, rcp, prev = self.channel_deltas[-1]
            self.channel_deltas[-1] = (ch, rcp, prev + tok_delta)
        else:
            self.channel_deltas.append((channel, recipient, tok_delta))

2. streaming_events.py — Emit per-channel

Extract the dispatch logic into _emit_delta_for_channel(channel, recipient, delta, state), then iterate:

def emit_content_delta_events(ctx, state):
    events = []
    for channel, recipient, delta in ctx.channel_deltas:
        events.extend(_emit_delta_for_channel(channel, recipient, delta, state))
    return events

This preserves the exact same dispatch rules (commentary → text, analysis → reasoning, functions.* → function call, etc.) but applies them per-segment instead of per-batch.


Files Changed

File Change
vllm/entrypoints/openai/responses/context.py Add channel_deltas list to StreamingHarmonyContext; populate it per-token in append_output with coalescing for consecutive same-channel tokens
vllm/entrypoints/openai/responses/streaming_events.py Extract _emit_delta_for_channel(); rewrite emit_content_delta_events() to iterate over ctx.channel_deltas
tests/entrypoints/openai/test_serving_responses.py Update _make_ctx mock to set channel_deltas so existing tests work with the new code path

Decisions to Debate

1. Per-token channel tracking vs. message-level diffing

What we chose: Record each token's (channel, recipient) at process time and coalesce consecutive runs.

Alternative: PR #35449 takes a different approach for the Chat Completions path — it diffs parser.messages before/after processing a batch to derive what changed. This is more robust against parser state quirks but requires more complex state management.

Why we chose this: For the Responses API path, the per-token approach is minimal and surgical. The StreamableParser already exposes current_channel and current_recipient per token. We just need to not throw that information away by accumulating into a single string. The message-diffing approach would be over-engineering for this specific dispatcher.

What reviewers might disagree with: Per-token tracking assumes current_channel is always accurate at the moment last_content_delta is set. If the parser ever updates current_channel before or after setting last_content_delta, the attribution could be wrong. In practice this doesn't happen — the parser updates both atomically in process().

2. Keeping last_content_delta alongside channel_deltas

What we chose: We still populate self.last_content_delta (the accumulated string) for backward compatibility, even though emit_content_delta_events now uses channel_deltas exclusively.

Alternative: Remove last_content_delta entirely and make all consumers use channel_deltas.

Why we chose this: last_content_delta is read in other code paths (e.g. _update_num_reasoning_tokens, potentially external consumers). Removing it would expand the blast radius. The cost of maintaining both is negligible — it's one extra string concatenation per batch.

3. Tuple vs. dataclass for channel_deltas entries

What we chose: list[tuple[str | None, str | None, str]] — lightweight, no new classes.

Alternative: A ChannelDelta dataclass with named fields (channel, recipient, delta) for clarity.

Why we chose this: The list is ephemeral (rebuilt every append_output call), consumed in one place (emit_content_delta_events), and destructured immediately. A dataclass would add a class definition for a 3-field struct used in exactly two lines of code.


Test Plan

  • All 13 existing test_serving_responses.py tests pass (including 5 preamble streaming tests)
  • All 22 existing test_context.py tests pass
  • E2E: streaming deltas reconstruct to match response.completed.output_texttest_streaming_response was FAILED (assert '' == 'Hello there!') before the fix, now PASSED; test_streaming_text_matches_final also asserts the same invariant with a Harmony-specific prompt
  • E2E with --stream-interval 20 — server intentionally runs at maximum batch size to maximize cross-channel batch probability; all streaming tests pass
  • Reasoning deltas unaffected — test_streaming_reasoning_events verifies reasoning delta events fire separately from text deltas with no control tokens in either; test_reasoning_items_present + test_multi_turn_reasoning_consistent confirm reasoning items intact in both single and multi-turn

Out of Scope (follow-ups)

@will-deines will-deines force-pushed the fix/harmony-streaming-cross-channel-delta branch from 96395c1 to 111ce47 Compare March 4, 2026 13:54
@mergify mergify bot added frontend gpt-oss Related to GPT-OSS models bug Something isn't working labels Mar 4, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in Harmony streaming where content could leak across different channels within a single token batch. The fix introduces a mechanism to track content deltas per channel in StreamingHarmonyContext and refactors the event emission logic in streaming_events.py to process these deltas individually. This ensures that content is correctly attributed to its channel, preventing data leakage and loss. The implementation is clear, and the accompanying test modifications are appropriate. The changes effectively resolve the issue.

…hannel content leaks

When --stream-interval yields a batch of tokens that crosses a Harmony
channel boundary (e.g. analysis → commentary), append_output accumulated
all content into a single last_content_delta string. emit_content_delta_events
then classified the entire blob using the channel at the END of the batch,
causing analysis-tail text to leak into output_text.delta events and
commentary content to be lost.

Track (channel, recipient, delta) triples per contiguous run in the batch
so each segment is emitted with the correct event type.

Signed-off-by: Will Deines <will@garr.io>
@will-deines will-deines force-pushed the fix/harmony-streaming-cross-channel-delta branch from 9e7497e to d48bf7b Compare March 18, 2026 14:08
@mergify
Copy link
Copy Markdown

mergify bot commented Mar 19, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @will-deines.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 19, 2026
…ng-cross-channel-delta

Signed-off-by: Will Deines <will@garr.io>
@mergify mergify bot removed the needs-rebase label Mar 20, 2026
The Harmony streaming path only emits done events (output_text.done,
content_part.done, output_item.done, etc.) when is_expecting_start()
fires at a message boundary. For the last message in the stream, no
subsequent message triggers this, so done events are never emitted.

Add a post-loop flush that constructs a synthetic HarmonyMessage from
the parser's in-progress state and passes it to the existing
emit_previous_item_done_events() dispatcher, guarded by
state.sent_output_item_added to prevent double-emission.

Signed-off-by: Will Deines <will@garr.io>
…treaming

Within a single Harmony message, the channel can change (e.g.
analysis → final, analysis → functions.*) without a <|start|>
boundary. The previous fix only handled the final message's done
events via a post-loop flush, but mid-message channel transitions
were never emitting done events for the outgoing channel.

Add channel-transition tracking to StreamingState (last_channel,
last_recipient, accumulated_text). When emit_content_delta_events
detects a channel/recipient change, it emits done events for the
previous channel and resets state before starting the new one.

The post-loop flush now uses state.accumulated_text and
state.last_channel/last_recipient instead of parser state, which
is more reliable since the parser's current_content may span
multiple channels.

Signed-off-by: Will Deines <will@garr.io>
- Add sent_output_item_added=True to emit_function_call_delta_events,
  enabling post-loop flush and mid-message channel transitions for
  function calls (was the only delta emitter missing this flag)
- Use state-based done events at is_expecting_start() boundaries
  instead of parser message, eliminating race with emit_content_delta_events
- Always use arguments="" in first-delta output_item.added for simple
  path tool calls, preventing double-counted arguments
- Fix test_code_interpreter: client.timeout is a Timeout object, not a number
- Fix test_mcp_tool_multi_turn: Message.to_dict() flattens author fields
  to top level, so use msg.get("role") instead of msg.get("author", {}).get("role")
- Fix test_chat: use instructions= param instead of system role in input
  array (Harmony models build their own system message)
- Fix test_logprobs: skip for gpt-oss models (logprobs intentionally rejected)

Signed-off-by: Will Deines <will@garr.io>
Signed-off-by: Will Deines <will@garr.io>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working frontend gpt-oss Related to GPT-OSS models

Projects

Status: To Triage

Development

Successfully merging this pull request may close these issues.

2 participants