-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Add On-Policy Distillation from thinking labs to paper index. #4410
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a7f2532
0365470
3f26f44
29a7b8e
41f68dd
41d1f32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -605,3 +605,47 @@ def add_margin(example): | |
|
|
||
| dataset = dataset.map(add_margin) | ||
| ``` | ||
|
|
||
| ## Distillation | ||
| Papers relating to training a student model with the help of a teacher model. | ||
|
|
||
| ### On-Policy Distillation | ||
| **📰 Blog**: https://thinkingmachines.ai/blog/on-policy-distillation/ | ||
|
|
||
| On-Policy Distillation involves a student model generating rollouts for each batch of training data. We subsequently obtain the probability distributions for each token of the rollouts from both the student and teacher models. The student model is then optimized to minimize the negative Kullback-Leibler (KL) divergence between its own token distributions and those of the teacher model. | ||
|
|
||
| | Method | Sampling | Reward signal | | ||
| |-------------------------|------------|---------------| | ||
| | Supervised finetuning | off-policy | dense | | ||
| | Reinforcement learning | on-policy | sparse | | ||
| | On-policy distillation | on-policy | dense | | ||
|
|
||
| On-Policy Distillation has been shown to outperform SFT, GRPO and can be used to restore generalization capabilities lost during SFT. | ||
|
|
||
| Additionally on-policy distillation is more compute efficient and is less prone to overfitting when trained with limited data. | ||
|
|
||
| To train a model with on-policy distillation using TRL, you can use the following configuration, with the [`GKDTrainer`] and [`GKDConfig`]: | ||
|
|
||
| ```python | ||
| from trl import GKDConfig | ||
|
|
||
| config = GKDConfig( | ||
| lmbda=1.0, # student produces rollouts for all batches | ||
| beta=1.0, # to ensure reverse-kl as the loss function | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can someone please confirm if From thinking labs:
HF blog's linked recipe for distillation
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @kashif
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets see the GKD paper had beta from 0.1, 0.5 and 0.9, and as the beta-> 1 the gradient of the loss behaves like inverse-KL. In the code when
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, so it seems like setting
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good catch. Using You're right that for exact reproduction of the Thinking Machine blog, we should use Running the experiments with the Thinking Machines setup will likely result in marginally better performance, but the conclusion would be essentially the same.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @qgallouedec this one should be fine to merge. |
||
| teacher_model_name_or_path="teacher-model", # specify the teacher model | ||
|
|
||
| ) | ||
| ``` | ||
|
|
||
| Alternatively, you can use the [`GOLDTrainer`] and [`GOLDConfig`] to perform on-policy distillation with a similar configuration: | ||
|
|
||
| ```python | ||
| from trl.experimental import GOLDConfig | ||
|
|
||
| config = GOLDConfig( | ||
| lmbda=1.0, # student produces rollouts for all batches | ||
| beta=1.0, # to ensure reverse-kl as the loss function | ||
| teacher_model_name_or_path="teacher-model", # specify the teacher model | ||
|
|
||
| ) | ||
| ``` | ||
Uh oh!
There was an error while loading. Please reload this page.