diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index db7702c0524..2f96df2b722 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1212,6 +1212,26 @@ def test_train_vlm_text_only_data(self, model_id, dataset_config): else: assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" + @require_vision + def test_train_vlm_with_max_length(self): + # Regression test for #5283: mm_token_type_ids must be truncated alongside input_ids when max_length is set, + # otherwise a shape mismatch crashes the model forward pass. + # max_length=37 truncates 1 completion token (total_len=38) while keeping all image tokens (prompt_len=34) safe. + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_preference", split="train") + training_args = DPOConfig( + output_dir=self.tmp_dir, + max_length=37, # total_len=38, prompt_len=34 — truncates completion, not image tokens + per_device_train_batch_size=2, + report_to="none", + ) + trainer = DPOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + args=training_args, + train_dataset=dataset, + ) + trainer.train() + assert trainer.state.log_history[-1]["train_loss"] is not None + @require_peft @require_bitsandbytes def test_peft_with_quantization(self): diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 33c4f32ba6d..85bff240c16 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -970,27 +970,38 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) -> return dataset def _truncate_inputs( - self, input_ids: torch.Tensor, attention_mask: torch.Tensor, completion_mask: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + completion_mask: torch.Tensor, + *extra: torch.Tensor, + ) -> tuple[torch.Tensor, ...]: if self.args.max_length is None: - return input_ids, attention_mask, completion_mask + return input_ids, attention_mask, completion_mask, *extra if self.args.truncation_mode == "keep_start": input_ids = input_ids[:, : self.args.max_length] attention_mask = attention_mask[:, : self.args.max_length] completion_mask = completion_mask[:, : self.args.max_length] + extra = tuple(t[:, : self.args.max_length] for t in extra) elif self.args.truncation_mode == "keep_end": - attention_mask, input_ids, completion_mask = flush_right(attention_mask, input_ids, completion_mask) + attention_mask, input_ids, completion_mask, *extra = flush_right( + attention_mask, input_ids, completion_mask, *extra + ) input_ids = input_ids[:, -self.args.max_length :] attention_mask = attention_mask[:, -self.args.max_length :] completion_mask = completion_mask[:, -self.args.max_length :] - attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask) + extra = tuple(t[:, -self.args.max_length :] for t in extra) + attention_mask, input_ids, completion_mask, *extra = flush_left( + attention_mask, input_ids, completion_mask, *extra + ) + extra = tuple(extra) else: raise ValueError( f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'" ) - return input_ids, attention_mask, completion_mask + return input_ids, attention_mask, completion_mask, *extra def compute_ref_log_probs(self, inputs): """Computes reference log probabilities for a single padded batch.""" @@ -999,20 +1010,19 @@ def compute_ref_log_probs(self, inputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) + # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids + extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] + input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( + input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] + ) shift_labels = input_ids[..., 1:].contiguous() shift_completion_mask = completion_mask[..., 1:].contiguous() model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key in ( - "pixel_values", - "pixel_attention_mask", - "image_grid_thw", - "image_sizes", - "token_type_ids", - "mm_token_type_ids", - ): + for key, val in zip(extra_keys, extra, strict=False): + model_kwargs[key] = val + for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): if key in inputs: model_kwargs[key] = inputs[key] @@ -1130,17 +1140,16 @@ def _compute_loss(self, model, inputs, return_outputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) + # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids + extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] + input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( + input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] + ) model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key in ( - "pixel_values", - "pixel_attention_mask", - "image_grid_thw", - "image_sizes", - "token_type_ids", - "mm_token_type_ids", - ): + for key, val in zip(extra_keys, extra, strict=False): + model_kwargs[key] = val + for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): if key in inputs: model_kwargs[key] = inputs[key]