Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,47 @@ def add_margin(example):

dataset = dataset.map(add_margin)
```

## Distillation
Papers relating to training a student model with the help of a teacher model.

### On-Policy Distillation
**📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/

On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model.

| Method | Sampling | Reward signal |
|-------------------------|------------|---------------|
| Supervised finetuning | off-policy | dense |
| Reinforcement learning | on-policy | sparse |
| On-policy distillation | on-policy | dense |

On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT.

Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data.

To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]:

```python
from trl import GKDConfig

config = GKDConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
Copy link
Collaborator Author

@pramodith pramodith Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can someone please confirm if beta should be 1.0? The thinking lab's block mentions that we should use reverse-kl, but I see that in the HF blog recommends setting beta=0 for the distillation step., so I'm a bit confused.

From thinking labs:

Reverse KL has natural synergy with RL, which generally optimizes a form of sequence-level reverse KL induced by the reward model. However, unlike most reward models in practice, the reverse KL is “unhackable” in the sense that low KL always corresponds to a high probability of desirable behavior from the teacher model’s point of view. Two other useful properties of reverse KL are that it is “mode seeking”See Eric Jang’s post for more discussion of mode seeking behaviors. — it learns one specific behavior (the teacher’s) instead of spreading its distribution across several suboptimal options — and it reduces exposure bias.

HF blog's linked recipe for distillation

accelerate launch \
  --config_file examples/accelerate_configs/multi_gpu.yaml trl/experimental/gold/gold.py \
  --model_name_or_path <sft-model> \
  --dtype auto \
  --attn_implementation kernels-community/flash-attn \
  --dataset_name allenai/tulu-3-sft-mixture \
  --dataset_train_split train \
  --bf16 \
  --learning_rate 1e-7 \
  --gradient_checkpointing \
  --per_device_train_batch_size 1 \
  --gradient_accumulation_steps 64 \
  --num_train_epochs 1 \
  --eval_strategy steps \
  --eval_steps 100 \
  --temperature 1.0 \
  --top_p 0.95 \
  --top_k 0 \
  --max_new_tokens 2048 \
  --max_prompt_length 512 \
  --lmbda 0.25 \
  --beta 0.0 \

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @kashif

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets see the GKD paper had beta from 0.1, 0.5 and 0.9, and as the beta-> 1 the gradient of the loss behaves like inverse-KL. In the code when beta=1.0 we do: jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) which is the KL(student || teacher)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so it seems like setting beta=1 is the right thing to do for reproducing what the Thinking Machines blog state. Was the decision to use beta=0 instead of 1.0 in the HF blog intended?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good catch.

Using beta=0 was intentional because it was more sample efficient compared to beta=1 for some ablations with the Countdown task. We also used lambda=0.25 instead of lambda=1.0, as lower lambda values are faster to run than lambda=1 and with comparable results to those of an entirely online setup (see Figure 5 from the blogpost).

You're right that for exact reproduction of the Thinking Machine blog, we should use beta=1 and lambda=1. However, we wanted to include the parameters we used for the blog post in case someone wanted to reproduce those results.

Running the experiments with the Thinking Machines setup will likely result in marginally better performance, but the conclusion would be essentially the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec this one should be fine to merge.

teacher_model_name_or_path="teacher-model", # specify the teacher model

)
```

Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration:

```python
from trl.experimental import GOLDConfig

config = GOLDConfig(
lmbda=1.0, # student produces rollouts for all batches
beta=1.0, # to ensure reverse-kl as the loss function
teacher_model_name_or_path="teacher-model", # specify the teacher model

)
```