Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 3 additions & 20 deletions docs/source/sdft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

## Overview

Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation for Language Models](https://arxiv.org/pdf/2601.19897).
SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between
student and teacher distributions.
Self-Distillation Fine-Tuning (SDFT) is described in [Self-Distillation Enables Continual Learning](https://huggingface.co/papers/2601.19897) by Idan Shenfeld, Mehul Damani, [Jonas Hübotter](https://huggingface.co/jonhue), Pulkit Agrawal.
SDFT trains a student model using a teacher model on the student's generated completions, using a divergence between student and teacher distributions.

The abstract from the paper is the following:

Expand All @@ -13,26 +12,12 @@ The abstract from the paper is the following:
> [!WARNING]
> **Experimental:** APIs under `trl.experimental` may change or be removed without notice.

## Usage tips

- Provide a teacher model via `ref_model`. If you omit it, the trainer will create a teacher from the same checkpoint
as the student.
- Your dataset must contain `prompt` and `teacher_prompt`. If you do not have distinct teacher prompts, set
`teacher_prompt = prompt`.
- Set `generate_from_teacher=True` to generate completions using the teacher model instead of the student.

## Quick Start

```python
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl.experimental.sdft import SDFTConfig, SDFTTrainer

student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer.pad_token = tokenizer.eos_token

train_dataset = Dataset.from_dict(
{
"prompt": ["Write a haiku about the ocean."],
Expand All @@ -42,10 +27,8 @@ train_dataset = Dataset.from_dict(

training_args = SDFTConfig(output_dir="sdft-model", per_device_train_batch_size=1)
trainer = SDFTTrainer(
model=student_model,
ref_model=teacher_model,
model="Qwen/Qwen2-0.5B-Instruct",
args=training_args,
processing_class=tokenizer,
train_dataset=train_dataset,
)
trainer.train()
Expand Down
Loading