Enable chunked NLL loss with VLM in SFT#5684
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@codex review |
|
Codex Review: Didn't find any major issues. Keep them coming! ℹ️ About Codex in GitHubYour team has set up Codex to review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. Codex can also answer questions or update the PR. Try commenting "@codex address that feedback". |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit 00cb84b. Configure here.
| # the model itself. We should investigate this further, but for now we just skip these params. | ||
| # fmt: off | ||
| if ( | ||
| model_id == "trl-internal-testing/tiny-Gemma3ForConditionalGeneration" and "model.vision_tower.vision_model.head" in n or |
There was a problem hiding this comment.
nit: can we refacto this a bit ? any reasons they didn't change ?
There was a problem hiding this comment.
I'm not sure; it's been an open question for a long time, but it's never been urgent enough for me to set aside time to investigate. My hunch is that the gradients reaching the vision tower are too weak for the weights to be updated, either because of the structure of the tiny model or because of the initialization values.
There was a problem hiding this comment.
for the refacto, I'd recommend keeping thing like this mostly because it's consistent with TestDPOTrainer.test_train_vlm and TestSFTTrainer.test_train_vlm, plus it explicitly shows which layers are problematic.
Although I agree it's no pretty

Requires #5676
Note
Medium Risk
Expands the
chunked_nlltraining path to VLM and MoE wrappers by patching modelforward, which can subtly affect loss/gradient behavior across many model families and transformers versions.Overview
Enables
loss_type='chunked_nll'for vision-language models by extending_patch_chunked_ce_lm_headto handle VLM config (text_config), run the multimodal wrapper (base_model/model) so vision token injection occurs, and compute MoE auxiliary loss using the correct config fields.Updates
SFTTrainerto apply the patched chunked-loss forward for VLMs (removing the prior VLM restriction) and relaxesSFTConfigdocs/help text to reflect thatchunked_nllis now only incompatible withuse_liger_kernel.Adds/expands tests to cover chunked NLL training on multiple VLM families, plus forward/backward equivalence tests for patched chunked CE on VLMs (including a VLM MoE aux-loss case), and tightens the PEFT chunked-NLL test to assert base weights stay frozen while adapter params update.
Reviewed by Cursor Bugbot for commit ec0cad7. Bugbot is set up for automated code reviews on this repo. Configure here.