Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions tests/entrypoints/openai/parser/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
get_streamable_parser_for_assistant,
get_system_message,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
sanitize_harmony_name,
sanitize_harmony_recipient,
)
from vllm.entrypoints.openai.responses.harmony import (
response_input_to_harmony,
Expand Down Expand Up @@ -928,3 +931,282 @@ def test_reasoning_with_empty_content_returns_none(self):
msg = response_input_to_harmony(item, prev_responses=[])

assert msg is None


class TestSanitizeHarmonyName:
"""Tests for sanitize_harmony_name()."""

def test_clean_name_unchanged(self) -> None:
assert sanitize_harmony_name("get_weather") == "get_weather"

def test_strip_channel_token(self) -> None:
assert (
sanitize_harmony_name("manage_cart<|channel|>commentary") == "manage_cart"
)

def test_strip_constrain_token(self) -> None:
assert sanitize_harmony_name("<|constrain|>json") == ""

def test_pure_control_token_returns_empty(self) -> None:
assert sanitize_harmony_name("<|start|>") == ""

def test_multiple_tokens_earliest_wins(self) -> None:
assert sanitize_harmony_name("foo<|channel|>bar<|constrain|>baz") == "foo"

def test_empty_string(self) -> None:
assert sanitize_harmony_name("") == ""

def test_trailing_whitespace_stripped(self) -> None:
assert sanitize_harmony_name("tool_name <|end|>") == "tool_name"


class TestSanitizeHarmonyRecipient:
"""Tests for sanitize_harmony_recipient()."""

def test_clean_dotted_name_unchanged(self) -> None:
assert sanitize_harmony_recipient("browser.search") == "browser.search"

def test_clean_simple_name_unchanged(self) -> None:
assert sanitize_harmony_recipient("python") == "python"

def test_contaminated_first_part_preserved_structure(self) -> None:
"""browser<|channel|>.search → browser.search"""
assert (
sanitize_harmony_recipient("browser<|channel|>.search") == "browser.search"
)

def test_contaminated_second_part(self) -> None:
"""browser.search<|end|>garbage → browser.search"""
assert (
sanitize_harmony_recipient("browser.search<|end|>garbage")
== "browser.search"
)

def test_pure_control_token_returns_empty(self) -> None:
assert sanitize_harmony_recipient("<|constrain|>json") == ""

def test_functions_dotted_contaminated(self) -> None:
"""functions.get_weather<|channel|>commentary → functions.get_weather"""
assert (
sanitize_harmony_recipient("functions.get_weather<|channel|>commentary")
== "functions.get_weather"
)

def test_empty_string(self) -> None:
assert sanitize_harmony_recipient("") == ""

def test_container_dotted_contaminated(self) -> None:
"""container<|channel|>.exec → container.exec"""
assert (
sanitize_harmony_recipient("container<|channel|>.exec") == "container.exec"
)

def test_full_component_contamination_returns_empty(self) -> None:
"""functions.<|constrain|>json → "" (not "functions")"""
assert sanitize_harmony_recipient("functions.<|constrain|>json") == ""

def test_container_full_component_contamination_returns_empty(self) -> None:
"""container.<|channel|>commentary → "" (not "container")"""
assert sanitize_harmony_recipient("container.<|channel|>commentary") == ""


class TestResilientStreamableParser:
"""Tests for ResilientStreamableParser error recovery."""

def test_normal_sequence_unchanged(self) -> None:
"""Normal token sequence should produce same results as raw parser."""
encoding = get_encoding()
harmony_str = "<|channel|>final<|message|>Hello world<|end|>"
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in token_ids:
parser.process(tok)

assert len(parser.messages) == 1
assert parser.messages[0].content[0].text == "Hello world"
assert parser.messages[0].channel == "final"

def test_missing_start_recovery(self) -> None:
"""Parser should recover when <|start|> is missing between messages."""
encoding = get_encoding()
# First message completes normally, second is missing <|start|>
first_msg = "<|channel|>final<|message|>First.<|end|>"
second_msg = "<|channel|>final<|message|>Second.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")
second_tokens = encoding.encode(second_msg, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in first_tokens:
parser.process(tok)
# Feed second message tokens directly (missing <|start|>assistant)
for tok in second_tokens:
parser.process(tok)

assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert parser.messages[1].content[0].text == "Second."

def test_constrain_in_header_skipped(self) -> None:
"""<|constrain|> in HEADER state should be skipped gracefully."""
encoding = get_encoding()
# First, complete a normal message so parser goes to EXPECT_START
first_msg = "<|channel|>final<|message|>First.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")

