diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index 8e140630f62..6467548d8ea 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -605,3 +605,47 @@ def add_margin(example): dataset = dataset.map(add_margin) ``` + +## Distillation +Papers relating to training a student model with the help of a teacher model. + +### On-Policy Distillation +**📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/ + +On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model. + +| Method | Sampling | Reward signal | +|-------------------------|------------|---------------| +| Supervised finetuning | off-policy | dense | +| Reinforcement learning | on-policy | sparse | +| On-policy distillation | on-policy | dense | + +On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT. + +Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. + +To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]: + +```python +from trl import GKDConfig + +config = GKDConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model + +) +``` + +Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: + +```python +from trl.experimental import GOLDConfig + +config = GOLDConfig( + lmbda=1.0, # student produces rollouts for all batches + beta=1.0, # to ensure reverse-kl as the loss function + teacher_model_name_or_path="teacher-model", # specify the teacher model + +) +```