Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 18 additions & 55 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
disable_dropout_in_model,
entropy_from_logits,
flush_left,
flush_right,
get_config_model_id,
hash_module,
pad,
Expand Down Expand Up @@ -446,7 +445,8 @@ class DPOTrainer(_BaseTrainer):
data_collator ([`~transformers.DataCollator`], *optional*):
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
Will default to [`~trainer.dpo_trainer.DataCollatorForPreference`] if the model is a language model and
[`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model.
[`~trainer.dpo_trainer.DataCollatorForVisionPreference`] if the model is a vision-language model. Custom

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you please add the same comment in SFTTrainer.data_collator

@albertvillanova albertvillanova Mar 24, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am planning to add the same comment to SFT in my subsequent PR, when I remove post-collation truncation from SFT as well.

collators must truncate sequences before padding; the trainer does not apply post-collation truncation.
Comment thread
albertvillanova marked this conversation as resolved.
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
Dataset to use for training. This trainer supports both [language modeling](#language-modeling) type and
[prompt-completion](#prompt-completion) type. The format of the samples can be either:
Expand Down Expand Up @@ -1014,60 +1014,25 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) ->

return dataset

def _truncate_inputs(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
completion_mask: torch.Tensor,
*extra: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
if self.args.max_length is None:
return input_ids, attention_mask, completion_mask, *extra

if self.args.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
completion_mask = completion_mask[:, : self.args.max_length]
extra = tuple(t[:, : self.args.max_length] for t in extra)
elif self.args.truncation_mode == "keep_end":
attention_mask, input_ids, completion_mask, *extra = flush_right(
attention_mask, input_ids, completion_mask, *extra
)
input_ids = input_ids[:, -self.args.max_length :]
attention_mask = attention_mask[:, -self.args.max_length :]
completion_mask = completion_mask[:, -self.args.max_length :]
extra = tuple(t[:, -self.args.max_length :] for t in extra)
attention_mask, input_ids, completion_mask, *extra = flush_left(
attention_mask, input_ids, completion_mask, *extra
)
extra = tuple(extra)
else:
raise ValueError(
f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'"
)

return input_ids, attention_mask, completion_mask, *extra

def compute_ref_log_probs(self, inputs):
"""Computes reference log probabilities for a single padded batch."""
device = self.accelerator.device

input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
# token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids
extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs]
input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs(
input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys]
)

shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key, val in zip(extra_keys, extra, strict=False):
model_kwargs[key] = val
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"):
for key in (
"token_type_ids",
"mm_token_type_ids",
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
):
if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down Expand Up @@ -1112,7 +1077,6 @@ def _compute_loss_liger(self, model, inputs, return_outputs):
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)

decoder = model.get_decoder()
outputs = decoder(input_ids, attention_mask=attention_mask, use_cache=False)
Expand Down Expand Up @@ -1185,16 +1149,15 @@ def _compute_loss(self, model, inputs, return_outputs):
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
# token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids
extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs]
input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs(
input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys]
)

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
Comment thread
qgallouedec marked this conversation as resolved.
for key, val in zip(extra_keys, extra, strict=False):
model_kwargs[key] = val
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"):
for key in (
"token_type_ids",
"mm_token_type_ids",
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
):
if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down
Loading