# Build a second message where <|constrain|> appears in the header
# after <|start|>assistant, before <|channel|>
start_tok = encoding.encode("<|start|>", allowed_special="all")
role_toks = encoding.encode("assistant", allowed_special="all")
constrain_tok = encoding.encode("<|constrain|>", allowed_special="all")
# Garbage tokens that should be skipped
json_toks = encoding.encode("json", allowed_special="all")
message_tok = encoding.encode("<|message|>", allowed_special="all")
text_toks = encoding.encode("Second.", allowed_special="all")
end_tok = encoding.encode("<|end|>", allowed_special="all")

parser = get_streamable_parser_for_assistant()
# Complete first message
for tok in first_tokens:
parser.process(tok)
assert len(parser.messages) == 1

# Feed: <|start|>assistant → puts parser in HEADER state
for tok in start_tok:
parser.process(tok)
for tok in role_toks:
parser.process(tok)
# Feed: <|constrain|> → should enter skip mode
for tok in constrain_tok:
parser.process(tok)
# Feed: json tokens → should be discarded in skip mode
for tok in json_toks:
parser.process(tok)
# Feed: <|message|> → should exit skip mode and resume
for tok in message_tok:
parser.process(tok)
# Feed: text + <|end|>
for tok in text_toks:
parser.process(tok)
for tok in end_tok:
parser.process(tok)

# Should have produced two messages despite the malformed sequence
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
Comment on lines +1095 to +1096
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The test test_constrain_in_header_skipped verifies that two messages are produced after recovering from a malformed sequence, but it only asserts the content of the first message. To ensure the recovery logic is fully correct and the second message is parsed as expected, you should also add an assertion for the content of the second message.

Suggested change
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert parser.messages[1].content[0].text == "Second."

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the missing assertion for the second message content. Fixed in 6ef1b2f.

assert parser.messages[1].content[0].text == "Second."

def test_messages_recipients_sanitized(self) -> None:
"""Messages returned by .messages should have sanitized recipients,
preventing contaminated history in multi-turn interactions."""
encoding = get_encoding()
# Build a tool call message with a contaminated recipient
harmony_str = (
"<|channel|>commentary"
"<|message|>Let me search.<|end|>"
"<|start|>assistant to=functions.get_weather<|channel|>commentary"
'<|constrain|>json<|message|>{"loc": "SF"}<|end|>'
)
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in token_ids:
parser.process(tok)

msgs = parser.messages
# All recipients should be clean (no control tokens)
for msg in msgs:
if msg.recipient is not None:
for tok_str in (
"<|channel|>",
"<|constrain|>",
"<|start|>",
"<|end|>",
"<|message|>",
):
assert tok_str not in msg.recipient, (
f"Leaked control token {tok_str!r} "
f"in message recipient: {msg.recipient!r}"
)

def test_last_consumed_token_tracks_normal_processing(self) -> None:
"""Normal tokens forwarded to inner parser update last_consumed_token."""
encoding = get_encoding()
harmony_str = "<|channel|>final<|message|>Hello world<|end|>"
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
assert parser.last_consumed_token is None

for tok in token_ids:
parser.process(tok)

# After processing, last_consumed_token should be the last token
assert parser.last_consumed_token == token_ids[-1]

