Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,26 @@ def test_train_vlm_text_only_data(self, model_id, dataset_config):
else:
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"

@require_vision
def test_train_vlm_with_max_length(self):
# Regression test for #5283: mm_token_type_ids must be truncated alongside input_ids when max_length is set,
# otherwise a shape mismatch crashes the model forward pass.
# max_length=37 truncates 1 completion token (total_len=38) while keeping all image tokens (prompt_len=34) safe.
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_preference", split="train")
training_args = DPOConfig(
output_dir=self.tmp_dir,
max_length=37, # total_len=38, prompt_len=34 — truncates completion, not image tokens
per_device_train_batch_size=2,
report_to="none",
)
trainer = DPOTrainer(
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
args=training_args,
train_dataset=dataset,
)
trainer.train()
assert trainer.state.log_history[-1]["train_loss"] is not None

@require_peft
@require_bitsandbytes
def test_peft_with_quantization(self):
Expand Down
57 changes: 33 additions & 24 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,27 +970,38 @@ def _precompute_ref_logps(self, dataset: Dataset, name: str, batch_size: int) ->
return dataset

def _truncate_inputs(
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, completion_mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
completion_mask: torch.Tensor,
*extra: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
if self.args.max_length is None:
return input_ids, attention_mask, completion_mask
return input_ids, attention_mask, completion_mask, *extra

if self.args.truncation_mode == "keep_start":
input_ids = input_ids[:, : self.args.max_length]
attention_mask = attention_mask[:, : self.args.max_length]
completion_mask = completion_mask[:, : self.args.max_length]
extra = tuple(t[:, : self.args.max_length] for t in extra)
elif self.args.truncation_mode == "keep_end":
attention_mask, input_ids, completion_mask = flush_right(attention_mask, input_ids, completion_mask)
attention_mask, input_ids, completion_mask, *extra = flush_right(
attention_mask, input_ids, completion_mask, *extra
)
input_ids = input_ids[:, -self.args.max_length :]
attention_mask = attention_mask[:, -self.args.max_length :]
completion_mask = completion_mask[:, -self.args.max_length :]
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
extra = tuple(t[:, -self.args.max_length :] for t in extra)
attention_mask, input_ids, completion_mask, *extra = flush_left(
attention_mask, input_ids, completion_mask, *extra
)
extra = tuple(extra)
else:
raise ValueError(
f"Unsupported truncation mode: {self.args.truncation_mode}, expected 'keep_start' or 'keep_end'"
)

return input_ids, attention_mask, completion_mask
return input_ids, attention_mask, completion_mask, *extra

def compute_ref_log_probs(self, inputs):
"""Computes reference log probabilities for a single padded batch."""
Expand All @@ -999,20 +1010,19 @@ def compute_ref_log_probs(self, inputs):
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)
# token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids
extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs]
input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs(
input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys]
)

shift_labels = input_ids[..., 1:].contiguous()
shift_completion_mask = completion_mask[..., 1:].contiguous()

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key in (
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
for key, val in zip(extra_keys, extra, strict=False):
model_kwargs[key] = val
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"):
if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down Expand Up @@ -1130,17 +1140,16 @@ def _compute_loss(self, model, inputs, return_outputs):
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
completion_mask = inputs["completion_mask"]
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)
# token_type_ids and mm_token_type_ids are sequence-length-aligned: truncate to match input_ids
extra_keys = [k for k in ("token_type_ids", "mm_token_type_ids") if k in inputs]
input_ids, attention_mask, completion_mask, *extra = self._truncate_inputs(
input_ids, attention_mask, completion_mask, *[inputs[k] for k in extra_keys]
)

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key in (
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
for key, val in zip(extra_keys, extra, strict=False):
model_kwargs[key] = val
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes"):
if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down
Loading