Skip to content

[GKD] Buffer Implementation for Distillation Trainer#5137

Merged
cmpatino merged 40 commits into
huggingface:mainfrom
cmpatino:kd-buffering
Mar 17, 2026
Merged

[GKD] Buffer Implementation for Distillation Trainer#5137
cmpatino merged 40 commits into
huggingface:mainfrom
cmpatino:kd-buffering

Conversation

@cmpatino

@cmpatino cmpatino commented Feb 20, 2026

Copy link
Copy Markdown
Collaborator

Implement Buffer for Distillation Trainer (GOLDTrainer)

Implement generation buffering and multi-generation support for GOLDTrainer

Add a prompt-level generation buffer that decouples generation from the
optimization steps. We adopt a buffer similar to GRPO to generate all rollouts for all mini-batches within an optimization step, leveraging parallel inference engines. This means each worker handles a buffer of per_device_train_batch_size * gradient_accumulation_steps.

Buffer Details

We allow multiple rollouts per prompt, following Thinking Machine’s Tinker example. The number of rollouts per prompt is determined by the num_generations parameter. To keep the effective batch size constant, we introduce the generation_batch_size parameter, which controls how many unique prompts we pass to the inference engine. We enforce generation_batch_size = per_device_train_batch_size * gradient_accumulation_steps // num_generations to ensure the effective batch size is invariant across setups.

Benchmarks

We can replicate Thinking Machine’s results using both non-Liger and Liger losses, achieving a 3x speedup on a setup with 8 training nodes in colocate mode.

comparison_tinker_liger
Phase Tinker (s) TRL (s)
Sampling 329.83 130
Loss 37.96 -
Training 98.69 38
Total 492.28 173

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.


Note

Medium Risk
Changes the GOLD training loop and dataloader sampling semantics to buffer generations across optimizer windows and support multiple rollouts per prompt, which can affect training dynamics and correctness across distributed setups.

Overview
Adds an optimizer-window generation buffer to GOLDTrainer, decoupling on-policy rollout generation from individual gradient-accumulation microsteps. Training now samples a full window batch via a RepeatSampler, generates (optionally via vLLM) once per window, and reuses buffered slices across accumulation steps.

Extends GOLDConfig with num_generations and generation_batch_size (with strict partitioning validation) and adds teacher_model_revision; updates model/teacher revision handling in gold.py and teacher model instantiation. Generation/label masking is refactored to correctly handle left-padded prompts and completion-only outputs, and Liger+ZeRO-3 gets an explicit parameter gather context.

Updates docs to describe buffering behavior and config changes, and adds/adjusts tests to cover left-padding retokenization and prompt-length masking expectations.

Written by Cursor Bugbot for commit abfadc1. This will update automatically on new commits. Configure here.

Comment thread trl/experimental/gold/gold_trainer.py Outdated
Comment thread trl/experimental/gold/gold_trainer.py
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread trl/experimental/gold/gold.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py Outdated

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py
Comment thread trl/experimental/gold/gold_trainer.py
@qgallouedec

Copy link
Copy Markdown
Member

I haven't reviewed it in detail; I have a general idea of what it's about, but I'm leaving the implementation mostly up to you. In future PRs, we can try to align it better with the rest of the codebase, but what matters most right now are the results you're getting.

Make sure to run make precommit to make to CI happy

Comment thread docs/source/gold_trainer.md Outdated
Comment thread trl/experimental/gold/gold_config.py Outdated

@qgallouedec qgallouedec left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just ensure the CI is green before merging

