Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ The [SPPO](https://arxiv.org/abs/2405.00675) authors claim that SPPO is capable

The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood.

The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,48 @@ def test_dpo_trainer_w_dataset_num_proc(self):

trainer.train()

def test_tr_dpo_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
precompute_ref_log_probs=False,
sync_ref_model=True,
ref_model_mixup_alpha=0.5,
ref_model_sync_steps=1,
)

dummy_dataset = self._init_dummy_dataset()

trainer = DPOTrainer(
model=self.model,
ref_model=self.model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)

# params of the ref model as its the same as the model
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

assert trainer.state.log_history[-1]["train_loss"] is not None

# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.ref_model.get_parameter(n)
# check the ref model's params have changed - ignore 0 biases
if param.sum() != 0:
assert not torch.equal(param, new_param)

@require_no_wandb
def test_dpo_trainer_generate_during_eval_no_wandb(self):
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ class DPOConfig(TrainingArguments):
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
force_use_ref_model (`bool`, defaults to `False`):
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
sync_ref_model ('bool', defaults to `False`):
The flag for syncing reference model during training from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_mixup_alpha ('float', defaults to 1.0):
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_sync_steps ('int', defaults to 2):
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
"""

beta: float = 0.1
Expand All @@ -87,3 +93,6 @@ class DPOConfig(TrainingArguments):
ref_adapter_name: Optional[str] = None
reference_free: bool = False
force_use_ref_model: bool = False
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
12 changes: 12 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
SyncRefModelCallback,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -528,12 +529,23 @@ def make_inputs_require_grad(module, input, output):
raise ValueError(
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
)
if args.sync_ref_model:
raise ValueError(
"You currently cannot use `ref_model=None` with TR-DPO method. Please provide `ref_model`."
)
else:
if self.is_deepspeed_enabled:
self.ref_model = self._prepare_deepspeed(self.ref_model)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if args.sync_ref_model:
if precompute_ref_log_probs:
raise ValueError(
"You cannot use `precompute_ref_log_probs=True` with TR-DPO method. Please set `precompute_ref_log_probs=False`."
)

self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

Expand Down
43 changes: 41 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
import numpy as np
import pandas as pd
import torch
from accelerate import PartialState
from accelerate.state import AcceleratorState
from accelerate import Accelerator
from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import is_deepspeed_available
from rich.console import Console, Group
from rich.live import Live
from rich.panel import Panel
Expand All @@ -32,6 +33,7 @@
from transformers import (
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
PreTrainedModel,
PreTrainedTokenizerBase,
)
from transformers.trainer import TrainerCallback
Expand All @@ -45,6 +47,10 @@
from peft import LoraConfig, PeftConfig


if is_deepspeed_available():
import deepspeed


class AdaptiveKLController:
"""
Adaptive KL controller described in the paper:
Expand All @@ -63,6 +69,39 @@ def update(self, current, n_steps):
self.value *= mult


class SyncRefModelCallback(TrainerCallback):
def __init__(
self,
ref_model: Union[PreTrainedModel, torch.nn.Module],
accelerator: Optional[Accelerator],
):
self.accelerator = accelerator
self.ref_model = ref_model

@staticmethod
def _sync_target_model(model, target_model, alpha):
for target_param, copy_param in zip(target_model.parameters(), model.parameters()):
target_param.data.mul_(1.0 - alpha).add_(copy_param.data, alpha=alpha)

@staticmethod
def sync_target_model(model, target_model, alpha):
deepspeed_plugin = AcceleratorState().deepspeed_plugin
if deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a DS expert. Is this only required for zero stage3?

with deepspeed.zero.GatheredParameters(list(model.parameters()), modifier_rank=0):
if deepspeed.comm.get_rank() == 0:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)
else:
SyncRefModelCallback._sync_target_model(model, target_model, alpha)

def on_step_end(self, args, state, control, **kwargs):
model: PreTrainedModel = kwargs["model"]

if self.ref_model is not None and state.global_step % args.ref_model_sync_steps == 0:
if self.accelerator:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, args.ref_model_mixup_alpha)


class FixedKLController:
"""Fixed KL controller."""

Expand Down