Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions tests/entrypoints/openai/parser/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from vllm.entrypoints.openai.parser.harmony_utils import (
auto_drop_analysis_messages,
get_encoding,
get_streamable_parser_for_assistant,
get_system_message,
has_custom_tools,
parse_chat_input_to_harmony_message,
parse_chat_output,
sanitize_harmony_name,
)
from vllm.entrypoints.openai.responses.harmony import (
response_previous_input_to_harmony,
Expand Down Expand Up @@ -841,3 +843,119 @@ def test_all_standard_channels_present(self) -> None:
assert channel in valid_channels, (
f"{channel} missing when with_custom_tools={with_tools}"
)


class TestSanitizeHarmonyName:
"""Tests for sanitize_harmony_name()."""

def test_clean_name_unchanged(self) -> None:
assert sanitize_harmony_name("get_weather") == "get_weather"

def test_strip_channel_token(self) -> None:
assert (
sanitize_harmony_name("manage_cart<|channel|>commentary") == "manage_cart"
)

def test_strip_constrain_token(self) -> None:
assert sanitize_harmony_name("<|constrain|>json") == ""

def test_pure_control_token_returns_empty(self) -> None:
assert sanitize_harmony_name("<|start|>") == ""

def test_multiple_tokens_earliest_wins(self) -> None:
assert (
sanitize_harmony_name("foo<|channel|>bar<|constrain|>baz") == "foo"
)

def test_empty_string(self) -> None:
assert sanitize_harmony_name("") == ""

def test_trailing_whitespace_stripped(self) -> None:
assert sanitize_harmony_name("tool_name <|end|>") == "tool_name"


class TestResilientStreamableParser:
"""Tests for ResilientStreamableParser error recovery."""

def test_normal_sequence_unchanged(self) -> None:
"""Normal token sequence should produce same results as raw parser."""
encoding = get_encoding()
harmony_str = "<|channel|>final<|message|>Hello world<|end|>"
token_ids = encoding.encode(harmony_str, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in token_ids:
parser.process(tok)

assert len(parser.messages) == 1
assert parser.messages[0].content[0].text == "Hello world"
assert parser.messages[0].channel == "final"

def test_missing_start_recovery(self) -> None:
"""Parser should recover when <|start|> is missing between messages."""
encoding = get_encoding()
# First message completes normally, second is missing <|start|>
first_msg = "<|channel|>final<|message|>First.<|end|>"
second_msg = "<|channel|>final<|message|>Second.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")
second_tokens = encoding.encode(second_msg, allowed_special="all")

parser = get_streamable_parser_for_assistant()
for tok in first_tokens:
parser.process(tok)
# Feed second message tokens directly (missing <|start|>assistant)
for tok in second_tokens:
parser.process(tok)

assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert parser.messages[1].content[0].text == "Second."

def test_constrain_in_header_skipped(self) -> None:
"""<|constrain|> in HEADER state should be skipped gracefully."""
encoding = get_encoding()
# First, complete a normal message so parser goes to EXPECT_START
first_msg = "<|channel|>final<|message|>First.<|end|>"
first_tokens = encoding.encode(first_msg, allowed_special="all")

# Build a second message where <|constrain|> appears in the header
# after <|start|>assistant, before <|channel|>
start_tok = encoding.encode("<|start|>", allowed_special="all")
role_toks = encoding.encode("assistant", allowed_special="all")
constrain_tok = encoding.encode("<|constrain|>", allowed_special="all")
# Garbage tokens that should be skipped
json_toks = encoding.encode("json", allowed_special="all")
message_tok = encoding.encode("<|message|>", allowed_special="all")
text_toks = encoding.encode("Second.", allowed_special="all")
end_tok = encoding.encode("<|end|>", allowed_special="all")

parser = get_streamable_parser_for_assistant()
# Complete first message
for tok in first_tokens:
parser.process(tok)
assert len(parser.messages) == 1

# Feed: <|start|>assistant → puts parser in HEADER state
for tok in start_tok:
parser.process(tok)
for tok in role_toks:
parser.process(tok)
# Feed: <|constrain|> → should enter skip mode
for tok in constrain_tok:
parser.process(tok)
# Feed: json tokens → should be discarded in skip mode
for tok in json_toks:
parser.process(tok)
# Feed: <|message|> → should exit skip mode and resume
for tok in message_tok:
parser.process(tok)
# Feed: text + <|end|>
for tok in text_toks:
parser.process(tok)
for tok in end_tok:
parser.process(tok)

# Should have produced two messages despite the malformed sequence
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
Comment on lines +959 to +960
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This test verifies that the parser produces two messages, but it only asserts the content of the first message. To ensure the error recovery logic for malformed headers is fully functional, it's important to also assert the content of the second message, which should have been parsed correctly after skipping the garbage tokens.

Suggested change
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert len(parser.messages) == 2
assert parser.messages[0].content[0].text == "First."
assert parser.messages[1].content[0].text == "Second."


37 changes: 37 additions & 0 deletions tests/entrypoints/openai/responses/test_harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,40 @@ def test_parser_state_to_response_output_analysis_channel() -> None:
assert len(builtin_items) == 1
assert not isinstance(builtin_items[0], McpCall)
assert builtin_items[0].type == "reasoning"


class TestHarmonyOutputSanitization:
"""Tests that leaked Harmony control tokens are sanitized in output."""

def test_constrain_recipient_treated_as_no_recipient(self):
"""<|constrain|>json as recipient should be sanitized to empty,
falling through to _parse_message_no_recipient (produces message)."""
message = Message.from_role_and_content(
Role.ASSISTANT, "Some output text"
)
message = message.with_channel("commentary")
message = message.with_recipient("<|constrain|>json")

output_items = harmony_to_response_output(message)

# Should produce a message (preamble), not an MCP call
assert len(output_items) == 1
assert isinstance(output_items[0], ResponseOutputMessage)
assert output_items[0].type == "message"

def test_contaminated_tool_name_cleaned_in_function_call(self):
"""Function name with leaked <|channel|> should be sanitized."""
message = Message.from_role_and_content(
Role.ASSISTANT, '{"location": "SF"}'
)
message = message.with_channel("commentary")
message = message.with_recipient(
"functions.get_weather<|channel|>commentary"
)

output_items = harmony_to_response_output(message)

assert len(output_items) == 1
assert isinstance(output_items[0], ResponseFunctionToolCall)
assert output_items[0].name == "get_weather"
assert "<|" not in output_items[0].name
5 changes: 4 additions & 1 deletion vllm/entrypoints/openai/chat_completion/stream_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from openai_harmony import StreamableParser

from vllm.entrypoints.chat_utils import make_tool_call_id
from vllm.entrypoints.openai.parser.harmony_utils import sanitize_harmony_name
from vllm.entrypoints.openai.engine.protocol import (
DeltaFunctionCall,
DeltaMessage,
Expand Down Expand Up @@ -109,7 +110,9 @@ def extract_harmony_streaming_delta(
opened_new_call = False
if prev_recipient != group.recipient:
# New tool call - emit the opening message
tool_name = group.recipient.split("functions.", 1)[1]
tool_name = sanitize_harmony_name(
group.recipient.split("functions.", 1)[1]
)
tool_messages.append(
DeltaToolCall(
id=make_tool_call_id(),
Expand Down
129 changes: 126 additions & 3 deletions vllm/entrypoints/openai/parser/harmony_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ReasoningEffort,
Role,
StreamableParser,
StreamState,
SystemContent,
TextContent,
ToolDescription,
Expand All @@ -27,6 +28,126 @@

logger = init_logger(__name__)

# Harmony special token strings that may leak into tool names or recipients
# during generation by GPT-OSS models.
_HARMONY_SPECIAL_TOKEN_STRS = (
"<|channel|>",
"<|constrain|>",
"<|start|>",
"<|end|>",
"<|message|>",
"<|call|>",
)

# Harmony special token IDs (GPT-OSS encoding)
_TOK_CONSTRAIN = 200003
_TOK_CHANNEL = 200005
_TOK_START = 200006
_TOK_END = 200007
_TOK_MESSAGE = 200008


def sanitize_harmony_name(name: str) -> str:
"""Strip leaked Harmony control tokens from a tool/recipient name.

Finds the earliest Harmony control token string in *name* and returns
only the text before it. Returns the original string unchanged when
no control tokens are present.
"""
earliest = len(name)
for tok in _HARMONY_SPECIAL_TOKEN_STRS:
idx = name.find(tok)
if idx != -1 and idx < earliest:
earliest = idx
return name[:earliest].rstrip()


class ResilientStreamableParser:
"""Wrapper around ``StreamableParser`` that recovers from two common
malformed-output patterns produced by GPT-OSS models:

1. **Missing ``<|start|>`` recovery** – When the parser expects a
``<|start|>`` token but receives ``<|channel|>`` instead, inject the
missing ``<|start|>`` + assistant role token before forwarding.

2. **Malformed ``<|constrain|>`` in headers** – When the parser is in
``HEADER`` state and receives ``<|constrain|>``, enter skip mode and
discard tokens until ``<|message|>`` or ``<|end|>`` is seen.

All public properties are delegated to the underlying parser. The
``current_recipient`` getter additionally sanitizes leaked tokens.
"""

def __init__(self, inner: StreamableParser, encoding):
self._inner = inner
self._encoding = encoding
self._skip_until_message_or_end = False

# --- error-recovering process() -----------------------------------

def process(self, token_id: int) -> None:
# Pattern 2: skip mode – discard until <|message|> or <|end|>
if self._skip_until_message_or_end:
if token_id in (_TOK_MESSAGE, _TOK_END):
self._skip_until_message_or_end = False
self._inner.process(token_id)
# else: silently discard the token
return

state = self._inner.state

# Pattern 1: missing <|start|> before <|channel|>
if state == StreamState.EXPECT_START and token_id == _TOK_CHANNEL:
# Inject <|start|> + assistant role token
self._inner.process(_TOK_START)
role_tokens = self._encoding.encode("assistant", allowed_special="all")
for rt in role_tokens:
self._inner.process(rt)
self._inner.process(token_id)
return

# Pattern 2: <|constrain|> during HEADER → enter skip mode
if state == StreamState.HEADER and token_id == _TOK_CONSTRAIN:
self._skip_until_message_or_end = True
return
Comment on lines +109 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve valid <|constrain|> headers during parsing

<|constrain|> is part of normal Harmony tool-call headers (for example, existing chat tests build calls as ...<|constrain|>json<|message|>...), but this branch unconditionally treats any HEADER-state constrain token as malformed and skips everything until <|message|>/<|end|>. That strips legitimate header metadata (notably content_type) from otherwise valid outputs, so downstream parsing loses type information and can mis-handle non-JSON constrained tool payloads.

Useful? React with 👍 / 👎.


self._inner.process(token_id)

# --- delegated properties -----------------------------------------

@property
def messages(self):
return self._inner.messages

@property
def current_content(self):
return self._inner.current_content

@property
def current_channel(self):
return self._inner.current_channel

@property
def current_recipient(self):
raw = self._inner.current_recipient
if raw is not None:
sanitized = sanitize_harmony_name(raw)
return sanitized if sanitized else None
return raw

@property
def current_role(self):
return self._inner.current_role

@property
def state(self):
return self._inner.state

@property
def last_content_delta(self):
return self._inner.last_content_delta


REASONING_EFFORT = {
"high": ReasoningEffort.HIGH,
"medium": ReasoningEffort.MEDIUM,
Expand Down Expand Up @@ -256,7 +377,7 @@ def parse_chat_input_to_harmony_message(

for call in tool_calls:
func = call.get("function", {})
name = func.get("name", "")
name = sanitize_harmony_name(func.get("name", ""))
arguments = func.get("arguments", "") or ""
msg = Message.from_role_and_content(Role.ASSISTANT, arguments)
msg = msg.with_channel("commentary")
Expand Down Expand Up @@ -328,8 +449,10 @@ def get_stop_tokens_for_assistant_actions() -> list[int]:
return get_encoding().stop_tokens_for_assistant_actions()


def get_streamable_parser_for_assistant() -> StreamableParser:
return StreamableParser(get_encoding(), role=Role.ASSISTANT)
def get_streamable_parser_for_assistant() -> ResilientStreamableParser:
encoding = get_encoding()
inner = StreamableParser(encoding, role=Role.ASSISTANT)
return ResilientStreamableParser(inner, encoding)


def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
Expand Down
13 changes: 10 additions & 3 deletions vllm/entrypoints/openai/responses/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
get_encoding,
get_streamable_parser_for_assistant,
render_for_completion,
sanitize_harmony_name,
)
from vllm.entrypoints.openai.parser.responses_parser import (
get_responses_parser_for_simple_context,
Expand Down Expand Up @@ -669,7 +670,9 @@ def messages(self) -> list:
def need_builtin_tool_call(self) -> bool:
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is None:
if recipient is not None:
recipient = sanitize_harmony_name(recipient)
if not recipient:
return False
if recipient.startswith("browser."):
return "browser" in self.available_tools
Expand All @@ -685,6 +688,8 @@ async def call_tool(self) -> list[Message]:
last_msg = self.messages[-1]
recipient = last_msg.recipient
if recipient is not None:
recipient = sanitize_harmony_name(recipient)
if recipient:
Comment on lines 690 to +692
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The last_msg.recipient is used unsafely to set the Author name in tool responses. Since last_msg.recipient can contain leaked Harmony control tokens (as acknowledged by this PR), an attacker can manipulate the LLM to output a contaminated recipient that, when reflected back in the conversation history as an author name, injects protocol delimiters. This allows for 'protocol smuggling' where a single tool response can be interpreted as multiple messages in subsequent turns, potentially bypassing security controls or misrepresenting the conversation state.

To remediate this, ensure that the recipient is sanitized before being assigned back to the message object or used as an author name.

        if recipient is not None:
            recipient = sanitize_harmony_name(recipient)
            last_msg.recipient = recipient
        if recipient:

if recipient.startswith("browser."):
return await self.call_search_tool(
self._tool_sessions["browser"], last_msg
Expand All @@ -708,7 +713,7 @@ async def call_search_tool(
self.called_tools.add("browser")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1]
tool_name = sanitize_harmony_name(last_msg.recipient.split(".")[1])
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
Expand Down Expand Up @@ -795,7 +800,9 @@ async def call_container_tool(
self.called_tools.add("container")
if isinstance(tool_session, Tool):
return await tool_session.get_result(self)
tool_name = last_msg.recipient.split(".")[1].split(" ")[0]
tool_name = sanitize_harmony_name(
last_msg.recipient.split(".")[1].split(" ")[0]
)
if envs.VLLM_TOOL_JSON_ERROR_AUTOMATIC_RETRY:
try:
args = json.loads(last_msg.content[0].text)
Expand Down
Loading
Loading