Skip to content

Add SDPO (Self-Distillation Policy Optimization) trainer#4935

Merged
kashif merged 85 commits into
huggingface:mainfrom
MengAiDev:4929
Mar 23, 2026
Merged

Add SDPO (Self-Distillation Policy Optimization) trainer#4935
kashif merged 85 commits into
huggingface:mainfrom
MengAiDev:4929

Conversation

@MengAiDev

@MengAiDev MengAiDev commented Jan 30, 2026

Copy link
Copy Markdown
Contributor

Implements SDPO algorithm from arxiv.org/abs/2601.20802. SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories, converting tokenized feedback into a dense learning signal.

  • Add SDPOConfig with distillation parameters (alpha, topk, ema_update_rate, etc.)
  • Add SDPOTrainer extending GRPOTrainer with self-distillation loss
  • Add comprehensive tests for SDPOConfig and SDPOTrainer
  • Add example script demonstrating SDPO usage

Fixes #4929


Note

Medium Risk
Introduces new experimental training algorithms (online rollout + self-distillation, EMA teacher syncing, reward-driven reprompting) that can affect training correctness and distributed behavior, though changes are mostly additive and isolated under trl.experimental.

Overview
Adds a new experimental self-distillation stack (SelfDistillationConfig, SelfDistillationMixin, OnlineRolloutMixin, BaseSelfDistillationTrainer) to support rollout reuse, reward scoring/normalization, self-distillation losses (token/logit-level + optional IS clipping), and callback hooks/diagnostics.

Introduces two new trainers: SDPOTrainer/SDPOConfig implementing online SDPO with successful-rollout/feedback-based teacher reprompting and optional EMA teacher synchronization, and SDFTTrainer/SDFTConfig for on-policy self-distilled fine-tuning using teacher-conditioned prompts and optional PEFT adapter EMA teacher.

Adds example scripts (trl/experimental/sdpo/sdpo.py, trl/experimental/sdft/sdft.py), new docs pages and paper index entries, and extensive tests covering training flows, callback payloads, PEFT EMA behavior, masking/attention correctness, and diagnostic warnings.

Written by Cursor Bugbot for commit bf4cc67. This will update automatically on new commits. Configure here.

MengAiDev and others added 9 commits January 30, 2026 10:04
Implements SDPO algorithm from arxiv.org/abs/2601.20802.
SDPO augments on-policy optimization with self-distillation from
the model's own high-reward trajectories, converting tokenized
feedback into a dense learning signal.

- Add SDPOConfig with distillation parameters (alpha, topk, ema_update_rate, etc.)
- Add SDPOTrainer extending GRPOTrainer with self-distillation loss
- Add comprehensive tests for SDPOConfig and SDPOTrainer
- Add example script demonstrating SDPO usage
@kashif

kashif commented Feb 2, 2026

Copy link
Copy Markdown
Collaborator

@MengAiDev I have cleaned up the structure and docs and tests. Next we need to address the main TODOs regarding the teacher logits.

@kashif

kashif commented Feb 2, 2026

Copy link
Copy Markdown
Collaborator

cc @jonhue here is a port of SDPO for TRL

@jonhue

jonhue commented Feb 2, 2026

Copy link
Copy Markdown

@MengAiDev @kashif Thanks so much for implementing this!! Let's coordinate with @Shekswess and #4941. It might be cleanest to have one implementation for SDFT & SDPO ("self-distillation") since both are algorithmically the same and they differ only in whether data is offline or online.

@kashif

kashif commented Feb 2, 2026

Copy link
Copy Markdown
Collaborator

agree! lets try that if its ok for you @MengAiDev

@Shekswess

Copy link
Copy Markdown

Wohoo !
This is really awesome, bravo legends @kashif @jonhue @MengAiDev. Maybe we should also then have the offline version of the trainer, knowing that some folks (like me that are GPU poor hahahahaha) can experiment with the approaches

@LeonEricsson

Copy link
Copy Markdown
Collaborator

Regarding the discussion on how to combine SDFT/SDPO PRs:

This PR inherits from GRPOTrainer, while the SDFT PR modifies it in place. Both approaches carry baggage from GRPOTrainer that isn’t necessarily applicable to SDPO/SDFT — but this also provides a nice playground for experimentation.

