Skip to content

Enable chunked NLL loss with PEFT in SFT#5676

Merged
qgallouedec merged 12 commits into
mainfrom
chunked_nll_peft
May 5, 2026
Merged

Enable chunked NLL loss with PEFT in SFT#5676
qgallouedec merged 12 commits into
mainfrom
chunked_nll_peft

Conversation

@qgallouedec

@qgallouedec qgallouedec commented Apr 28, 2026

Copy link
Copy Markdown
Member

Follows #5575

nll_vs_chunked_nll-2

Bechmark

import argparse
import time

import torch
from datasets import load_dataset
from peft import LoraConfig

from trl import SFTConfig, SFTTrainer


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument("--loss_type", choices=["nll", "chunked_nll"], required=True)
    args = parser.parse_args()

    dataset = load_dataset("trl-lib/Capybara")
    dataset = dataset.filter(lambda ex: len(ex["messages"]) == 2)
    dataset = dataset.map(
        lambda ex: {"prompt": [ex["messages"][0]], "completion": [ex["messages"][1]]},
        remove_columns=["messages"],
    )

    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.0,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        task_type="CAUSAL_LM",
    )

    training_args = SFTConfig(
        output_dir=f"/tmp/chunked_nll_peft_bench/peft/{args.loss_type}",
        loss_type=args.loss_type,
        run_name=f"peft-{args.loss_type}",
        max_length=1024,
        max_steps=100,
        logging_steps=1,
        save_strategy="no",
        report_to="trackio",
        trackio_space_id="qgallouedec/chunked-nll-benchmark-2",
        seed=1234,
        data_seed=1234,
    )

    trainer = SFTTrainer(
        model="Qwen/Qwen3-1.7B",
        args=training_args,
        train_dataset=dataset["train"],
        peft_config=peft_config,
    )

    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    start = time.perf_counter()
    trainer.train()
    torch.cuda.synchronize()
    train_time = time.perf_counter() - start

    peak_gb = torch.cuda.max_memory_allocated() / 1024**3
    if torch.distributed.is_initialized():
        t = torch.tensor(peak_gb, device="cuda")
        torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.MAX)
        peak_gb = t.item()
        if torch.distributed.get_rank() != 0:
            raise SystemExit(0)

    print(f"loss_type={args.loss_type} peak={peak_gb:.2f} GB time={train_time:.2f} s")


if __name__ == "__main__":
    main()
train_grad_norm-2 train_mean_token_accuracy train_entropy train_loss-2

trackio: https://qgallouedec-chunked-nll-benchmark-2.hf.space/?project=huggingface&run_ids=a191d6b9e72d44b9a7cbf1196604a8e6%2C142d25ebd5bd4ebaaf8d6364bdcae602&sidebar=hidden&navbar=hidden


Note

Medium Risk
Expands chunked_nll to PEFT-wrapped models by patching base-model forward and changes how token-level metrics are computed, which could impact training correctness/metrics for adapter configurations (especially around lm_head and prompt-learning). Coverage is improved with new PEFT-focused regression tests and explicit validation errors for unsupported setups.

Overview
Enables loss_type='chunked_nll' for PEFT in SFTTrainer by patching the inner base model (get_base_model()) instead of rejecting PEFT, while adding a guard that errors if lm_head itself is wrapped by a PEFT tuner layer (to avoid silently dropping adapter deltas).

Updates the chunked CE path to return and propagate num_valid_tokens, and switches mean_token_accuracy/entropy denominators to use this value (fixing prompt-learning PEFT cases where virtual-token label padding changes the valid-target count).

Adds regression tests covering PEFT training with chunked_nll (base params unchanged, adapter params updated), chunked CE valid-token counting, and patched-forward numerical/gradient equivalence across multiple PEFT types (LoRA, modules_to_save, and prompt-learning configs).

Reviewed by Cursor Bugbot for commit 65a4c77. Bugbot is set up for automated code reviews on this repo. Configure here.

@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread tests/test_sft_trainer.py Outdated

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 0cf0afb15f

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread trl/trainer/sft_trainer.py
@qgallouedec

Copy link
Copy Markdown
Member Author

cc @BenjaminBossan

Comment thread trl/trainer/sft_trainer.py
@BenjaminBossan

Copy link
Copy Markdown
Member

Phenomenal results, thanks for implementing.

About the comment with the LM head weight: I think it's an accurate concern. We could exclude this possibility for now (raise if the LM head is a PEFT adapter layer) or we could merge the weights into the head, something like this:

self.lm_head.merge()
self.lm_head.weight  # <= now includes PEFT weights
...
self.lm_head.unmerge()

To be super safe, this could use merge(safe_merge=True) and be wrapped in try ... finally (safe merge means we check for nans but it creates an extra copy of the weights).

@qgallouedec

Copy link
Copy Markdown
Member Author

The merge approach would fix the loss value but not the gradient problem. LoraLinear.merge() does base_weight += scaling * (lora_B @ lora_A) in-place under torch.no_grad(). There's no autograd edge back to lora_A/lora_B. So x @ self.lm_head.weight.T would compute the right logits but still give zero gradient to the LoRA params.

Let's just detect and nicely fail for now 68f4f94

@BenjaminBossan

Copy link
Copy Markdown
Member

The merge approach would fix the loss value but not the gradient problem. LoraLinear.merge() does base_weight += scaling * (lora_B @ lora_A) in-place under torch.no_grad(). There's no autograd edge back to lora_A/lora_B. So x @ self.lm_head.weight.T would compute the right logits but still give zero gradient to the LoRA params.

Good point. We could manually merge in trl to avoid detaching the gradient. Otherwise, we could make _chunked_cross_entropy_loss LoRA aware and pass the weights there, but I think both are overkill. This is already a pretty big win, edge cases can be dealt with later if there is demand.

@albertvillanova albertvillanova left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit c48bb10. Configure here.

Comment thread trl/trainer/sft_trainer.py Outdated
@qgallouedec qgallouedec merged commit b8e6fc0 into main May 5, 2026
13 checks passed
@qgallouedec qgallouedec deleted the chunked_nll_peft branch May 5, 2026 17:07
@BenjaminBossan

Copy link
Copy Markdown
Member

@qgallouedec Thanks for enabling PEFT to work with chunked NLL loss. Is there an example somewhere that I can point our users to of using TRL + PEFT + chunked NLL loss for extra memory savings?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants