diff --git a/tests/test_prompt_canonicalize.py b/tests/test_prompt_canonicalize.py new file mode 100644 index 00000000..5d11fef1 --- /dev/null +++ b/tests/test_prompt_canonicalize.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for system-prompt canonicalization.""" + +from vllm_mlx.api.prompt_canonicalize import canonicalize_system_prompt + + +def test_canonicalize_system_prompt_strips_anthropic_billing_header_line(): + text = ( + "You are a coding assistant.\n" + "x-anthropic-billing-header: account=abc; cch=rotating-hash\n" + "Follow the repository instructions." + ) + + expected = "\n".join( + ["You are a coding assistant.", "Follow the repository instructions."] + ) + assert canonicalize_system_prompt(text) == expected + + +def test_canonicalize_system_prompt_is_idempotent_for_clean_input(): + text = "You are a coding assistant.\nCurrent time: 2026-05-12T01:00:00Z" + + assert canonicalize_system_prompt(text) == text + + +def test_canonicalize_system_prompt_does_not_strip_user_visible_timestamp(): + text = "Current time: 2026-05-12T01:00:00Z\nUse it when answering." + + assert canonicalize_system_prompt(text) == text + + +def test_canonicalize_system_prompt_handles_empty_and_none(): + assert canonicalize_system_prompt("") == "" + assert canonicalize_system_prompt(None) is None diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py index 1c4d9d98..2406104b 100644 --- a/tests/test_responses_api.py +++ b/tests/test_responses_api.py @@ -416,6 +416,30 @@ def test_developer_role_is_normalized_to_system(self, client): assert messages[1]["role"] == "user" assert messages[1]["content"] == "Hi" + def test_system_prompt_canonicalization_strips_billing_header(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": ( + "Be terse.\n" + "x-anthropic-billing-header: account=abc; cch=rotating-hash\n" + "Answer directly." + ), + "input": "Say hello", + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be terse.\nAnswer directly." + def test_instructions_and_developer_message_are_merged(self, client): import vllm_mlx.server as srv diff --git a/tests/test_server.py b/tests/test_server.py index bcd2a9a6..931ded6e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -170,6 +170,46 @@ def test_max_tokens_must_be_positive(self): ) +class TestPromptCanonicalization: + """Test OpenAI-path system prompt canonicalization.""" + + def test_prepare_chat_completion_canonicalizes_system_prompt(self): + from vllm_mlx.server import ( + ChatCompletionRequest, + Message, + _prepare_chat_completion_invocation, + ) + + engine = SimpleNamespace(is_mllm=False, preserve_native_tool_format=False) + system_text = ( + "You are a coding assistant.\n" + "x-anthropic-billing-header: account=abc; cch=rotating-hash\n" + "Follow the repository instructions." + ) + request = ChatCompletionRequest( + model="test-model", + messages=[ + Message(role="system", content=system_text), + Message( + role="user", + content="x-anthropic-billing-header: keep user content", + ), + ], + ) + + prepared = _prepare_chat_completion_invocation(engine, request, 128) + + expected_system_text = "\n".join( + ["You are a coding assistant.", "Follow the repository instructions."] + ) + assert prepared.messages[0]["content"] == expected_system_text + assert ( + prepared.messages[1]["content"] + == "x-anthropic-billing-header: keep user content" + ) + assert request.messages[0].content == system_text + + class TestCompletionRequest: """Test CompletionRequest model.""" diff --git a/vllm_mlx/api/prompt_canonicalize.py b/vllm_mlx/api/prompt_canonicalize.py new file mode 100644 index 00000000..dcc1c9f3 --- /dev/null +++ b/vllm_mlx/api/prompt_canonicalize.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +"""System-prompt canonicalization helpers.""" + +from __future__ import annotations + +import re + +_STRIPPERS: tuple[tuple[str, re.Pattern[str], str], ...] = ( + ( + "anthropic_billing_header", + re.compile(r"(?im)^x-anthropic-billing-header:[^\n]*(?:\n|$)"), + "", + ), +) + + +def canonicalize_system_prompt(text: str | None) -> str | None: + """Remove known non-semantic volatile lines from system prompt text.""" + if text is None: + return None + + for _name, pattern, replacement in _STRIPPERS: + text = pattern.sub(replacement, text) + return text + + +def canonicalize_system_messages(messages: list[dict]) -> list[dict]: + """Canonicalize string content on system-role messages without mutation.""" + canonicalized: list[dict] = [] + changed = False + for message in messages: + if message.get("role") != "system": + canonicalized.append(message) + continue + + content = message.get("content") + if not isinstance(content, str): + canonicalized.append(message) + continue + + next_content = canonicalize_system_prompt(content) + if next_content == content: + canonicalized.append(message) + continue + + changed = True + next_message = message.copy() + next_message["content"] = next_content + canonicalized.append(next_message) + + return canonicalized if changed else messages diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 0776bc2f..57402376 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -105,6 +105,7 @@ Usage, # noqa: F401 VideoUrl, # noqa: F401 ) +from .api.prompt_canonicalize import canonicalize_system_messages from .api.responses_models import ( ResponseCompletedEvent, ResponseContentPartAddedEvent, @@ -306,6 +307,8 @@ def _prepare_chat_messages( ) messages = _normalize_messages(messages) + messages = canonicalize_system_messages(messages) + has_media = bool(images or videos or audios) if is_mllm and not has_media: # MLLM extracts media from messages directly, so images/videos are @@ -2082,6 +2085,7 @@ def _prepare_responses_request( chat_request.messages, preserve_native_format=engine.preserve_native_tool_format, ) + messages = canonicalize_system_messages(messages) chat_kwargs = { "max_tokens": chat_request.max_tokens or _default_max_tokens,