generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.5k
🤫 TR-DPO implementation #1593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
🤫 TR-DPO implementation #1593
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6fe29e5
🤫 TR-DPO implementation baseline
a98e2ec
fix comments
687a2e8
docs
9f6d795
fix linters
4e0adbc
test added
62de841
Merge branch 'main' into tr-dpo
kashif f918480
move configs to DPOConfig
kashif f3b409e
fix typo
kashif d087bca
add docs
kashif fb01df7
fix import
kashif b9aabc2
use state.global_step
kashif 3b63bdd
fix order of arguments
kashif 247a88b
make sure plugins are not none
kashif 9932409
Update trl/trainer/utils.py
kashif e230875
Update trl/trainer/utils.py
kashif b333511
checking that reference model weights have changed
kashif b2b6a50
Merge branch 'main' into tr-dpo
kashif 2f161b6
sync_target_model as staticmethod
kashif 9f50f7d
set reference model
kashif d16daba
Merge branch 'main' into tr-dpo
kashif File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,9 @@ | |
| import numpy as np | ||
| import pandas as pd | ||
| import torch | ||
| from accelerate import PartialState | ||
| from accelerate.state import AcceleratorState | ||
| from accelerate import Accelerator | ||
| from accelerate.state import AcceleratorState, PartialState | ||
| from accelerate.utils import is_deepspeed_available | ||
| from rich.console import Console, Group | ||
| from rich.live import Live | ||
| from rich.panel import Panel | ||
|
|
@@ -32,6 +33,7 @@ | |
| from transformers import ( | ||
| BitsAndBytesConfig, | ||
| DataCollatorForLanguageModeling, | ||
| PreTrainedModel, | ||
| PreTrainedTokenizerBase, | ||
| ) | ||
| from transformers.trainer import TrainerCallback | ||
|
|
@@ -45,6 +47,10 @@ | |
| from peft import LoraConfig, PeftConfig | ||
|
|
||
|
|
||
| if is_deepspeed_available(): | ||
| import deepspeed | ||
|
|
||
|
|
||
| class AdaptiveKLController: | ||
| """ | ||
| Adaptive KL controller described in the paper: | ||
|
|
@@ -63,6 +69,39 @@ def update(self, current, n_steps): | |
| self.value *= mult | ||
|
|
||
|
|
||
| class SyncRefModelCallback(TrainerCallback): | ||
| def __init__( | ||
| self, | ||
| ref_model: Union[PreTrainedModel, torch.nn.Module], | ||
| accelerator: Optional[Accelerator], | ||
| ): | ||
| self.accelerator = accelerator | ||
| self.ref_model = ref_model | ||
|
|
||
| @staticmethod | ||
| def _sync_target_model(model, target_model, alpha): | ||
| for target_param, copy_param in zip(target_model.parameters(), model.parameters()): | ||
| target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha) | ||
|
|
||
| @staticmethod | ||
| def sync_target_model(model, target_model, alpha): | ||
| deepspeed_plugin = AcceleratorState().deepspeed_plugin | ||
| if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a DS expert. Is this only required for zero stage3? |
||
| with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0): | ||
| if deepspeed.comm.get_rank() == 0: | ||
| SyncRefModelCallback._sync_target_model(model, target_model, alpha) | ||
| else: | ||
| SyncRefModelCallback._sync_target_model(model, target_model, alpha) | ||
|
|
||
| def on_step_end(self, args, state, control, **kwargs): | ||
| model: PreTrainedModel = kwargs["model"] | ||
|
|
||
| if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0: | ||
| if self.accelerator: | ||
| model = self.accelerator.unwrap_model(model) | ||
| self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha) | ||
|
|
||
|
|
||
| class FixedKLController: | ||
| """Fixed KL controller.""" | ||
|
|
||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.