Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 191 additions & 0 deletions tests/entrypoints/openai/parser/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
get_streamable_parser_for_assistant,
get_system_message,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
sanitize_harmony_name,
sanitize_harmony_recipient,
)
from vllm.entrypoints.openai.responses.harmony import (
response_previous_input_to_harmony,
Expand Down Expand Up @@ -841,3 +844,191 @@ def test_all_standard_channels_present(self) -> None:
assert channel in valid_channels, (
f"{channel} missing when with_custom_tools={with_tools}"
)


class TestSanitizeHarmonyName:
"""Tests for sanitize_harmony_name()."""

def test_clean_name_unchanged(self) -> None:
assert sanitize_harmony_name("get_weather") == "get_weather"

def test_strip_channel_token(self) -> None:
assert (
sanitize_harmony_name("manage_cart<|channel|>commentary") == "manage_cart"
)

def test_strip_constrain_token(self) -> None:
assert sanitize_harmony_name("<|constrain|>json") == ""

def test_pure_control_token_returns_empty(self) -> None:
assert sanitize_harmony_name("<|start|>") == ""

def test_multiple_tokens_earliest_wins(self) -> None:
assert sanitize_harmony_name("foo<|channel|>bar<|constrain|>baz") == "foo"

def test_empty_string(self) -> None:
assert sanitize_harmony_name("") == ""

def test_trailing_whitespace_stripped(self) -> None:
assert sanitize_harmony_name("tool_name <|end|>") == "tool_name"


class TestSanitizeHarmonyRecipient:
"""Tests for sanitize_harmony_recipient()."""

def test_clean_dotted_name_unchanged(self) -> None:
assert sanitize_harmony_recipient("browser.search") == "browser.search"

def test_clean_simple_name_unchanged(self) -> None:
assert sanitize_harmony_recipient("python") == "python"

def test_contaminated_first_part_preserved_structure(self) -> None:
"""browser<|channel|>.search → browser.search"""
assert (
sanitize_harmony_recipient("browser<|channel|>.search") == "browser.search"
)

def test_contaminated_second_part(self) -> None:
"""browser.search<|end|>garbage → browser.search"""
assert (
sanitize_harmony_recipient("browser.search<|end|>garbage")
== "browser.search"
)

def test_pure_control_token_returns_empty(self) -> None:
assert sanitize_harmony_recipient("<|constrain|>json") == ""

def test_functions_dotted_contaminated(self) -> None:
"""functions.get_weather<|channel|>commentary → functions.get_weather"""
assert (
sanitize_harmony_recipient("functions.get_weather<|channel|>commentary")
== "functions.get_weather"
)

def test_empty_string(self) -> None:
assert sanitize_harmony_recipient("") == ""

def test_container_dotted_contaminated(self) -> None:
"""container<|channel|>.exec → container.exec"""
assert (
sanitize_harmony_recipient("container<|channel|>.exec") == "container.exec"
)


class TestResilientStreamableParser:
"""Tests for ResilientStreamableParser error recovery."""

def test_normal_sequence_unchanged(self) -> None:
"""Normal token sequence should produce same results as raw parser."""
encoding = get_encoding()
harmony_str = "<|channel|>final<|message|>Hello world<|end|>"
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in token_ids:
parser.process(tok)

assert len(parser.messages) == 1
assert parser.messages[0].content[0].text == "Hello world"
assert parser.messages[0].channel == "final"

def test_missing_start_recovery(self) -> None:
"""Parser should recover when <|start|> is missing between messages."""
encoding = get_encoding()
# First message completes normally, second is missing <|start|>
first_msg = "<|channel|>final<|message|>First.<|end|>"
second_msg = "<|channel|>final<|message|>Second.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")
second_tokens = encoding.encode(second_msg, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in first_tokens:
parser.process(tok)
# Feed second message tokens directly (missing <|start|>assistant)
for tok in second_tokens:
parser.process(tok)

assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert parser.messages[1].content[0].text == "Second."

def test_constrain_in_header_skipped(self) -> None:
"""<|constrain|> in HEADER state should be skipped gracefully."""
encoding = get_encoding()
# First, complete a normal message so parser goes to EXPECT_START
first_msg = "<|channel|>final<|message|>First.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")

