Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,12 @@ def preference_loss_fn(
chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps

chosen_rewards = beta * (chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (rejected_logps - ref_rejected_logps)

logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
return loss
return loss, chosen_rewards, rejected_rewards

@staticmethod
def forward(
Expand Down Expand Up @@ -99,7 +102,7 @@ def __init__(
beta: float = 0.1,
compute_nll_loss: bool = False,
compiled: bool = True,
use_ref_model: bool = False,
use_ref_model: bool = True,
):
"""
Args:
Expand Down
2 changes: 2 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,5 +430,7 @@ def _compute_loss(
chosen_logits_mean,
rejected_logits_mean,
chosen_nll_loss,
ref_chosen_logps if use_ref_model else None,
ref_rejected_logps if use_ref_model else None,
)
return loss, (*return_vars, *aux_outputs)
5 changes: 4 additions & 1 deletion test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ def alignment_loss(
chosen_logratios = policy_chosen_logps - ref_chosen_logps
rejected_logratios = policy_rejected_logps - ref_rejected_logps

chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps)

logits_diff = self.beta * (chosen_logratios - rejected_logratios)
losses = -F.logsigmoid(logits_diff)
return losses
return losses, chosen_rewards, rejected_rewards


class TorchLMHeadDPO(torch.nn.Module):
Expand Down
2 changes: 2 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,8 @@ def get_batch_loss_metrics(
policy_chosen_logits.detach().mean(),
policy_rejected_logits.detach().mean(),
policy_nll_loss,
ref_chosen_logps if self.use_ref_model else None,
ref_rejected_logps if self.use_ref_model else None,
)
return loss, (*return_vars, *aggregated_aux_outputs)

Expand Down