Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we

The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.

The [BCO](https://arxiv.org/abs/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. The `DPOTrainer` can be switched to this loss via the `loss_type="bco_pair"` argument.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def _init_dummy_dataset(self):
["t5", "ipo", True],
["gpt2", "kto_pair", True],
["t5", "kto_pair", False],
["gpt2", "bco_pair", False],
["t5", "bco_pair", True],
]
)
def test_dpo_trainer(self, name, loss_type, pre_compute):
Expand Down Expand Up @@ -453,6 +455,10 @@ def test_dpo_lora_bf16_autocast_llama(self):
["gpt2", "kto_pair", False, True],
["gpt2", "kto_pair", True, False],
["gpt2", "kto_pair", True, True],
["gpt2", "bco_pair", False, False],
["gpt2", "bco_pair", False, True],
["gpt2", "bco_pair", True, False],
["gpt2", "bco_pair", True, True],
]
)
@require_bitsandbytes
Expand Down
26 changes: 22 additions & 4 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..models import PreTrainedModelWrapper, create_reference_model
from .utils import (
DPODataCollatorWithPadding,
RunningMoments,
disable_dropout_in_model,
pad_to_length,
peft_module_casting_to_bf16,
Expand Down Expand Up @@ -77,7 +78,8 @@ class DPOTrainer(Trainer):
label_smoothing (`float`, defaults to 0):
The robust DPO label smoothing parameter from the [cDPO](https://ericmitchell.ai/cdpo.pdf) report that should be between 0 and 0.5.
loss_type (`str`, defaults to `"sigmoid"`):
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper, or `"kto"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf).
The type of DPO loss to use. Either `"sigmoid"` the default DPO loss,`"hinge"` loss from [SLiC](https://arxiv.org/abs/2305.10425) paper, `"ipo"` from [IPO](https://arxiv.org/abs/2310.12036) paper,
`"kto_pair"` from the HALOs [report](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf), or `"bco_pair"` from [BCO](https://arxiv.org/abs/2404.04656) paper.
args (`transformers.TrainingArguments`):
The arguments to use for training.
data_collator (`transformers.DataCollator`):
Expand Down Expand Up @@ -147,7 +149,7 @@ def __init__(
ref_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
beta: float = 0.1,
label_smoothing: float = 0,
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid",
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair", "bco_pair"] = "sigmoid",
args: Optional[TrainingArguments] = None,
data_collator: Optional[DataCollator] = None,
label_pad_token_id: int = -100,
Expand Down Expand Up @@ -359,7 +361,7 @@ def make_inputs_require_grad(module, input, output):
self._precomputed_train_ref_log_probs = False
self._precomputed_eval_ref_log_probs = False

if loss_type in ["hinge", "ipo", "kto_pair"] and label_smoothing > 0:
if loss_type in ["hinge", "ipo", "kto_pair", "bco_pair"] and label_smoothing > 0:
warnings.warn(
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
)
Expand Down Expand Up @@ -421,6 +423,9 @@ def make_inputs_require_grad(module, input, output):
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)

if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
Expand Down Expand Up @@ -899,9 +904,22 @@ def dpo_loss(
),
0,
)
elif self.loss_type == "bco_pair":
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps

chosen_rewards = self.beta * chosen_logratios
rejected_rewards = self.beta * rejected_logratios
rewards = torch.cat((chosen_rewards, rejected_rewards), 0).mean().detach()
self.running.update(rewards)
delta = self.running.mean

losses = -F.logsigmoid((self.beta * chosen_logratios) - delta) - F.logsigmoid(
-(self.beta * rejected_logratios - delta)
)
else:
raise ValueError(
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair', 'bco_pair']"
)

chosen_rewards = (
Expand Down