diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 217dc55214f..4a86e9ce4f9 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -56,7 +56,6 @@ disable_dropout_in_model, entropy_from_logits, flush_left, - flush_right, get_config_model_id, hash_module, pad, @@ -446,7 +445,8 @@ class DPOTrainer(_BaseTrainer): data_collator ([`~transformers.DataCollator`], *optional*): Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`. Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and - [`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model. + [`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model. Custom + collators must truncate sequences before padding; the trainer does not apply post-collation truncation. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and [prompt-completion](#prompt-completion) type. The format of the samples can be either: @@ -1014,40 +1014,6 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) -> return dataset - def _truncate_inputs( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - completion_mask: torch.Tensor, - *extra: torch.Tensor, - ) -> tuple[torch.Tensor, ...]: - if self.args.max_length is None: - return input_ids, attention_mask, completion_mask, *extra - - if self.args.truncation_mode == "keep_start": - input_ids = input_ids[:, : self.args.max_length] - attention_mask = attention_mask[:, : self.args.max_length] - completion_mask = completion_mask[:, : self.args.max_length] - extra = tuple(t[:, : self.args.max_length] for t in extra) - elif self.args.truncation_mode == "keep_end": - attention_mask, input_ids, completion_mask, *extra = flush_right( - attention_mask, input_ids, completion_mask, *extra - ) - input_ids = input_ids[:, -self.args.max_length :] - attention_mask = attention_mask[:, -self.args.max_length :] - completion_mask = completion_mask[:, -self.args.max_length :] - extra = tuple(t[:, -self.args.max_length :] for t in extra) - attention_mask, input_ids, completion_mask, *extra = flush_left( - attention_mask, input_ids, completion_mask, *extra - ) - extra = tuple(extra) - else: - raise ValueError( - f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'" - ) - - return input_ids, attention_mask, completion_mask, *extra - def compute_ref_log_probs(self, inputs): """Computes reference log probabilities for a single padded batch.""" device = self.accelerator.device @@ -1055,19 +1021,18 @@ def compute_ref_log_probs(self, inputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids - extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] - input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( - input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] - ) - shift_labels = input_ids[..., 1:].contiguous() shift_completion_mask = completion_mask[..., 1:].contiguous() model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key, val in zip(extra_keys, extra, strict=False): - model_kwargs[key] = val - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): + for key in ( + "token_type_ids", + "mm_token_type_ids", + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + ): if key in inputs: model_kwargs[key] = inputs[key] @@ -1112,7 +1077,6 @@ def _compute_loss_liger(self, model, inputs, return_outputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) decoder = model.get_decoder() outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False) @@ -1185,16 +1149,15 @@ def _compute_loss(self, model, inputs, return_outputs): input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] completion_mask = inputs["completion_mask"] - # token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids - extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs] - input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs( - input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys] - ) - model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False} - for key, val in zip(extra_keys, extra, strict=False): - model_kwargs[key] = val - for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"): + for key in ( + "token_type_ids", + "mm_token_type_ids", + "pixel_values", + "pixel_attention_mask", + "image_grid_thw", + "image_sizes", + ): if key in inputs: model_kwargs[key] = inputs[key]