Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 174 additions & 0 deletions tests/test_normalize_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# SPDX-License-Identifier: Apache-2.0
"""
Tests for _normalize_messages() in vllm_mlx.server.

_normalize_messages() maps non-standard roles (developer -> system) and merges
consecutive same-role messages before chat template application. This prevents
crashes from Qwen 3.5 and Llama templates that require alternating roles.
"""


class TestNormalizeMessages:
"""Test _normalize_messages() for handling real-world client formats."""

def test_merge_consecutive_system_messages(self):
"""Consecutive system messages are merged into one."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "system", "content": "Always respond in JSON."},
{"role": "user", "content": "Hello"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert result[0]["role"] == "system"
assert "helpful assistant" in result[0]["content"]
assert "JSON" in result[0]["content"]
assert result[1]["role"] == "user"
assert result[1]["content"] == "Hello"

def test_merge_consecutive_user_messages(self):
"""Consecutive user messages are merged into one."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "You are a helper."},
{"role": "user", "content": "First part"},
{"role": "user", "content": "Second part"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert result[1]["role"] == "user"
assert "First part" in result[1]["content"]
assert "Second part" in result[1]["content"]

def test_opencode_format(self):
"""OpenCode's system+system+user+user format is normalized."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "System prompt part 1"},
{"role": "system", "content": "System prompt part 2"},
{"role": "user", "content": "User instruction"},
{"role": "user", "content": "User question"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert result[0]["role"] == "system"
assert result[1]["role"] == "user"

def test_developer_role_mapped_to_system(self):
"""OpenAI Responses API 'developer' role is mapped to 'system'."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "developer", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello"},
]
result = _normalize_messages(messages)
assert result[0]["role"] == "system"
assert result[1]["role"] == "user"

def test_developer_and_system_merged(self):
"""developer + system consecutive messages are merged after role mapping."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "developer", "content": "Part 1"},
{"role": "system", "content": "Part 2"},
{"role": "user", "content": "Hello"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert result[0]["role"] == "system"
assert "Part 1" in result[0]["content"]
assert "Part 2" in result[0]["content"]

def test_already_alternating_unchanged(self):
"""Well-formed alternating messages pass through unchanged."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "You are a helper."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "Bye"},
]
result = _normalize_messages(messages)
assert result == messages

def test_single_message_unchanged(self):
"""Single message passes through unchanged."""
from vllm_mlx.server import _normalize_messages

messages = [{"role": "user", "content": "Hello"}]
result = _normalize_messages(messages)
assert result == messages

def test_empty_messages(self):
"""Empty message list passes through."""
from vllm_mlx.server import _normalize_messages

assert _normalize_messages([]) == []

def test_multimodal_content_preserved(self):
"""Messages with list content (multimodal) are not merged."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "user", "content": "Describe this:"},
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {"url": "http://example.com/img.png"},
},
],
},
]
result = _normalize_messages(messages)
# List content can't be trivially merged with string - kept separate
assert len(result) >= 1

def test_preserves_non_content_fields(self):
"""Fields other than role/content are preserved on the first merged message."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "Part 1", "name": "sys1"},
{"role": "system", "content": "Part 2"},
{"role": "user", "content": "Hello"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert result[0]["role"] == "system"

def test_null_content_not_merged(self):
"""Messages with None content (tool_calls pattern) are not merged."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]},
{"role": "assistant", "content": "Follow-up"},
]
result = _normalize_messages(messages)
# None content can't be merged with string - kept separate
assert len(result) == 2

def test_three_consecutive_system_messages(self):
"""Three consecutive system messages merge into one."""
from vllm_mlx.server import _normalize_messages

messages = [
{"role": "system", "content": "Part 1"},
{"role": "system", "content": "Part 2"},
{"role": "system", "content": "Part 3"},
{"role": "user", "content": "Hello"},
]
result = _normalize_messages(messages)
assert len(result) == 2
assert "Part 1" in result[0]["content"]
assert "Part 3" in result[0]["content"]
62 changes: 62 additions & 0 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,12 +1326,14 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
messages.append(msg_dict)
images, videos = [], [] # MLLM extracts these from messages
logger.debug(f"MLLM: Processing {len(messages)} messages")
messages = _normalize_messages(messages)
else:
# For LLM, extract text, images, and videos separately
messages, images, videos = extract_multimodal_content(
request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
messages = _normalize_messages(messages)

has_media = bool(images or videos)

Expand Down Expand Up @@ -1434,6 +1436,64 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
)


def _normalize_messages(messages: list[dict]) -> list[dict]:
"""Normalize message roles and merge consecutive same-role messages.

1. Maps non-standard roles to standard ones (e.g. ``developer`` -> ``system``).
2. Merges consecutive same-role messages to satisfy chat template constraints
(Qwen 3.5, Llama, etc. require alternating roles).

Only merges when both messages have string content. Messages with list
content (multimodal) are left as-is to preserve image/video attachments.

Args:
messages: List of message dicts with 'role' and 'content' keys.

Returns:
New list with normalized roles and consecutive same-role messages merged.
"""
# OpenAI Responses API uses "developer" instead of "system".
# Map it so chat templates don't fail and fall back to raw prefill.
_ROLE_MAP = {"developer": "system"}

if not messages:
return messages

merged = [messages[0].copy()]
if merged[0]["role"] in _ROLE_MAP:
merged[0]["role"] = _ROLE_MAP[merged[0]["role"]]
for msg in messages[1:]:
prev = merged[-1]
role = _ROLE_MAP.get(msg["role"], msg["role"])
if (
role == prev["role"]
and isinstance(prev.get("content"), str)
and isinstance(msg.get("content"), str)
):
# Merge string content with double newline separator
prev["content"] = prev["content"] + "\n\n" + msg["content"]
logger.debug(
f"Merged consecutive {role} messages "
f"({len(prev['content'])} chars total)"
)
else:
copy = msg.copy()
copy["role"] = role
merged.append(copy)

mapped_roles = sum(1 for m in messages if m["role"] in _ROLE_MAP)
merged_count = len(messages) - len(merged)
if mapped_roles or merged_count:
parts = []
if mapped_roles:
parts.append(f"mapped {mapped_roles} role(s)")
if merged_count:
parts.append(f"merged {len(messages)} -> {len(merged)}")
logger.info(f"Normalized messages: {', '.join(parts)}")

return merged


def _inject_json_instruction(messages: list, instruction: str) -> list:
"""
Inject JSON instruction into messages.
Expand Down Expand Up @@ -1529,6 +1589,7 @@ async def create_anthropic_message(
openai_request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
messages = _normalize_messages(messages)

chat_kwargs = {
"max_tokens": openai_request.max_tokens or _default_max_tokens,
Expand Down Expand Up @@ -1686,6 +1747,7 @@ async def _stream_anthropic_messages(
openai_request.messages,
preserve_native_format=engine.preserve_native_tool_format,
)
messages = _normalize_messages(messages)

chat_kwargs = {
"max_tokens": openai_request.max_tokens or _default_max_tokens,
Expand Down
Loading