The tradeoff with inheritance is less control, but I like how it nicely isolates SDPO’s key contributions and exposes relevant hparams clearly. If future research demands more flexibility, we can revisit and consider breaking out SDPO into its own trainer.

If we proceed with this PR’s approach, extending it to cover the offline case should, at first glance, just require modifying the _build_teacher_inputs function.​​​​​​​​​​​​​​​​

@qgallouedec

Copy link
Copy Markdown
Member

That a good point Leon, I need to review the PR carefully, but in general, I’d rather isolate first and abstract later, if needed. (abstractions are easy to do, hard to undo)

@Shekswess

Copy link
Copy Markdown

@qgallouedec @LeonEricsson if you see my implementation #4941 (comment), of the offline SDFT I think it can be really really improved, tried to follow the official code from the authors with small modifications, feel free to ping us on how we can make these stuff better. Cannot wait to start to experiment hehehehe

@niksdagr8

Copy link
Copy Markdown

Any progress on this is much appreciated

Comment thread trl/experimental/sdpo/sdpo_trainer.py Outdated
Comment thread trl/experimental/sdpo/sdpo_trainer.py Outdated
Comment thread trl/experimental/self_distillation/base_self_distillation_trainer.py Outdated
Comment thread trl/experimental/sdpo/sdpo_trainer.py
Comment thread trl/experimental/self_distillation/self_distillation_mixin.py Outdated
Comment thread trl/experimental/sdpo/sdpo_trainer.py
Comment thread trl/experimental/self_distillation/teacher_context.py Outdated
@HuggingFaceDocBuilderDev

Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment thread trl/experimental/sdpo/sdpo_trainer.py
Comment thread trl/experimental/sdpo/sdpo_trainer.py
Comment thread trl/experimental/self_distillation/teacher_context.py
callbacks=callbacks,
optimizers=optimizers,
peft_config=peft_config,
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPOTrainer positional constructor bypasses keyword validation

Low Severity

SDPOTrainer.__init__ accepts reward_funcs as the second positional argument followed by args as the third, but the test test_training_with_positional_config_argument passes (model, reward_func, training_args, dataset) positionally. The train_dataset parameter is the fourth positional argument, matching correctly. However, the parent BaseSelfDistillationTrainer constructor does not call super().__init__ with a signature that maps train_dataset positionally — the fourth parameter alignment is coincidental and fragile if the constructor signature changes.

Fix in Cursor Fix in Web

teacher_logits = teacher_model(**teacher_model_inputs).logits
teacher_logits = teacher_logits[:, :-1, :]
teacher_logits = teacher_logits[:, -logits_to_keep:, :]
teacher_logits = teacher_logits / self.temperature

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Teacher context uses wrong model reference during inference

High Severity

_get_teacher_model_for_self_distillation returns self.teacher_model when it exists, but _get_teacher_context_for_self_distillation returns nullcontext() by default. In the SDPO EMA teacher case, teacher_model is a separate deep-copied model, so the teacher logits are correctly computed on it. However, the teacher_model_inputs use teacher_input_ids (the reprompted sequence) which has a different prompt length than the student. When logits_to_keep is set based on completion_ids.size(1), both the student and teacher forward passes use the same logits_to_keep + 1. If the teacher prompt is longer than the student prompt, the teacher model's logits_to_keep parameter will correctly slice the last N logits. This is fine since both models share the same completion tokens at the end. So this is actually correct.

Fix in Cursor Fix in Web

Comment thread trl/experimental/sdpo/sdpo_trainer.py Outdated
Comment thread trl/experimental/sdft/sdft_trainer.py Outdated
inputs["completion_mask"] = completion_mask

loss = self._compute_self_distillation_loss(model, inputs)
return loss / self.current_gradient_accumulation_steps

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDFTTrainer ignores inherited distillation_weight config parameter

Low Severity

SDFTTrainer.compute_loss applies _compute_self_distillation_loss and divides by gradient accumulation steps, but never multiplies by self.args.distillation_weight. The distillation_weight parameter is inherited from SelfDistillationConfig and is settable by users, but silently ignored here. In contrast, SDPOTrainer._compute_loss correctly applies self.args.distillation_weight. Setting a non-default distillation_weight on SDFTConfig has no effect on training.