@cmpatino cmpatino merged commit ee339a0 into huggingface:main Mar 17, 2026
4 checks passed
@cmpatino cmpatino deleted the kd-buffering branch March 17, 2026 13:01
qgallouedec added a commit that referenced this pull request Mar 18, 2026
commit 52cd0cc
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:31:26 2026 +0100

    Fix UNEXPECTED lm_head.weight warning when loading a CausalLM as a reward model (#5295)

commit 7b42fc4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:29:11 2026 +0100

    Prevent corruption of DPO VLM training if "keep_end" truncation_mode (#5286)

commit 3acb8e8
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:27:10 2026 +0100

    Support max_length in DPO VLM training (#5284)

commit ee339a0
Author: Carlos Miguel Patiño <carlos.patino@huggingface.co>
Date:   Tue Mar 17 14:01:44 2026 +0100

    [GKD] Buffer Implementation for Distillation Trainer (#5137)

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit d46131f
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:27:19 2026 +0100

    Remove custom get_train/eval_dataloader from OnlineDPO (#5291)

commit 85cf8f4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:24:24 2026 +0100

    Remove TrainingArguments import from experimental trainers (#5290)

commit 91e3da0
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Mar 16 07:19:51 2026 -0600

    Fix `accuracy_reward` crash when called from non-main thread (#5281)

commit 4996631
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:44:28 2026 +0100

    Fix support for model_init_kwargs in MiniLLM when passed as CLI JSON string (#5274)

commit 5fceaa7
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:43:34 2026 +0100

    Simplify structured outputs logic across vLLM versions in scripts/vllm_serve (#5273)

commit 406d406
Author: casinca <47400729+casinca@users.noreply.github.com>
Date:   Sat Mar 14 04:12:49 2026 +0100

    feat(`grpo_trainer.py`): Variational Sequence-Level Soft Policy Optimization (VESPO) (#5199)

commit d0ac7ef
Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Date:   Sat Mar 14 02:53:33 2026 +0100

    Allow nullable logprobs in vLLM serve responses  (#5203)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit c0eabc4
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Mar 13 18:19:15 2026 -0600

    Change default `vllm_mode` to `"colocate"` and add v0→v1 migration guide (#5255)

    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 6c0fccd
Author: Mario Šaško <mariosasko777@gmail.com>
Date:   Sat Mar 14 00:19:38 2026 +0100

    35% faster packing + rename `bfd-requeue` to `bfd_split` (#5189)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
qgallouedec added a commit that referenced this pull request Mar 18, 2026
commit 3972d66
Author: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Date:   Wed Mar 18 22:26:44 2026 +0100

    Suggest the `Json()` type for tool calling dataset format (#5307)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 5c6e915
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 14:55:19 2026 -0600

    Update `RewardFunc` type annotation to allow `None`values in reward list (#5297)

commit ee96845
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Mar 18 17:03:54 2026 +0100

    Fix DPOTrainer collators to truncate sequences before padding (#5305)

commit 435c2ae
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 08:09:42 2026 -0600

    Add guidance to avoid `hasattr` and `getattr` with defaults in `AGENTS.md` (#5294)

commit 26ce6a3
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Mar 18 00:44:12 2026 -0600

    Apply docstyle (#5296)

commit 52cd0cc
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:31:26 2026 +0100

    Fix UNEXPECTED lm_head.weight warning when loading a CausalLM as a reward model (#5295)

commit 7b42fc4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:29:11 2026 +0100

    Prevent corruption of DPO VLM training if "keep_end" truncation_mode (#5286)

commit 3acb8e8
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Mar 17 15:27:10 2026 +0100

    Support max_length in DPO VLM training (#5284)

commit ee339a0
Author: Carlos Miguel Patiño <carlos.patino@huggingface.co>
Date:   Tue Mar 17 14:01:44 2026 +0100

    [GKD] Buffer Implementation for Distillation Trainer (#5137)

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit d46131f
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:27:19 2026 +0100

    Remove custom get_train/eval_dataloader from OnlineDPO (#5291)

commit 85cf8f4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 15:24:24 2026 +0100

    Remove TrainingArguments import from experimental trainers (#5290)

commit 91e3da0
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Mar 16 07:19:51 2026 -0600

    Fix `accuracy_reward` crash when called from non-main thread (#5281)

commit 4996631
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:44:28 2026 +0100

    Fix support for model_init_kwargs in MiniLLM when passed as CLI JSON string (#5274)

commit 5fceaa7
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Mar 16 07:43:34 2026 +0100

    Simplify structured outputs logic across vLLM versions in scripts/vllm_serve (#5273)

commit 406d406
Author: casinca <47400729+casinca@users.noreply.github.com>
Date:   Sat Mar 14 04:12:49 2026 +0100

    feat(`grpo_trainer.py`): Variational Sequence-Level Soft Policy Optimization (VESPO) (#5199)

commit d0ac7ef
Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Date:   Sat Mar 14 02:53:33 2026 +0100

    Allow nullable logprobs in vLLM serve responses  (#5203)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit c0eabc4
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Mar 13 18:19:15 2026 -0600

    Change default `vllm_mode` to `"colocate"` and add v0→v1 migration guide (#5255)

    Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>

commit 6c0fccd
Author: Mario Šaško <mariosasko777@gmail.com>
Date:   Sat Mar 14 00:19:38 2026 +0100

    35% faster packing + rename `bfd-requeue` to `bfd_split` (#5189)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
@cmpatino cmpatino mentioned this pull request Mar 23, 2026
3 tasks
songhappy pushed a commit to songhappy/trl that referenced this pull request Apr 20, 2026
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants