Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_prompt_canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for system-prompt canonicalization."""

from vllm_mlx.api.prompt_canonicalize import canonicalize_system_prompt


def test_canonicalize_system_prompt_strips_anthropic_billing_header_line():
text = (
"You are a coding assistant.\n"
"x-anthropic-billing-header: account=abc; cch=rotating-hash\n"
"Follow the repository instructions."
)

expected = "\n".join(
["You are a coding assistant.", "Follow the repository instructions."]
)
assert canonicalize_system_prompt(text) == expected


def test_canonicalize_system_prompt_is_idempotent_for_clean_input():
text = "You are a coding assistant.\nCurrent time: 2026-05-12T01:00:00Z"

assert canonicalize_system_prompt(text) == text


def test_canonicalize_system_prompt_does_not_strip_user_visible_timestamp():
text = "Current time: 2026-05-12T01:00:00Z\nUse it when answering."

assert canonicalize_system_prompt(text) == text


def test_canonicalize_system_prompt_handles_empty_and_none():
assert canonicalize_system_prompt("") == ""
assert canonicalize_system_prompt(None) is None
24 changes: 24 additions & 0 deletions tests/test_responses_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,30 @@ def test_developer_role_is_normalized_to_system(self, client):
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hi"

def test_system_prompt_canonicalization_strips_billing_header(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("Ready"))
srv._engine = engine

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"instructions": (
"Be terse.\n"
"x-anthropic-billing-header: account=abc; cch=rotating-hash\n"
"Answer directly."
),
"input": "Say hello",
},
)

assert resp.status_code == 200
messages = engine.chat.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Be terse.\nAnswer directly."

def test_instructions_and_developer_message_are_merged(self, client):
import vllm_mlx.server as srv

Expand Down
40 changes: 40 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,46 @@ def test_max_tokens_must_be_positive(self):
)


class TestPromptCanonicalization:
"""Test OpenAI-path system prompt canonicalization."""

def test_prepare_chat_completion_canonicalizes_system_prompt(self):
from vllm_mlx.server import (
ChatCompletionRequest,
Message,
_prepare_chat_completion_invocation,
)

engine = SimpleNamespace(is_mllm=False, preserve_native_tool_format=False)
system_text = (
"You are a coding assistant.\n"
"x-anthropic-billing-header: account=abc; cch=rotating-hash\n"
"Follow the repository instructions."
)
request = ChatCompletionRequest(
model="test-model",
messages=[
Message(role="system", content=system_text),
Message(
role="user",
content="x-anthropic-billing-header: keep user content",
),
],
)

prepared = _prepare_chat_completion_invocation(engine, request, 128)

expected_system_text = "\n".join(
["You are a coding assistant.", "Follow the repository instructions."]
)
assert prepared.messages[0]["content"] == expected_system_text
assert (
prepared.messages[1]["content"]
== "x-anthropic-billing-header: keep user content"
)
assert request.messages[0].content == system_text


class TestCompletionRequest:
"""Test CompletionRequest model."""

Expand Down
51 changes: 51 additions & 0 deletions vllm_mlx/api/prompt_canonicalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-License-Identifier: Apache-2.0
"""System-prompt canonicalization helpers."""

from __future__ import annotations

import re

_STRIPPERS: tuple[tuple[str, re.Pattern[str], str], ...] = (
(
"anthropic_billing_header",
re.compile(r"(?im)^x-anthropic-billing-header:[^\n]*(?:\n|$)"),
"",
),
)


def canonicalize_system_prompt(text: str | None) -> str | None:
"""Remove known non-semantic volatile lines from system prompt text."""
if text is None:
return None

for _name, pattern, replacement in _STRIPPERS:
text = pattern.sub(replacement, text)
return text


def canonicalize_system_messages(messages: list[dict]) -> list[dict]:
"""Canonicalize string content on system-role messages without mutation."""
canonicalized: list[dict] = []
changed = False
for message in messages:
if message.get("role") != "system":
canonicalized.append(message)
continue

content = message.get("content")
if not isinstance(content, str):
canonicalized.append(message)
continue

next_content = canonicalize_system_prompt(content)
if next_content == content:
canonicalized.append(message)
continue

changed = True
next_message = message.copy()
next_message["content"] = next_content
canonicalized.append(next_message)

return canonicalized if changed else messages
4 changes: 4 additions & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
Usage, # noqa: F401
VideoUrl, # noqa: F401
)
from .api.prompt_canonicalize import canonicalize_system_messages
from .api.responses_models import (
ResponseCompletedEvent,
ResponseContentPartAddedEvent,
Expand Down Expand Up @@ -306,6 +307,8 @@ def _prepare_chat_messages(
)
messages = _normalize_messages(messages)

messages = canonicalize_system_messages(messages)

has_media = bool(images or videos or audios)
if is_mllm and not has_media:
# MLLM extracts media from messages directly, so images/videos are
Expand Down Expand Up @@ -2082,6 +2085,7 @@ def _prepare_responses_request(
chat_request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
messages = canonicalize_system_messages(messages)

chat_kwargs = {
"max_tokens": chat_request.max_tokens or _default_max_tokens,
Expand Down
Loading