Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions vllm/entrypoints/openai/chat_completion/stream_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.parser.harmony_utils import (
sanitize_harmony_tool_name,
strip_harmony_control_tokens,
)


class TokenState(NamedTuple):
Expand Down Expand Up @@ -109,7 +113,9 @@ def extract_harmony_streaming_delta(
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_name = sanitize_harmony_tool_name(
group.recipient.split("functions.", 1)[1]
)
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
Expand Down Expand Up @@ -158,9 +164,13 @@ def extract_harmony_streaming_delta(
if content_encountered or combined_reasoning or tool_messages:
delta_kwargs: dict[str, str | list[DeltaToolCall]] = {}
if content_encountered:
delta_kwargs["content"] = combined_content
cleaned_content = strip_harmony_control_tokens(combined_content)
if cleaned_content is not None:
delta_kwargs["content"] = cleaned_content
if combined_reasoning:
delta_kwargs["reasoning"] = combined_reasoning
cleaned_reasoning = strip_harmony_control_tokens(combined_reasoning)
if cleaned_reasoning is not None:
delta_kwargs["reasoning"] = cleaned_reasoning
if tool_messages:
delta_kwargs["tool_calls"] = tool_messages
tools_streamed = True
Expand Down
63 changes: 48 additions & 15 deletions vllm/entrypoints/openai/parser/harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Iterable, Sequence
from typing import Literal

import regex as re
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputItem,
Expand Down Expand Up @@ -68,6 +69,30 @@
"container",
}

_HARMONY_CONTROL_TOKEN_RE = re.compile(r"<\|[^>]*?\|>")
_INVALID_TOOL_NAME = "__invalid_tool__"


def sanitize_harmony_tool_name(name: str | None) -> str:
if not name:
return ""
cleaned = name.strip()
if "<|" not in cleaned:
return cleaned
prefix = cleaned.split("<|", 1)[0].strip()
if prefix:
return prefix
cleaned = _HARMONY_CONTROL_TOKEN_RE.sub("", cleaned).strip()
return cleaned or _INVALID_TOOL_NAME


def strip_harmony_control_tokens(text: str | None) -> str | None:
if text is None:
return None
if "<|" not in text:
return text
return _HARMONY_CONTROL_TOKEN_RE.sub("", text)


def has_custom_tools(tool_types: set[str]) -> bool:
"""
Expand Down Expand Up @@ -220,7 +245,8 @@ def parse_response_input(
elif response_msg["type"] == "function_call":
msg = Message.from_role_and_content(Role.ASSISTANT, response_msg["arguments"])
msg = msg.with_channel("commentary")
msg = msg.with_recipient(f"functions.{response_msg['name']}")
tool_name = sanitize_harmony_tool_name(response_msg.get("name"))
msg = msg.with_recipient(f"functions.{tool_name}")
msg = msg.with_content_type("json")
else:
raise ValueError(f"Unknown input type: {response_msg['type']}")
Expand All @@ -238,9 +264,8 @@ def parse_chat_inputs_to_harmony_messages(chat_msgs: list) -> list[Message]:
# Collect tool id to name mappings for tool response recipient values
for chat_msg in chat_msgs:
for tool_call in chat_msg.get("tool_calls", []):
tool_id_names[tool_call.get("id")] = tool_call.get("function", {}).get(
"name"
)
raw_name = tool_call.get("function", {}).get("name")
tool_id_names[tool_call.get("id")] = sanitize_harmony_tool_name(raw_name)

for chat_msg in chat_msgs:
msgs.extend(parse_chat_input_to_harmony_message(chat_msg, tool_id_names))
Expand Down Expand Up @@ -333,7 +358,7 @@ def parse_chat_input_to_harmony_message(

for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
name = sanitize_harmony_tool_name(func.get("name"))
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
Expand All @@ -348,7 +373,7 @@ def parse_chat_input_to_harmony_message(
# Tool role message (tool output)
if role == "tool":
tool_call_id = chat_msg.get("tool_call_id", "")
name = tool_id_names.get(tool_call_id, "")
name = sanitize_harmony_tool_name(tool_id_names.get(tool_call_id, ""))
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)

Expand Down Expand Up @@ -410,7 +435,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:
msgs: list[Message] = []
for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
name = sanitize_harmony_tool_name(func.get("name"))
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
Expand All @@ -421,7 +446,7 @@ def parse_input_to_harmony_message(chat_msg) -> list[Message]:

# Tool role message (tool output)
if role == "tool":
name = chat_msg.get("name", "")
name = sanitize_harmony_tool_name(chat_msg.get("name"))
content = chat_msg.get("content", "") or ""
content = flatten_chat_text_content(content)

Expand Down Expand Up @@ -530,7 +555,7 @@ def _parse_browser_tool_call(message: Message, recipient: str) -> ResponseOutput

def _parse_function_call(message: Message, recipient: str) -> list[ResponseOutputItem]:
"""Parse function calls into function tool call items."""
function_name = recipient.split(".")[-1]
function_name = sanitize_harmony_tool_name(recipient.split(".")[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic to extract the function name using recipient.split(".")[-1] can be incorrect if the function name itself contains dots. For consistency with other parts of the codebase (e.g., serving.py), it's better to slice the string after the functions. prefix. This will correctly handle function names that might contain dots.

Suggested change
function_name = sanitize_harmony_tool_name(recipient.split(".")[-1])
function_name = sanitize_harmony_tool_name(recipient[len("functions.") :])

output_items = []
for content in message.content:
random_id = random_uuid()
Expand All @@ -554,7 +579,10 @@ def _parse_reasoning_content(message: Message) -> list[ResponseOutputItem]:
summary=[],
type="reasoning",
content=[
ResponseReasoningTextContent(text=content.text, type="reasoning_text")
ResponseReasoningTextContent(
text=strip_harmony_control_tokens(content.text),
type="reasoning_text",
)
],
status=None,
)
Expand All @@ -567,7 +595,7 @@ def _parse_final_message(message: Message) -> ResponseOutputItem:
contents = []
for content in message.content:
output_text = ResponseOutputText(
text=content.text,
text=strip_harmony_control_tokens(content.text),
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
Expand Down Expand Up @@ -681,13 +709,14 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:

if current_recipient and parser.current_channel in ("commentary", "analysis"):
if current_recipient.startswith("functions."):
function_name = sanitize_harmony_tool_name(current_recipient.split(".")[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to another comment, using current_recipient.split(".")[-1] to extract the function name can be problematic if the function name contains dots. To ensure correctness and consistency, it's safer to slice the string after the functions. prefix.

Suggested change
function_name = sanitize_harmony_tool_name(current_recipient.split(".")[-1])
function_name = sanitize_harmony_tool_name(current_recipient[len("functions.") :])

rid = random_uuid()
return [
ResponseFunctionToolCall(
arguments=parser.current_content,
call_id=f"call_{rid}",
type="function_call",
name=current_recipient.split(".")[-1],
name=function_name,
id=f"fc_{rid}",
status="in_progress",
)
Expand Down Expand Up @@ -720,7 +749,8 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
text=strip_harmony_control_tokens(parser.current_content),
type="reasoning_text",
)
],
status=None,
Expand All @@ -735,7 +765,8 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:
type="reasoning",
content=[
ResponseReasoningTextContent(
text=parser.current_content, type="reasoning_text"
text=strip_harmony_control_tokens(parser.current_content),
type="reasoning_text",
)
],
status=None,
Expand All @@ -744,7 +775,7 @@ def parse_remaining_state(parser: StreamableParser) -> list[ResponseOutputItem]:

if parser.current_channel == "final":
output_text = ResponseOutputText(
text=parser.current_content,
text=strip_harmony_control_tokens(parser.current_content),
annotations=[], # TODO
type="output_text",
logprobs=None, # TODO
Expand Down Expand Up @@ -814,6 +845,8 @@ def parse_chat_output(
final_content: str | None = "\n".join(final_texts)

# Return None instead of empty string since existing callers check for None
reasoning = strip_harmony_control_tokens(reasoning)
final_content = strip_harmony_control_tokens(final_content)
reasoning = reasoning or None
final_content = final_content or None

Expand Down
24 changes: 16 additions & 8 deletions vllm/entrypoints/openai/responses/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@
parse_remaining_state,
parse_response_input,
render_for_completion,
sanitize_harmony_tool_name,
strip_harmony_control_tokens,
)
from vllm.entrypoints.openai.responses.context import (
ConversationContext,
Expand Down Expand Up @@ -1616,7 +1618,9 @@ def _emit_function_call_done_events(
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a function call completes."""
function_name = previous_item.recipient[len("functions.") :]
function_name = sanitize_harmony_tool_name(
previous_item.recipient[len("functions.") :]
)
events = []
events.append(
ResponseFunctionCallArgumentsDoneEvent(
Expand Down Expand Up @@ -1699,8 +1703,9 @@ def _emit_reasoning_done_events(
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a reasoning (analysis) item completes."""
sanitized_text = strip_harmony_control_tokens(previous_item.content[0].text)
content = ResponseReasoningTextContent(
text=previous_item.content[0].text,
text=sanitized_text,
type="reasoning_text",
)
reasoning_item = ResponseReasoningItem(
Expand All @@ -1718,7 +1723,7 @@ def _emit_reasoning_done_events(
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
text=sanitized_text,
)
)
events.append(
Expand Down Expand Up @@ -1747,9 +1752,10 @@ def _emit_text_output_done_events(
state: HarmonyStreamingState,
) -> list[StreamingResponsesResponse]:
"""Emit events when a final text output item completes."""
sanitized_text = strip_harmony_control_tokens(previous_item.content[0].text)
text_content = ResponseOutputText(
type="output_text",
text=previous_item.content[0].text,
text=sanitized_text,
annotations=[],
)
events = []
Expand All @@ -1759,7 +1765,7 @@ def _emit_text_output_done_events(
sequence_number=-1,
output_index=state.current_output_index,
content_index=state.current_content_index,
text=previous_item.content[0].text,
text=sanitized_text,
logprobs=[],
item_id=state.current_item_id,
)
Expand Down Expand Up @@ -1859,7 +1865,7 @@ def _emit_final_channel_delta_events(
content_index=state.current_content_index,
output_index=state.current_output_index,
item_id=state.current_item_id,
delta=ctx.last_content_delta,
delta=strip_harmony_control_tokens(ctx.last_content_delta),
# TODO, use logprobs from ctx.last_request_output
logprobs=[],
)
Expand Down Expand Up @@ -1909,7 +1915,7 @@ def _emit_analysis_channel_delta_events(
item_id=state.current_item_id,
output_index=state.current_output_index,
content_index=state.current_content_index,
delta=ctx.last_content_delta,
delta=strip_harmony_control_tokens(ctx.last_content_delta),
sequence_number=-1,
)
)
Expand Down Expand Up @@ -2390,7 +2396,9 @@ def _emit_function_call_delta_events(
events = []
if state.is_first_function_call_delta is False:
state.is_first_function_call_delta = True
fc_name = ctx.parser.current_recipient[len("functions.") :]
fc_name = sanitize_harmony_tool_name(
ctx.parser.current_recipient[len("functions.") :]
)
state.current_item_id = f"fc_{random_uuid()}"
tool_call_item = ResponseFunctionToolCall(
name=fc_name,
Expand Down
20 changes: 14 additions & 6 deletions vllm/tool_parsers/openai_tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.parser.harmony_utils import parse_output_into_messages
from vllm.entrypoints.openai.parser.harmony_utils import (
parse_output_into_messages,
sanitize_harmony_tool_name,
strip_harmony_control_tokens,
)
from vllm.logger import init_logger
from vllm.tool_parsers.abstract_tool_parser import (
ToolParser,
Expand Down Expand Up @@ -71,24 +75,28 @@ def extract_tool_calls(
ToolCall(
type="function",
function=FunctionCall(
name=msg.recipient.split("functions.")[1],
name=sanitize_harmony_tool_name(
msg.recipient.split("functions.", 1)[1]
),
arguments=tool_args,
),
)
)
elif msg.channel == "final":
final_content = msg_text
final_content = strip_harmony_control_tokens(msg_text)
elif msg.channel == "commentary" and not msg.recipient:
commentary_content = msg_text
commentary_content = strip_harmony_control_tokens(msg_text)

# Extract partial content from the parser state if the generation was truncated
if parser.current_content:
if parser.current_channel == "final":
final_content = parser.current_content
final_content = strip_harmony_control_tokens(parser.current_content)
elif (
parser.current_channel == "commentary" and not parser.current_recipient
):
commentary_content = parser.current_content
commentary_content = strip_harmony_control_tokens(
parser.current_content
)

return ExtractedToolCallInformation(
tools_called=len(tool_calls) > 0,
Expand Down
Loading