Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@
title: PPO
- local: prm_trainer
title: PRM
- local: sdft_trainer
title: SDFT
- local: winrate_callback
title: WinRateCallback
- local: xpo_trainer
Expand Down
70 changes: 70 additions & 0 deletions docs/source/sdft_trainer.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Self-Distillation Fine-Tuning (SDFT) Trainer

## Overview

Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897).
SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between
student and teacher distributions.

The abstract from the paper is the following:

> Continual learning, enabling models to acquire new skills and knowledge without degrading existing capabilities, remains a fundamental challenge for foundation models. While on-policy reinforcement learning can reduce forgetting, it requires explicit reward functions that are often unavailable. Learning from expert demonstrations, the primary alternative, is dominated by supervised fine-tuning (SFT), which is inherently offpolicy. We introduce Self-Distillation Fine-Tuning (SDFT), a simple method that enables on-policy learning directly from demonstrations. SDFT leverages in-context learning by using a demonstration-conditioned model as its own teacher, generating on-policy training signals that preserve prior capabilities while acquiring new skills. Across skill learning and knowledge acquisition tasks, SDFT consistently outperforms SFT, achieving higher new-task accuracy while substantially reducing catastrophic forgetting. In sequential learning experiments, SDFT enables a single model to accumulate multiple skills over time without performance regression, establishing on-policy distillation as a practical path to continual learning from demonstrations.

> [!WARNING]
> **Experimental:** APIs under `trl.experimental` may change or be removed without notice.

## Usage tips

- Provide a teacher model via `ref_model`. If you omit it, the trainer will create a teacher from the same checkpoint
as the student.
- Your dataset must contain `prompt` and `teacher_prompt`. If you do not have distinct teacher prompts, set
`teacher_prompt = prompt`.
- Set `generate_from_teacher=True` to generate completions using the teacher model instead of the student.

## Quick Start

```python
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.experimental.sdft import SDFTConfig, SDFTTrainer

student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

train_dataset = Dataset.from_dict(
{
"prompt": ["Write a haiku about the ocean."],
"teacher_prompt": ["Write a haiku about the ocean."],
}
)

training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1)
trainer = SDFTTrainer(
model=student_model,
ref_model=teacher_model,
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
```

### Expected dataset type

The dataset must be formatted with the following columns:

- `prompt`: text or conversational messages for the student input.
- `teacher_prompt`: text or conversational messages for the teacher input.

## SDFTTrainer

[[autodoc]] experimental.sdft.SDFTTrainer
- train
- save_model
- push_to_hub

## SDFTConfig

[[autodoc]] experimental.sdft.SDFTConfig
126 changes: 126 additions & 0 deletions tests/experimental/test_sdft_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from datasets import Dataset
from transformers import AutoTokenizer

from trl.experimental.sdft import SDFTConfig, SDFTTrainer

from ..testing_utils import TrlTestCase


MODEL_ID = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"


def build_dataset():
return Dataset.from_dict(
{
"prompt": ["Write a short poem about the sea."],
"teacher_prompt": ["Write a short poem about the sea."],
}
)


class TestSDFTTrainer(TrlTestCase):
def _build_args(self):
return SDFTConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
num_generations=1,
report_to="none",
)

def _build_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
return tokenizer

def test_init_creates_default_teacher(self):
args = self._build_args()
tokenizer = self._build_tokenizer()
trainer = SDFTTrainer(
model=MODEL_ID,
args=args,
processing_class=tokenizer,
train_dataset=build_dataset(),
)
assert trainer.ref_model is not None
assert trainer.ref_model is not trainer.model

def test_init_with_ref_model_id(self):
args = self._build_args()
tokenizer = self._build_tokenizer()
trainer = SDFTTrainer(
model=MODEL_ID,
ref_model=MODEL_ID,
args=args,
processing_class=tokenizer,
train_dataset=build_dataset(),
)
assert trainer.ref_model is not None

def test_missing_teacher_prompt_raises(self):
args = self._build_args()
tokenizer = self._build_tokenizer()
bad_dataset = Dataset.from_dict({"prompt": ["Hello"]})
with pytest.raises(ValueError, match="teacher_prompt"):
SDFTTrainer(
model=MODEL_ID,
args=args,
processing_class=tokenizer,
train_dataset=bad_dataset,
)

@pytest.mark.low_priority
def test_train_updates_student_and_freezes_teacher(self):
args = SDFTConfig(
output_dir=self.tmp_dir,
per_device_train_batch_size=1,
num_generations=1,
max_completion_length=8,
max_steps=1,
logging_steps=1,
report_to="none",
save_strategy="no",
eval_strategy="no",
)
tokenizer = self._build_tokenizer()
trainer = SDFTTrainer(
model=MODEL_ID,
args=args,
processing_class=tokenizer,
train_dataset=build_dataset(),
)

student_before = {n: p.detach().clone() for n, p in trainer.model.named_parameters()}
teacher_before = {n: p.detach().clone() for n, p in trainer.ref_model.named_parameters()}

trainer.train()

# Student params should change
student_changed = False
for name, before in student_before.items():
after = trainer.model.get_parameter(name).detach()
if not torch.allclose(before, after):
student_changed = True
break
assert student_changed, "Student parameters did not update after training"

# Teacher params should remain frozen
for name, before in teacher_before.items():
after = trainer.ref_model.get_parameter(name).detach()
assert torch.allclose(before, after), f"Teacher parameter {name} changed during training"
19 changes: 19 additions & 0 deletions trl/experimental/sdft/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .sdft_config import SDFTConfig
from .sdft_trainer import SDFTTrainer


__all__ = ["SDFTConfig", "SDFTTrainer"]
Loading