Skip to content

Fix mm_token_type_ids silently dropped in DPO VLM training#5279

Merged
albertvillanova merged 3 commits into
huggingface:mainfrom
albertvillanova:fix-5277
Mar 13, 2026
Merged

Fix mm_token_type_ids silently dropped in DPO VLM training#5279
albertvillanova merged 3 commits into
huggingface:mainfrom
albertvillanova:fix-5277

Conversation

@albertvillanova

@albertvillanova albertvillanova commented Mar 12, 2026

Copy link
Copy Markdown
Member

Fix mm_token_type_ids silently dropped in DPO VLM training.

Fix #5277.


Note

Medium Risk
Touches the VLM batching/forward kwargs path for DPO, so mistakes could break multimodal training or cause subtle tensor-shape/runtime errors, but the change is narrow and covered by a targeted regression test.

Overview
Fixes DPO vision-language training for processors that return mm_token_type_ids (e.g. Qwen2.5-VL), ensuring the field is not dropped and its tensor shape stays aligned with input_ids.

DataCollatorForVisionPreference now pads/flushes mm_token_type_ids alongside other batch tensors and includes it in the batch output, and DPOTrainer forwards mm_token_type_ids into both reference-logprob computation and the main loss forward pass. Adds a regression test validating mm_token_type_ids presence and shape.

Written by Cursor Bugbot for commit 4842af6. This will update automatically on new commits. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread trl/trainer/dpo_trainer.py

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

"image_sizes",
"token_type_ids",
"mm_token_type_ids",
):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm_token_type_ids not truncated when max_length is set

Medium Severity

When max_length is set, _truncate_inputs truncates input_ids, attention_mask, and completion_mask but not mm_token_type_ids. Both compute_ref_log_probs and _compute_loss then pass the truncated input_ids alongside the original full-length mm_token_type_ids from inputs to the model, causing a shape mismatch that will crash during the forward pass.

Additional Locations (2)
Fix in Cursor Fix in Web

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it sounds like a legit feedback

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec,

First, this is a pre-existing condition, not introduced by this PR:

  • The same gap for token_type_ids has never caused a filed issue

Second, as already explicitly documented in the tests, for VLMs, truncating can remove image tokens, leading to errors. Anyone doing VLM DPO uses max_length=None

Therefore, in my opinion, if this is an issue we want to handle, I think it should be done in separate PR. I am opening one! 🚀

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same gap for token_type_ids has never caused a filed issue

interesting, curious to know why

Anyone doing VLM DPO uses max_length=None

yes that's recommended, but careful user might use max_length!=0 (technically it's supported, if you ensure that no image token are truncated)

ok a separate pr sounds good!

@albertvillanova albertvillanova merged commit e9c1821 into huggingface:main Mar 13, 2026
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPOTrainer silently ignores mm_token_type_ids when training with Qwen2.5-VL and similar 3D-RoPE models

3 participants