def test_pattern3_discarded_tokens_not_in_last_consumed(self) -> None:
"""Free-text tokens in EXPECT_START don't update last_consumed_token."""
encoding = get_encoding()
# Complete a message to reach EXPECT_START state
first_msg = "<|channel|>final<|message|>First.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in first_tokens:
parser.process(tok)

last_consumed_after_first = parser.last_consumed_token
assert last_consumed_after_first is not None

# Now feed free-text tokens (not <|start|>) — these should be discarded
garbage_tokens = encoding.encode("some free text", allowed_special="all")
for tok in garbage_tokens:
parser.process(tok)

# last_consumed_token should NOT have changed
assert parser.last_consumed_token == last_consumed_after_first

def test_pattern2_skip_mode_discarded_tokens_not_in_last_consumed(self) -> None:
"""Tokens skipped during Pattern 2 don't update last_consumed_token."""
encoding = get_encoding()
# Complete a first message
first_msg = "<|channel|>final<|message|>First.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")

# Build second message with <|constrain|> in header
start_tok = encoding.encode("<|start|>", allowed_special="all")
role_toks = encoding.encode("assistant", allowed_special="all")
constrain_tok = encoding.encode("<|constrain|>", allowed_special="all")
json_toks = encoding.encode("json", allowed_special="all")
message_tok = encoding.encode("<|message|>", allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in first_tokens:
parser.process(tok)

last_consumed_after_first = parser.last_consumed_token

# Feed <|start|>assistant to enter HEADER state
for tok in start_tok:
parser.process(tok)
for tok in role_toks:
parser.process(tok)

last_consumed_after_header = parser.last_consumed_token

# Feed <|constrain|> to enter skip mode
for tok in constrain_tok:
parser.process(tok)

# last_consumed should not change (constrain triggers skip, not forwarded)
assert parser.last_consumed_token == last_consumed_after_header

# Feed garbage tokens in skip mode — should not update
for tok in json_toks:
parser.process(tok)
assert parser.last_consumed_token == last_consumed_after_header

# Feed <|message|> to exit skip mode — this IS forwarded
for tok in message_tok:
parser.process(tok)
assert parser.last_consumed_token != last_consumed_after_first
31 changes: 31 additions & 0 deletions tests/entrypoints/openai/responses/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,34 @@ def test_parser_state_to_response_output_analysis_channel() -> None:
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"


class TestHarmonyOutputSanitization:
"""Tests that leaked Harmony control tokens are sanitized in output."""

def test_constrain_recipient_treated_as_no_recipient(self):
"""<|constrain|>json as recipient should be sanitized to empty,
falling through to _parse_message_no_recipient (produces message)."""
message = Message.from_role_and_content(Role.ASSISTANT, "Some output text")
message = message.with_channel("commentary")
message = message.with_recipient("<|constrain|>json")

output_items = harmony_to_response_output(message)

# Should produce a message (preamble), not an MCP call
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseOutputMessage)
assert output_items[0].type == "message"

def test_contaminated_tool_name_cleaned_in_function_call(self):
"""Function name with leaked <|channel|> should be sanitized."""
message = Message.from_role_and_content(Role.ASSISTANT, '{"location": "SF"}')
message = message.with_channel("commentary")
message = message.with_recipient("functions.get_weather<|channel|>commentary")

output_items = harmony_to_response_output(message)

assert len(output_items) == 1
assert isinstance(output_items[0], ResponseFunctionToolCall)
assert output_items[0].name == "get_weather"
assert "<|" not in output_items[0].name
5 changes: 4 additions & 1 deletion vllm/entrypoints/openai/chat_completion/stream_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DeltaMessage,
DeltaToolCall,
)
from vllm.entrypoints.openai.parser.harmony_utils import sanitize_harmony_name


class TokenState(NamedTuple):
Expand Down Expand Up @@ -109,7 +110,9 @@ def extract_harmony_streaming_delta(
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_name = sanitize_harmony_name(
group.recipient.split("functions.", 1)[1]
)
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
Expand Down
Loading
Loading