Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/chat_completion/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,9 @@ async def chat_completion_stream_generator(
get_streamable_parser_for_assistant() for _ in range(num_choices)
]
harmony_tools_streamed = [False] * num_choices
harmony_tool_call_ids: list[dict[str, tuple[int, str]]] = [
{} for _ in range(num_choices)
]
tools_streamed = [False] * num_choices

if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
Expand Down Expand Up @@ -764,6 +767,7 @@ async def chat_completion_stream_generator(
token_states=token_states,
prev_recipient=prev_recipient,
include_reasoning=request.include_reasoning,
tool_call_ids=harmony_tool_call_ids[i],
)
)
harmony_tools_streamed[i] |= tools_streamed_flag
Expand Down Expand Up @@ -1109,6 +1113,7 @@ async def chat_completion_stream_generator(
delta_message, output
)
and tool_parser
and auto_tools_called
):
latest_delta_len = 0
if (
Expand Down
66 changes: 51 additions & 15 deletions vllm/entrypoints/openai/chat_completion/stream_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
DeltaMessage,
DeltaToolCall,
)
from vllm.logger import init_logger

logger = init_logger(__name__)


class TokenState(NamedTuple):
Expand All @@ -30,6 +33,7 @@
token_states: list[TokenState],
prev_recipient: str | None,
include_reasoning: bool,
tool_call_ids: dict[str, tuple[int, str]] | None = None,
) -> tuple[DeltaMessage | None, bool]:
"""
Extract a DeltaMessage from harmony parser state during streaming.
Expand All @@ -39,10 +43,14 @@
token_states: List of TokenState tuples for each token
prev_recipient: Previous recipient for detecting tool call transitions
include_reasoning: Whether to include reasoning content
tool_call_ids: Optional dict mapping tool call index to its ID,
used to include ID in continuation chunks

Returns:
A tuple of (DeltaMessage or None, tools_streamed_flag)
"""
if tool_call_ids is None:
tool_call_ids = {}

if not token_states:
return None, False
Expand Down Expand Up @@ -106,43 +114,51 @@
and group.recipient
and group.recipient.startswith("functions.")
):
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
# New tool call - emit the opening message with any arguments
tool_name = group.recipient.split("functions.", 1)[1]
tool_id = make_tool_call_id()
# Store by recipient so we can look it up later even if base_index changes

Check failure on line 121 in vllm/entrypoints/openai/chat_completion/stream_harmony.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/entrypoints/openai/chat_completion/stream_harmony.py:121:89: E501 Line too long (90 > 88)
tool_call_ids[group.recipient] = (next_tool_index, tool_id)
logger.debug(
"New tool call: index=%d, id=%s, recipient=%s",
next_tool_index, tool_id, group.recipient
)
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
id=tool_id,
type="function",
function=DeltaFunctionCall(
name=tool_name,
arguments="",
arguments=group.text or "",
),
index=next_tool_index,
)
)
opened_new_call = True
prev_recipient = group.recipient
# Increment for subsequent new tool calls
next_tool_index += 1

if group.text:
# Stream arguments for the ongoing tool call
if opened_new_call:
# Just opened in this group
tool_call_index = next_tool_index - 1
elif group.text:
# Continuing arguments for an existing tool call
# Look up the index and ID by recipient
stored = tool_call_ids.get(group.recipient)
if stored:
tool_call_index, tool_id = stored
else:
# Continuing from previous chunk
# If ongoing_tool_index is None here, it means
# we're continuing a call but prev_recipient
# wasn't a function. Use base_index.
# Fallback if not found (shouldn't happen)
tool_call_index = (
ongoing_tool_index
if ongoing_tool_index is not None
else base_index
)
tool_id = None
logger.debug(
"Continue tool call: index=%d, id=%s, recipient=%s",
tool_call_index, tool_id, group.recipient
)
tool_messages.append(
DeltaToolCall(
id=tool_id,
index=tool_call_index,
function=DeltaFunctionCall(arguments=group.text),
)
Expand All @@ -154,6 +170,26 @@
elif group.channel == "analysis" and include_reasoning:
combined_reasoning += group.text

# Merge tool messages with the same index to avoid sending multiple
# entries for the same tool call in one SSE chunk
if tool_messages:
merged_tools: dict[int, DeltaToolCall] = {}
for tc in tool_messages:
if tc.index in merged_tools:
# Merge arguments into existing entry
existing = merged_tools[tc.index]
if tc.function and tc.function.arguments:
if existing.function:
existing.function.arguments = (
(existing.function.arguments or "")
+ tc.function.arguments
)
else:
existing.function = tc.function
else:
merged_tools[tc.index] = tc
tool_messages = list(merged_tools.values())

# Combine all non-empty fields into a single message
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
Expand Down
Loading