Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/dpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ The [NCA](https://arxiv.org/abs/2402.05369) authors shows that NCA optimizes the

The [TR-DPO](https://arxiv.org/pdf/2404.09656) paper suggests syncing the reference model weights after every `ref_model_sync_steps` steps of SGD with weight `ref_model_mixup_alpha` during DPO training. To toggle this callback use the `sync_ref_model` flag in the `DPOConfig`.

The [RPO](https://arxiv.org/abs/2404.19733) paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this [paper](https://arxiv.org/abs/2405.16436) that essentially consists of the SFT loss on the chosen preferences together with a weighted DPO loss. To use this loss set the `rpo_alpha` in the `DPOConfig` to an appropriate value.

## Logging

While training and evaluating we record the following reward metrics:
Expand Down
1 change: 1 addition & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def test_dpo_trainer_without_providing_ref_model(self):
eval_strategy="steps",
beta=0.1,
precompute_ref_log_probs=True,
rpo_alpha=0.5,
)

dummy_dataset = self._init_dummy_dataset()
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class DPOConfig(TrainingArguments):
The alpha parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
ref_model_sync_steps ('int', defaults to 2):
The tau parameter from the [TR-DPO](https://arxiv.org/pdf/2404.09656) paper.
rpo_alpha ('float', defaults to `None`):
The alpha parameter from the [RPO](https://arxiv.org/pdf/2404.19733) paper. If None, no weighting is applied and the loss is the same as the DPO loss.
"""

beta: float = 0.1
Expand Down Expand Up @@ -98,3 +100,4 @@ class DPOConfig(TrainingArguments):
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.9
ref_model_sync_steps: int = 64
rpo_alpha: Optional[float] = None
35 changes: 23 additions & 12 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,13 +901,15 @@ def compute_reference_log_probs(self, padded_batch: Dict) -> Dict:
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.model, padded_batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.ref_model, padded_batch)

return reference_chosen_logps, reference_rejected_logps
Expand Down Expand Up @@ -1089,21 +1091,19 @@ def dpo_loss(
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = False,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
) -> Tuple[torch.FloatTensor, torch.LongTensor]:
"""Compute the log probabilities of the given labels under the given logits.

Args:
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
label_pad_token_id: The label pad token id.
is_encoder_decoder: Whether the model is an encoder-decoder model.

Returns:
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
A Tuple of two tensor of shape ((batch_size,), (batch_size,)) containing the sum of log probabilities of the given labels under the given logits in the first tensor and the number of non-masked tokens in the second tensor.
"""
if logits.shape[:-1] != labels.shape:
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
Expand All @@ -1118,10 +1118,7 @@ def get_batch_logps(

per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1)

def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
Expand Down Expand Up @@ -1154,21 +1151,25 @@ def concatenated_forward(
**model_kwargs,
).logits

all_logps = self.get_batch_logps(
all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
average_log_prob=self.loss_type == "ipo",
# average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps_avg = all_logps[:len_chosen] / size_completion[:len_chosen]

if self.loss_type == "ipo":
all_logps = all_logps / size_completion

chosen_logps = all_logps[:len_chosen]
rejected_logps = all_logps[len_chosen:]

chosen_logits = all_logits[:len_chosen]
rejected_logits = all_logits[len_chosen:]

return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps_avg)

def get_batch_loss_metrics(
self,
Expand All @@ -1184,10 +1185,15 @@ def get_batch_loss_metrics(
policy_rejected_logps,
policy_chosen_logits,
policy_rejected_logits,
policy_chosen_logps_avg,
) = self.concatenated_forward(model, batch)

# if reference_chosen_logps and reference_rejected_logps in batch use them, otherwise use the reference model
if "reference_chosen_logps" in batch and "reference_rejected_logps" in batch:
if (
"reference_chosen_logps" in batch
and "reference_rejected_logps" in batch
and self.args.rpo_alpha is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this condition?

):
reference_chosen_logps = batch["reference_chosen_logps"]
reference_rejected_logps = batch["reference_rejected_logps"]
else:
Expand All @@ -1199,13 +1205,15 @@ def get_batch_loss_metrics(
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)

losses, chosen_rewards, rejected_rewards = self.dpo_loss(
Expand All @@ -1216,6 +1224,9 @@ def get_batch_loss_metrics(
)
reward_accuracies = (chosen_rewards > rejected_rewards).float()

if self.args.rpo_alpha is not None:
losses = losses * self.args.rpo_alpha - policy_chosen_logps_avg

prefix = "eval_" if train_eval == "eval" else ""
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
Expand Down