Fix mm_token_type_ids silently dropped in DPO VLM training#5279
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| "image_sizes", | ||
| "token_type_ids", | ||
| "mm_token_type_ids", | ||
| ): |
There was a problem hiding this comment.
mm_token_type_ids not truncated when max_length is set
Medium Severity
When max_length is set, _truncate_inputs truncates input_ids, attention_mask, and completion_mask but not mm_token_type_ids. Both compute_ref_log_probs and _compute_loss then pass the truncated input_ids alongside the original full-length mm_token_type_ids from inputs to the model, causing a shape mismatch that will crash during the forward pass.
Additional Locations (2)
There was a problem hiding this comment.
yep, it sounds like a legit feedback
There was a problem hiding this comment.
First, this is a pre-existing condition, not introduced by this PR:
- The same gap for token_type_ids has never caused a filed issue
Second, as already explicitly documented in the tests, for VLMs, truncating can remove image tokens, leading to errors. Anyone doing VLM DPO uses max_length=None
Therefore, in my opinion, if this is an issue we want to handle, I think it should be done in separate PR. I am opening one! 🚀
There was a problem hiding this comment.
The same gap for token_type_ids has never caused a filed issue
interesting, curious to know why
Anyone doing VLM DPO uses max_length=None
yes that's recommended, but careful user might use max_length!=0 (technically it's supported, if you ensure that no image token are truncated)
ok a separate pr sounds good!


Fix
mm_token_type_idssilently dropped in DPO VLM training.Fix #5277.
Note
Medium Risk
Touches the VLM batching/forward kwargs path for DPO, so mistakes could break multimodal training or cause subtle tensor-shape/runtime errors, but the change is narrow and covered by a targeted regression test.
Overview
Fixes DPO vision-language training for processors that return
mm_token_type_ids(e.g. Qwen2.5-VL), ensuring the field is not dropped and its tensor shape stays aligned withinput_ids.DataCollatorForVisionPreferencenow pads/flushesmm_token_type_idsalongside other batch tensors and includes it in the batch output, andDPOTrainerforwardsmm_token_type_idsinto both reference-logprob computation and the main loss forward pass. Adds a regression test validatingmm_token_type_idspresence and shape.Written by Cursor Bugbot for commit 4842af6. This will update automatically on new commits. Configure here.