# Build a second message where <|constrain|> appears in the header
# after <|start|>assistant, before <|channel|>
start_tok = encoding.encode("<|start|>", allowed_special="all")
role_toks = encoding.encode("assistant", allowed_special="all")
constrain_tok = encoding.encode("<|constrain|>", allowed_special="all")
# Garbage tokens that should be skipped
json_toks = encoding.encode("json", allowed_special="all")
message_tok = encoding.encode("<|message|>", allowed_special="all")
text_toks = encoding.encode("Second.", allowed_special="all")
end_tok = encoding.encode("<|end|>", allowed_special="all")

parser = get_streamable_parser_for_assistant()
# Complete first message
for tok in first_tokens:
parser.process(tok)
assert len(parser.messages) == 1

# Feed: <|start|>assistant → puts parser in HEADER state
for tok in start_tok:
parser.process(tok)
for tok in role_toks:
parser.process(tok)
# Feed: <|constrain|> → should enter skip mode
for tok in constrain_tok:
parser.process(tok)
# Feed: json tokens → should be discarded in skip mode
for tok in json_toks:
parser.process(tok)
# Feed: <|message|> → should exit skip mode and resume
for tok in message_tok:
parser.process(tok)
# Feed: text + <|end|>
for tok in text_toks:
parser.process(tok)
for tok in end_tok:
parser.process(tok)

# Should have produced two messages despite the malformed sequence
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."

def test_messages_recipients_sanitized(self) -> None:
"""Messages returned by .messages should have sanitized recipients,
preventing contaminated history in multi-turn interactions."""
encoding = get_encoding()
# Build a tool call message with a contaminated recipient
harmony_str = (
"<|channel|>commentary"
"<|message|>Let me search.<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
'<|constrain|>json<|message|>{"loc": "SF"}<|end|>'
)
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in token_ids:
parser.process(tok)

msgs = parser.messages
# All recipients should be clean (no control tokens)
for msg in msgs:
if msg.recipient is not None:
for tok_str in (
"<|channel|>",
"<|constrain|>",
"<|start|>",
"<|end|>",
"<|message|>",
):
assert tok_str not in msg.recipient, (
f"Leaked control token {tok_str!r} "
f"in message recipient: {msg.recipient!r}"
)
31 changes: 31 additions & 0 deletions tests/entrypoints/openai/responses/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,34 @@ def test_parser_state_to_response_output_analysis_channel() -> None:
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"


class TestHarmonyOutputSanitization:
"""Tests that leaked Harmony control tokens are sanitized in output."""

def test_constrain_recipient_treated_as_no_recipient(self):
"""<|constrain|>json as recipient should be sanitized to empty,
falling through to _parse_message_no_recipient (produces message)."""
message = Message.from_role_and_content(Role.ASSISTANT, "Some output text")
message = message.with_channel("commentary")
message = message.with_recipient("<|constrain|>json")

output_items = harmony_to_response_output(message)

# Should produce a message (preamble), not an MCP call
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseOutputMessage)
assert output_items[0].type == "message"

def test_contaminated_tool_name_cleaned_in_function_call(self):
"""Function name with leaked <|channel|> should be sanitized."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"location": "SF"}')
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather<|channel|>commentary")

output_items = harmony_to_response_output(message)

assert len(output_items) == 1
assert isinstance(output_items[0], ResponseFunctionToolCall)
assert output_items[0].name == "get_weather"
assert "<|" not in output_items[0].name
5 changes: 4 additions & 1 deletion vllm/entrypoints/openai/chat_completion/stream_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.parser.harmony_utils import sanitize_harmony_name


class TokenState(NamedTuple):
Expand Down Expand Up @@ -109,7 +110,9 @@ def extract_harmony_streaming_delta(
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_name = sanitize_harmony_name(
group.recipient.split("functions.", 1)[1]
)
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
Expand Down
Loading