From 4f14a3c6d338e8202524a399245b37407ce92e54 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:01:08 +0100 Subject: [PATCH 1/5] Do not call _truncate_inputs in DPOTrainer --- trl/trainer/dpo_trainer.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 217dc55214f..ff2399edec6 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1055,19 +1055,18 @@ def compute_ref_log_probs(self, inputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids - extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] - input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( - input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] - ) - shift_labels = input_ids[..., 1:].contiguous() shift_completion_mask = completion_mask[..., 1:].contiguous() model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key, val in zip(extra_keys, extra, strict=False): - model_kwargs[key] = val - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): + for key in ( + "token_type_ids", + "mm_token_type_ids", + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + ): if key in inputs: model_kwargs[key] = inputs[key] @@ -1112,7 +1111,6 @@ def _compute_loss_liger(self, model, inputs, return_outputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) decoder = model.get_decoder() outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False) @@ -1185,16 +1183,15 @@ def _compute_loss(self, model, inputs, return_outputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids - extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] - input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( - input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] - ) - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key, val in zip(extra_keys, extra, strict=False): - model_kwargs[key] = val - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): + for key in ( + "token_type_ids", + "mm_token_type_ids", + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + ): if key in inputs: model_kwargs[key] = inputs[key] From c07503248766319c8d7043363d48feab6c148eec Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:01:36 +0100 Subject: [PATCH 2/5] Remove DPOTrainer._truncate_inputs --- trl/trainer/dpo_trainer.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index ff2399edec6..8f5245b1a83 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -56,7 +56,6 @@ disable_dropout_in_model, entropy_from_logits, flush_left, - flush_right, get_config_model_id, hash_module, pad, @@ -1014,40 +1013,6 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) -> return dataset - def _truncate_inputs( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - completion_mask: torch.Tensor, - *extra: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - if self.args.max_length is None: - return input_ids, attention_mask, completion_mask, *extra - - if self.args.truncation_mode == "keep_start": - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - completion_mask = completion_mask[:, : self.args.max_length] - extra = tuple(t[:, : self.args.max_length] for t in extra) - elif self.args.truncation_mode == "keep_end": - attention_mask, input_ids, completion_mask, *extra = flush_right( - attention_mask, input_ids, completion_mask, *extra - ) - input_ids = input_ids[:, -self.args.max_length :] - attention_mask = attention_mask[:, -self.args.max_length :] - completion_mask = completion_mask[:, -self.args.max_length :] - extra = tuple(t[:, -self.args.max_length :] for t in extra) - attention_mask, input_ids, completion_mask, *extra = flush_left( - attention_mask, input_ids, completion_mask, *extra - ) - extra = tuple(extra) - else: - raise ValueError( - f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'" - ) - - return input_ids, attention_mask, completion_mask, *extra - def compute_ref_log_probs(self, inputs): """Computes reference log probabilities for a single padded batch.""" device = self.accelerator.device From 5dade1bea5395fcb5a2f3ca0a715f8c4a0acb639 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:02:52 +0100 Subject: [PATCH 3/5] Raise ValueError for custom collator and max_length --- trl/trainer/dpo_trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 8f5245b1a83..7c92ae1fbc2 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -651,6 +651,12 @@ def __init__( max_length=args.max_length, pad_to_multiple_of=args.pad_to_multiple_of, ) + elif data_collator is not None and args.max_length is not None: + raise ValueError( + "Cannot use `max_length` with a custom `data_collator`. `max_length` is passed to the default " + "collators to control truncation; with a custom collator it has no effect. Configure truncation " + "directly in your collator." + ) # Training arguments self.beta = args.beta From bbdc3ec9f7152545f98938af66211ec6b77216d5 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:03:30 +0100 Subject: [PATCH 4/5] Update data_collator docstring to specify it must truncate --- trl/trainer/dpo_trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 7c92ae1fbc2..4f1a18b5af4 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -445,7 +445,8 @@ class DPOTrainer(_BaseTrainer): data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and - [`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model. + [`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model. Custom + collators must truncate sequences before padding; the trainer does not apply post-collation truncation. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: From 36c06515521894bfcafad2006a753f2e03efb276 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 23 Mar 2026 11:39:22 +0100 Subject: [PATCH 5/5] Revert "Raise ValueError for custom collator and max_length" This reverts commit 5dade1bea5395fcb5a2f3ca0a715f8c4a0acb639. --- trl/trainer/dpo_trainer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 4f1a18b5af4..4a86e9ce4f9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -652,12 +652,6 @@ def __init__( max_length=args.max_length, pad_to_multiple_of=args.pad_to_multiple_of, ) - elif data_collator is not None and args.max_length is not None: - raise ValueError( - "Cannot use `max_length` with a custom `data_collator`. `max_length` is passed to the default " - "collators to control truncation; with a custom collator it has no effect. Configure truncation " - "directly in your collator." - ) # Training arguments self.beta = args.beta