diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index a4232739735..fabc29798ac 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -244,6 +244,29 @@ def test_message_with_tool_calling_turns(self): assert messages == expected + def test_prepared_image_blocks_without_new_images(self): + """Test that existing image payloads are preserved when no new images are provided.""" + image = Image.new("RGB", (10, 10), color="blue") + messages = [ + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + {"role": "assistant", "content": "It is blue."}, + ] + + messages = prepare_multimodal_messages(messages) + + expected = [ + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + ] + + assert messages == expected + @require_vision class TestPrepareMultimodalMessagesVLLM: diff --git a/trl/data_utils.py b/trl/data_utils.py index 46fd4a96b68..617f0052767 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -53,6 +53,8 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}` and inserting `{"type": "image"}` placeholders for the images *before* the first user message. If the number of placeholders does not match the number of provided images, an error is raised. + - Existing image blocks that already include an `"image"` payload are preserved as-is. Only unfilled image + placeholders are counted and populated from `images`. Example: ```python @@ -94,7 +96,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N # Then, check that the number of image placeholders matches the number of images provided num_placeholders = sum( - sum(1 for part in message["content"] if part["type"] == "image") + sum(1 for part in message["content"] if part["type"] == "image" and "image" not in part) for message in new_messages if message.get("content") and message["role"] != "tool" ) @@ -104,18 +106,19 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list | N ) # Then, fill in the actual images in the placeholders - img_idx = 0 - for i, message in enumerate(new_messages): - if not message.get("content") or message["role"] == "tool": - continue - new_content = [] - for part in message["content"]: - if part["type"] == "image": - new_content.append({**part, "image": images[img_idx]}) - img_idx += 1 - else: - new_content.append(part) - new_messages[i] = {**message, "content": new_content} + if images: + img_idx = 0 + for i, message in enumerate(new_messages): + if not message.get("content") or message["role"] == "tool": + continue + new_content = [] + for part in message["content"]: + if part["type"] == "image" and "image" not in part: + new_content.append({**part, "image": images[img_idx]}) + img_idx += 1 + else: + new_content.append(part) + new_messages[i] = {**message, "content": new_content} return new_messages diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bcabd5e17be..1aafe324b27 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1275,17 +1275,8 @@ def _tokenize_prompts(self, prompts: list): """Tokenize prompts and extract images/multimodal fields for generation.""" if is_conversational({"prompt": prompts[0]}): # Normalize string content to content blocks for VLM processors that don't handle plain strings. - # Use copies to avoid mutating the original prompts. if self._is_vlm: - prompts = [ - [ - {**msg, "content": [{"type": "text", "text": msg["content"]}]} - if isinstance(msg.get("content"), str) - else msg - for msg in prompt - ] - for prompt in prompts - ] + prompts = [prepare_multimodal_messages(prompt) for prompt in prompts] # Extract images from messages for VLM support images = []