Enable chunked NLL loss with PEFT in SFT#5676
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0cf0afb15f
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
Phenomenal results, thanks for implementing. About the comment with the LM head weight: I think it's an accurate concern. We could exclude this possibility for now (raise if the LM head is a PEFT adapter layer) or we could merge the weights into the head, something like this: self.lm_head.merge()
self.lm_head.weight # <= now includes PEFT weights
...
self.lm_head.unmerge()To be super safe, this could use |
|
The merge approach would fix the loss value but not the gradient problem. Let's just detect and nicely fail for now 68f4f94 |
Good point. We could manually merge in trl to avoid detaching the gradient. Otherwise, we could make |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
Reviewed by Cursor Bugbot for commit c48bb10. Configure here.
|
@qgallouedec Thanks for enabling PEFT to work with chunked NLL loss. Is there an example somewhere that I can point our users to of using TRL + PEFT + chunked NLL loss for extra memory savings? |

Follows #5575
Bechmark
trackio: https://qgallouedec-chunked-nll-benchmark-2.hf.space/?project=huggingface&run_ids=a191d6b9e72d44b9a7cbf1196604a8e6%2C142d25ebd5bd4ebaaf8d6364bdcae602&sidebar=hidden&navbar=hidden
Note
Medium Risk
Expands
chunked_nllto PEFT-wrapped models by patching base-modelforwardand changes how token-level metrics are computed, which could impact training correctness/metrics for adapter configurations (especially aroundlm_headand prompt-learning). Coverage is improved with new PEFT-focused regression tests and explicit validation errors for unsupported setups.Overview
Enables
loss_type='chunked_nll'for PEFT inSFTTrainerby patching the inner base model (get_base_model()) instead of rejecting PEFT, while adding a guard that errors iflm_headitself is wrapped by a PEFT tuner layer (to avoid silently dropping adapter deltas).Updates the chunked CE path to return and propagate
num_valid_tokens, and switchesmean_token_accuracy/entropy denominators to use this value (fixing prompt-learning PEFT cases where virtual-token label padding changes the valid-target count).Adds regression tests covering PEFT training with
chunked_nll(base params unchanged, adapter params updated), chunked CE valid-token counting, and patched-forward numerical/gradient equivalence across multiple PEFT types (LoRA,modules_to_save, and prompt-learning configs).Reviewed by Cursor Bugbot for commit 65a4c77. Bugbot is set up for automated code reviews on this repo. Configure here.