Skip to content
Merged
23 changes: 23 additions & 0 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,29 @@ def test_message_with_tool_calling_turns(self):

assert messages == expected

def test_prepared_image_blocks_without_new_images(self):
"""Test that existing image payloads are preserved when no new images are provided."""
image = Image.new("RGB", (10, 10), color="blue")
messages = [
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{"role": "assistant", "content": "It is blue."},
]

messages = prepare_multimodal_messages(messages)

expected = [
{
"role": "user",
"content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
},
{"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
]

assert messages == expected
Comment thread
cursor[bot] marked this conversation as resolved.


@require_vision
class TestPrepareMultimodalMessagesVLLM:
Expand Down
29 changes: 16 additions & 13 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}`
and inserting `{"type": "image"}` placeholders for the images *before* the first user message.
If the number of placeholders does not match the number of provided images, an error is raised.
- Existing image blocks that already include an `"image"` payload are preserved as-is. Only unfilled image
placeholders are counted and populated from `images`.

Example:
```python
Expand Down Expand Up @@ -94,7 +96,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N

# Then, check that the number of image placeholders matches the number of images provided
num_placeholders = sum(
sum(1 for part in message["content"] if part["type"] == "image")
sum(1 for part in message["content"] if part["type"] == "image" and "image" not in part)
for message in new_messages
if message.get("content") and message["role"] != "tool"
)
Expand All @@ -104,18 +106,19 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N
)

# Then, fill in the actual images in the placeholders
img_idx = 0
for i, message in enumerate(new_messages):
if not message.get("content") or message["role"] == "tool":
continue
new_content = []
for part in message["content"]:
if part["type"] == "image":
new_content.append({**part, "image": images[img_idx]})
img_idx += 1
else:
new_content.append(part)
new_messages[i] = {**message, "content": new_content}
if images:
img_idx = 0
for i, message in enumerate(new_messages):
if not message.get("content") or message["role"] == "tool":
continue
new_content = []
for part in message["content"]:
if part["type"] == "image" and "image" not in part:
new_content.append({**part, "image": images[img_idx]})
img_idx += 1
else:
new_content.append(part)
new_messages[i] = {**message, "content": new_content}

return new_messages

Expand Down
11 changes: 1 addition & 10 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,17 +1275,8 @@ def _tokenize_prompts(self, prompts: list):
"""Tokenize prompts and extract images/multimodal fields for generation."""
if is_conversational({"prompt": prompts[0]}):
# Normalize string content to content blocks for VLM processors that don't handle plain strings.
# Use copies to avoid mutating the original prompts.
if self._is_vlm:
prompts = [
[
{**msg, "content": [{"type": "text", "text": msg["content"]}]}
if isinstance(msg.get("content"), str)
else msg
for msg in prompt
]
for prompt in prompts
]
prompts = [prepare_multimodal_messages(prompt) for prompt in prompts]
Comment thread
cursor[bot] marked this conversation as resolved.

# Extract images from messages for VLM support
images = []
Expand Down
Loading