diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index dd5f4fa4699..db7702c0524 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -22,7 +22,7 @@ from transformers.utils import is_peft_available from trl import DPOConfig, DPOTrainer -from trl.trainer.dpo_trainer import DataCollatorForPreference +from trl.trainer.dpo_trainer import DataCollatorForPreference, DataCollatorForVisionPreference from .testing_utils import ( TrlTestCase, @@ -132,6 +132,38 @@ def test_with_pad_to_multiple_of(self): torch.testing.assert_close(result["input_ids"], expected_input_ids) +class TestDataCollatorForVisionPreference(TrlTestCase): + @pytest.mark.skipif( + Version(transformers.__version__) < Version("5.3.0"), + reason="mm_token_type_ids are returned by default since transformers-5.3.0 (see transformers#43972)", + ) + @require_vision + def test_mm_token_type_ids_shape(self): + # Regression test: when the processor returns mm_token_type_ids (e.g. Qwen2.5-VL after + # transformers#43972), the collator must concatenate it with zeros for the completion part + # so that its shape matches input_ids. Without the fix this raises an IndexError in the model. + from PIL import Image + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration") + collator = DataCollatorForVisionPreference(processor) + image = Image.new("RGB", (16, 16)) + examples = [ + { + "images": [image], + "prompt": [{"role": "user", "content": "What is this?"}], + "chosen": [{"role": "assistant", "content": "A red square."}], + "rejected": [{"role": "assistant", "content": "A blue circle."}], + } + ] + output = collator(examples) + assert "mm_token_type_ids" in output + assert output["mm_token_type_ids"].shape == output["input_ids"].shape, ( + f"mm_token_type_ids shape {output['mm_token_type_ids'].shape} != " + f"input_ids shape {output['input_ids'].shape}" + ) + + class TestDPOTrainer(TrlTestCase): @pytest.mark.parametrize( "model_id", diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8e78a1e2cc0..33c4f32ba6d 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -336,12 +336,23 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: rejected_type_ids = processed_rejecteds["token_type_ids"] completion_token_type_ids = torch.cat(tuple(pad([chosen_type_ids, rejected_type_ids], padding_value=0))) token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1) + if "mm_token_type_ids" in processed_prompts: # special case for Qwen2.5-VL + prompt_mm_token_type_ids = processed_prompts["mm_token_type_ids"] + mm_token_type_ids = torch.cat((prompt_mm_token_type_ids, torch.zeros_like(completion_ids)), dim=1) # Flush left to reduce padding - if "token_type_ids" in processed_prompts: + if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts: + attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids = flush_left( + attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids + ) + elif "token_type_ids" in processed_prompts: attention_mask, input_ids, completion_mask, token_type_ids = flush_left( attention_mask, input_ids, completion_mask, token_type_ids ) + elif "mm_token_type_ids" in processed_prompts: + attention_mask, input_ids, completion_mask, mm_token_type_ids = flush_left( + attention_mask, input_ids, completion_mask, mm_token_type_ids + ) else: attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) @@ -352,6 +363,8 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]: output["completion_mask"] = completion_mask if "token_type_ids" in processed_prompts: output["token_type_ids"] = token_type_ids + if "mm_token_type_ids" in processed_prompts: + output["mm_token_type_ids"] = mm_token_type_ids return output @@ -992,7 +1005,14 @@ def compute_ref_log_probs(self, inputs): shift_completion_mask = completion_mask[..., 1:].contiguous() model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): + for key in ( + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): if key in inputs: model_kwargs[key] = inputs[key] @@ -1113,7 +1133,14 @@ def _compute_loss(self, model, inputs, return_outputs): input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"): + for key in ( + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + "token_type_ids", + "mm_token_type_ids", + ): if key in inputs: model_kwargs[key] = inputs[key]