Fix in Cursor Fix in Web

@qgallouedec

Copy link
Copy Markdown
Member

Just running a final @codex review to see if there is any critical bug, otherwise we're good to merge


model_init_kwargs (`dict[str, Any]`, *optional*):
Keyword arguments used when the `model` argument is passed as a string.
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessarily for this PR, but I think that in general, truncating the prompt isn't a good idea, because it truncates the generation prompt (<|im_end|><|im_start|>assistant\n), so the generation basically completes the user query, instead of answering it, which does make sense, and I don't think there is anything to learn from it. Unless there is a good reason to keep it, I recommend removing it

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: a16b033390

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines +294 to +296
if self.generation_batch_size % self.num_generations != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by num_generations ({self.num_generations})."

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Validate eval batch divisibility for num_generations_eval

SelfDistillationConfig.__post_init__() only checks that the training generation batch is divisible by num_generations, but SDPO evaluation also groups samples by num_generations_eval. If a user enables eval with a global eval batch that is not divisible by that value, OnlineRolloutMixin._generate_and_score_completions() later does rewards.view(-1, num_generations) and will either raise or mix prompt groups incorrectly during evaluate(). GRPO already guards against this shape constraint, so SDPO needs the same validation here.

Useful? React with 👍 / 👎.

Comment thread trl/experimental/sdft/sdft_trainer.py Outdated
Comment on lines +493 to +494
loss = self._compute_self_distillation_loss(model, inputs)
return loss / self.current_gradient_accumulation_steps

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Avoid scaling SDFT eval loss by accumulation steps

SDFTTrainer.compute_loss() always divides by current_gradient_accumulation_steps, but prediction_step() reuses this path during evaluate(). When gradient_accumulation_steps > 1, the reported eval loss is therefore smaller by that factor, which can skew checkpoint selection or early-stopping decisions. The online self-distillation trainers already special-case eval here; SDFT should do the same.

Useful? React with 👍 / 👎.

if alpha == 0.0:
kl = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif alpha == 1.0:
kl = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KL divergence direction is swapped for forward/reverse

High Severity

The _compute_divergence method swaps forward and reverse KL directions. alpha=0.0 is documented as "forward KL" (KL(teacher||student)), but F.kl_div(input=student, target=teacher) computes sum(teacher * (log(teacher) - student)) which equals KL(teacher||student). Actually wait — F.kl_div with log_target=True computes exp(target) * (target - input), so F.kl_div(student, teacher) = exp(teacher) * (teacher - student) = KL(teacher||student). For alpha=1.0 (reverse KL = KL(student||teacher)), F.kl_div(teacher, student) = exp(student) * (student - teacher) = KL(student||teacher). This is actually correct. I withdraw this bug.

Fix in Cursor Fix in Web

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

There are 7 total unresolved issues (including 5 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

top_entropy_quantile: float = field(
default=1.0,
metadata={"help": "Reserved for entropy-based token filtering."},
)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused config fields declared as reserved placeholders

Low Severity

top_entropy_quantile and use_transformers_paged are declared in SelfDistillationConfig with "Reserved" metadata but are never referenced anywhere in the trainer logic, loss computation, or generation code. These dead config fields add user-facing API surface that does nothing, and could confuse users into thinking they affect behavior.

Additional Locations (1)
Fix in Cursor Fix in Web

Comment thread trl/experimental/sdft/sdft_trainer.py
@kashif kashif merged commit 9b59eed into huggingface:main Mar 23, 2026
1 check passed
@Neelectric

Copy link
Copy Markdown

It looks like SDFTConfig has True as the default parameter for disable_dropout, while SelfDistillationConfig uses False as the default parameter.

The latter more closely matches the reference implementation by Shenfeld et al., where disable_dropout also uses False as the default parameter, and I did not find any overwrites for this elsewhere in their repo. Should SDFTConfig also default to False for consistency?

@1dividedby0

1dividedby0 commented Mar 25, 2026

Copy link
Copy Markdown

Does this implementation do online feedback? I see that the privileged context is generated offline as part of the dataset. Are there any plans to make this online (i.e. using the privileged context that is generated after the rollout)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SDPO: Reinforcement Learning via Self-Distillation