Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from transformers.utils import is_peft_available

from trl import DPOConfig, DPOTrainer
from trl.trainer.dpo_trainer import DataCollatorForPreference
from trl.trainer.dpo_trainer import DataCollatorForPreference, DataCollatorForVisionPreference

from .testing_utils import (
TrlTestCase,
Expand Down Expand Up @@ -132,6 +132,38 @@ def test_with_pad_to_multiple_of(self):
torch.testing.assert_close(result["input_ids"], expected_input_ids)


class TestDataCollatorForVisionPreference(TrlTestCase):
@pytest.mark.skipif(
Version(transformers.__version__) < Version("5.3.0"),
reason="mm_token_type_ids are returned by default since transformers-5.3.0 (see transformers#43972)",
)
@require_vision
def test_mm_token_type_ids_shape(self):
# Regression test: when the processor returns mm_token_type_ids (e.g. Qwen2.5-VL after
# transformers#43972), the collator must concatenate it with zeros for the completion part
# so that its shape matches input_ids. Without the fix this raises an IndexError in the model.
from PIL import Image
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration")
collator = DataCollatorForVisionPreference(processor)
image = Image.new("RGB", (16, 16))
examples = [
{
"images": [image],
"prompt": [{"role": "user", "content": "What is this?"}],
"chosen": [{"role": "assistant", "content": "A red square."}],
"rejected": [{"role": "assistant", "content": "A blue circle."}],
}
]
output = collator(examples)
assert "mm_token_type_ids" in output
assert output["mm_token_type_ids"].shape == output["input_ids"].shape, (
f"mm_token_type_ids shape {output['mm_token_type_ids'].shape} != "
f"input_ids shape {output['input_ids'].shape}"
)


class TestDPOTrainer(TrlTestCase):
@pytest.mark.parametrize(
"model_id",
Expand Down
33 changes: 30 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,23 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
rejected_type_ids = processed_rejecteds["token_type_ids"]
completion_token_type_ids = torch.cat(tuple(pad([chosen_type_ids, rejected_type_ids], padding_value=0)))
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
if "mm_token_type_ids" in processed_prompts: # special case for Qwen2.5-VL
prompt_mm_token_type_ids = processed_prompts["mm_token_type_ids"]
mm_token_type_ids = torch.cat((prompt_mm_token_type_ids, torch.zeros_like(completion_ids)), dim=1)

# Flush left to reduce padding
if "token_type_ids" in processed_prompts:
if "token_type_ids" in processed_prompts and "mm_token_type_ids" in processed_prompts:
attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids = flush_left(
attention_mask, input_ids, completion_mask, token_type_ids, mm_token_type_ids
)
elif "token_type_ids" in processed_prompts:
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
attention_mask, input_ids, completion_mask, token_type_ids
)
elif "mm_token_type_ids" in processed_prompts:
attention_mask, input_ids, completion_mask, mm_token_type_ids = flush_left(
attention_mask, input_ids, completion_mask, mm_token_type_ids
)
else:
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)

Expand All @@ -352,6 +363,8 @@ def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
output["completion_mask"] = completion_mask
if "token_type_ids" in processed_prompts:
output["token_type_ids"] = token_type_ids
if "mm_token_type_ids" in processed_prompts:
output["mm_token_type_ids"] = mm_token_type_ids
return output


Expand Down Expand Up @@ -992,7 +1005,14 @@ def compute_ref_log_probs(self, inputs):
shift_completion_mask = completion_mask[..., 1:].contiguous()

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
for key in (
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):
Comment thread
cursor[bot] marked this conversation as resolved.
if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down Expand Up @@ -1113,7 +1133,14 @@ def _compute_loss(self, model, inputs, return_outputs):
input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)

model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask, "use_cache": False}
for key in ("pixel_values", "pixel_attention_mask", "image_grid_thw", "image_sizes", "token_type_ids"):
for key in (
"pixel_values",
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_token_type_ids not truncated when max_length is set

Medium Severity

When max_length is set, _truncate_inputs truncates input_ids, attention_mask, and completion_mask but not mm_token_type_ids. Both compute_ref_log_probs and _compute_loss then pass the truncated input_ids alongside the original full-length mm_token_type_ids from inputs to the model, causing a shape mismatch that will crash during the forward pass.

Additional Locations (2)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it sounds like a legit feedback

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec,

First, this is a pre-existing condition, not introduced by this PR:

  • The same gap for token_type_ids has never caused a filed issue

Second, as already explicitly documented in the tests, for VLMs, truncating can remove image tokens, leading to errors. Anyone doing VLM DPO uses max_length=None

Therefore, in my opinion, if this is an issue we want to handle, I think it should be done in separate PR. I am opening one! 🚀

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same gap for token_type_ids has never caused a filed issue

interesting, curious to know why

Anyone doing VLM DPO uses max_length=None

yes that's recommended, but careful user might use max_length!=0 (technically it's supported, if you ensure that no image token are truncated)

ok a separate pr sounds good!

if key in inputs:
model_kwargs[key] = inputs[key]

Expand Down
Loading