-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Add SDPO (Self-Distillation Policy Optimization) trainer #4935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 83 commits
Commits
Show all changes
85 commits
Select commit
Hold shift + click to select a range
a189337
Add SDPO (Self-Distillation Policy Optimization) trainer
MengAiDev b382ea5
move to experimental
kashif 4139122
rename
kashif 9afaa0b
remove example
kashif 4de7cfb
add docs
kashif 2ece95a
fix tests and formatting
kashif 63e9423
added paper index
kashif 0d07988
align loss hyper-params with paper suggestion
kashif 0c0f4d7
update the docs
kashif 4c321e9
add helper to make teacher prompt
kashif cbf221c
Merge branch 'main' into 4929
kashif 067322f
Merge branch 'main' into 4929
kashif fec16e5
Merge branch 'main' into 4929
kashif b91901b
refactored to a base self-distillation trainer and specific sdpo and …
kashif 220ed91
add expected dataset format
kashif 20cfdf0
added sdft paper index
kashif 12cbe91
cleanup config
kashif 90f13f6
add sdft example
kashif 56754c9
add sdft test
kashif a65a4fa
initial sdpo example
kashif 51d9c10
fix example script
kashif 83b4434
use gsmk
kashif 15037d5
cleanup
kashif 1af9efa
clean up tests
kashif 6d167a8
fix review issues
kashif 922b2ad
fix __init__
kashif 88041c4
add online_rollout_mixin.py
kashif 920f065
Moved the shared sampled-token log-prob helper into self_distillation…
kashif c604497
consolidate the test callbacks
kashif 2ba8a9a
added generation-side diagnostics metrics
kashif 3455cb3
Change the example dataset for SDPO
cmpatino 11b991a
Moved the shared buffered-generation logic into self_distillation_mix…
kashif 6500121
add docs and callback info
kashif 64aa6c5
Merge branch 'main' into 4929
kashif 5174610
global gathering for the part that actually needs it
kashif f46fce2
formatting
kashif c0a7b13
fix dr_grpo
kashif 75ec3ab
remove double deepcopy
kashif 3533e82
Merge branch '4929' into sdpo_example
cmpatino 26c2d58
docstrings
kashif 9605b85
remove redundant computation and makes the parent/child data flow cl…
kashif 425c0ac
privileged_context is only extra teacher-only information
kashif ffdf44a
formatting
kashif 93a5ba4
fix num generation bug
kashif 2abd07d
privileged_context always needed
kashif c778f77
focused tests
kashif 661a26e
Upload working minimal GSM8k example for SDPO
cmpatino d117127
Merge branch '4929' into sdpo_example
cmpatino a1a0ddd
Fix type hint in SDPO example
cmpatino bea5b07
distillation_only mode requires
kashif 7aceaee
refactor teacher_context.py
kashif b0adc22
fix use_topk_distillation duplicatation
kashif d843945
fix formatting
kashif 7c0fe58
slice the rewards too
kashif 60ec6da
Merge branch 'main' into 4929
kashif 7798b80
add PEFTAdapterEMACallback
kashif 8bdfcd6
fix Multi-GPU index mismatch
kashif 34addee
remove arg coerce methods
kashif 9d1d838
remove duplications
kashif ce9ba1e
remove dead code and use shared extract_last_user_text
kashif 354e483
fix issues
kashif fb5e832
Merge branch 'main' into 4929
kashif 6c50a59
cleanup tests
kashif 3ffcb16
Update docs/source/sdft_trainer.md
kashif cacbd1f
Merge branch 'main' into 4929
kashif 9b6a1fe
remove SDPO brace escaping
kashif e3b0577
refactor base trainer
kashif d12e6da
use conversational format
kashif 1772110
Apply suggestion from @qgallouedec
kashif df867e4
remove unneeded properties from BaseSelfDistillationTrainer
kashif f9c9ef8
add _paper
kashif 4b89333
move PEFTAdapterEMACallback to experimental
kashif a6e586d
remove stale import
kashif 175230d
fix test
kashif 5983d8c
Completion tensors are padded to the local max length per rank; align…
kashif c2ab993
Merge branch 'main' into 4929
kashif ab7f630
pad completion_mask
kashif 7ac41fa
only padded_completion_ids is used for the cross-rank gather
kashif 20e17db
check role validation
kashif f0b8246
_set_signature_columns_if_needed moved to mixin
kashif d0e8cbc
Count groups with any successful rollout
kashif a16b033
Merge branch 'main' into 4929
qgallouedec c4a3bab
check num_generations_eval are divisible
kashif 33b8d5b
scale loss for grad acc only during training
kashif bf4cc67
remove ref_model reference
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # SDFT | ||
|
|
||
| Self-Distilled Fine-Tuning (SDFT) is described in [Self-Training with On-Policy Self-Distillation for Language Model Alignment](https://huggingface.co/papers/2601.19897). | ||
|
|
||
| The TRL implementation adapts SDFT to the experimental trainer API while reusing the shared self-distillation infrastructure also used by SDPO. | ||
|
|
||
| In the current TRL implementation: | ||
|
|
||
| - SDFT uses an explicit `ref_model` teacher | ||
| - the dataset must provide both `prompt` and `privileged_context` | ||
| - `privileged_context` contains only the extra teacher-only information; the trainer combines it with `prompt` to build the teacher prompt | ||
| - `teacher_prompt_template` controls how `prompt` and `privileged_context` are combined into the teacher prompt | ||
| - on-policy generation can use either the student prompt or the teacher-conditioned prompt via `generate_from_teacher` | ||
| - `num_loss_tokens_to_skip` can exclude initial completion tokens from the distillation loss | ||
| - SDFT currently supports text-only training and does not support `use_vllm=True` | ||
| - the shared dataset contract is `prompt` plus `privileged_context` | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from datasets import Dataset | ||
|
|
||
| from trl.experimental.sdft import SDFTConfig, SDFTTrainer | ||
|
|
||
| dataset = Dataset.from_dict( | ||
| { | ||
| "prompt": [[{"role": "user", "content": "Solve 2+2."}]], | ||
| "privileged_context": ["Example answer: 4."], | ||
| } | ||
| ) | ||
|
|
||
| training_args = SDFTConfig( | ||
| output_dir="sdft-model", | ||
| distillation_alpha=0.5, | ||
| distillation_topk=5, | ||
| max_completion_length=64, | ||
| ) | ||
|
|
||
| trainer = SDFTTrainer( | ||
| model="Qwen/Qwen2.5-1.5B-Instruct", | ||
| ref_model="Qwen/Qwen2.5-1.5B-Instruct", | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
| trainer.train() | ||
| ``` | ||
|
|
||
| To generate from the teacher-conditioned prompt instead of the student prompt, set `generate_from_teacher=True`. | ||
| To customize how the teacher prompt is built, set `teacher_prompt_template` on [`SDFTConfig`]. | ||
|
|
||
| ## Expected dataset columns | ||
|
|
||
| Each example must provide: | ||
|
|
||
| - `prompt`: the student-facing prompt | ||
| - `privileged_context`: only the extra teacher-only information, such as a demonstration, hint, or privileged feedback | ||
|
|
||
| Both standard text prompts and conversational prompts are supported by the trainer prompt handling. | ||
|
|
||
| ## Callbacks | ||
|
|
||
| The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows. | ||
|
|
||
| Shared self-distillation hooks: | ||
|
|
||
| - `on_self_distillation_batch_prepared`: fired when a self-distillation batch is ready. The payload includes `prompt_ids`, `completion_ids`, and `old_per_token_logps` when importance-sampling clipping inputs are available. | ||
| - `on_generation_batch_built`: fired when a new buffered generation batch is created. The payload includes `generate_every` and `steps_per_generation`. | ||
|
|
||
| SDFT-specific hook: | ||
|
|
||
| - `on_generation_prompts_selected`: fired when SDFT chooses the prompt source for on-policy generation. The payload includes the selected `generation_prompts` and the corresponding `generation_prompt_text`. | ||
|
|
||
| ## SDFTConfig | ||
|
|
||
| [[autodoc]] experimental.sdft.SDFTConfig | ||
|
|
||
| ## SDFTTrainer | ||
|
|
||
| [[autodoc]] experimental.sdft.SDFTTrainer | ||
| - train | ||
| - save_model | ||
| - push_to_hub | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # SDPO | ||
|
|
||
| Self-Distillation Policy Optimization (SDPO) was introduced in [Reinforcement Learning via Self-Distillation](https://huggingface.co/papers/2601.20802) by [Jonas Hübotter](https://huggingface.co/jonhue), Frederike Lübeck, Lejs Behric, [Anton Baumann](https://huggingface.co/antonbaumann), Marco Bagatella, Daniel Marta, Ido Hakimi, Idan Shenfeld, Thomas Kleine Buening, Carlos Guestrin, and Andreas Krause. | ||
|
|
||
| > Large language models are increasingly post-trained with reinforcement learning in verifiable domains such as code and math. Yet, current methods for reinforcement learning with verifiable rewards (RLVR) learn only from a scalar outcome reward per attempt, creating a severe credit-assignment bottleneck. Many verifiable environments actually provide rich textual feedback, such as runtime errors or judge evaluations, that explain why an attempt failed. We formalize this setting as reinforcement learning with rich feedback and introduce Self-Distillation Policy Optimization (SDPO), which converts tokenized feedback into a dense learning signal without any external teacher or explicit reward model. SDPO treats the current model conditioned on feedback as a self-teacher and distills its feedback-informed next-token predictions back into the policy. In this way, SDPO leverages the model's ability to retrospectively identify its own mistakes in-context. Across scientific reasoning, tool use, and competitive programming on LiveCodeBench v6, SDPO improves sample efficiency and final accuracy over strong RLVR baselines. Notably, SDPO also outperforms baselines in standard RLVR environments that only return scalar feedback by using successful rollouts as implicit feedback for failed attempts. Finally, applying SDPO to individual questions at test time accelerates discovery on difficult binary-reward tasks, achieving the same discovery probability as best-of-k sampling or multi-turn conversations with 3x fewer attempts. | ||
|
|
||
| The SDPO trainer is built on TRL's experimental shared self-distillation stack. It keeps the online rollout-and-reward training flow, then builds a teacher-conditioned view of the same completions from successful rollouts and optional environment feedback. | ||
|
|
||
| In the current TRL implementation: | ||
|
|
||
| - the default SDPO policy loss mode is `distillation_only` | ||
| - `hybrid` mode is also available to combine the base policy loss with the self-distillation loss | ||
| - supported teacher regularization modes are `ema` and `none` | ||
| - `distillation_topk` is only valid when `full_logit_distillation=True` | ||
| - when `full_logit_distillation=False`, SDPO uses token-level reverse KL and requires `distillation_alpha=1.0` | ||
| - environment feedback can be injected into teacher reprompts when the dataset exposes a `privileged_context` column | ||
|
|
||
| ## Expected dataset columns | ||
|
|
||
| Each example must provide: | ||
|
|
||
| - `prompt`: the student-facing prompt | ||
| - `privileged_context`: optional privileged text, such as environment feedback, used when `include_environment_feedback=True` | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| from datasets import Dataset | ||
|
|
||
| from trl.experimental.sdpo import SDPOConfig, SDPOTrainer | ||
|
|
||
| dataset = Dataset.from_dict( | ||
| { | ||
| "prompt": [[{"role": "user", "content": "Solve 2+2."}]], | ||
| "privileged_context": ["Your earlier answer used the wrong format."], | ||
| } | ||
| ) | ||
|
|
||
| training_args = SDPOConfig( | ||
| output_dir="sdpo-model", | ||
| distillation_topk=100, # Top-K logit distillation approximation | ||
| full_logit_distillation=True, # Required for top-K; enables non-reverse divergences | ||
| include_environment_feedback=True, # Use dataset privileged_context for teacher reprompts | ||
| ) | ||
|
|
||
| trainer = SDPOTrainer( | ||
| model="Qwen/Qwen2.5-1.5B-Instruct", | ||
| reward_funcs=reward_func, | ||
| args=training_args, | ||
| train_dataset=dataset, | ||
| ) | ||
| trainer.train() | ||
| ``` | ||
|
|
||
| SDPO always requires a `prompt` column. To use environment feedback, also include a `privileged_context` column and set `include_environment_feedback=True`. SDPO will use successful rollouts and, when enabled, that text to build teacher reprompts for self-distillation. | ||
|
|
||
| ## Callbacks | ||
|
|
||
| The trainer emits a small set of callback hooks that are useful for debugging, observability, and tests. These hooks are intended as practical integration points for experimental self-distillation workflows. | ||
|
|
||
| Shared self-distillation hooks: | ||
|
|
||
| - `on_self_distillation_batch_prepared`: fired when a self-distillation batch is ready. The payload includes `prompt_ids`, `completion_ids`, and `old_per_token_logps` when importance-sampling clipping inputs are available. | ||
| - `on_generation_batch_built`: fired when a new buffered generation batch is created. The payload includes `generate_every` and `steps_per_generation`. | ||
|
|
||
| SDPO-specific hook: | ||
|
|
||
| - `on_teacher_context_built`: fired after SDPO constructs the teacher-conditioned inputs. The payload includes `teacher_input_ids`, `teacher_attention_mask`, `completion_mask`, and `self_distillation_mask`. | ||
|
|
||
| ## SDPOConfig | ||
|
|
||
| [[autodoc]] experimental.sdpo.SDPOConfig | ||
|
|
||
| ## SDPOTrainer | ||
|
|
||
| [[autodoc]] experimental.sdpo.SDPOTrainer | ||
| - train | ||
| - save_model | ||
| - push_to_hub |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is ref_model the teacher? or are these two separate models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes its the teacher