diff --git a/trl/trainer/dpo_trainer.py b/trl/trainer/dpo_trainer.py index 95aa3f4e69..7423728a52 100644 --- a/trl/trainer/dpo_trainer.py +++ b/trl/trainer/dpo_trainer.py @@ -1277,8 +1277,9 @@ def dpo_loss( "'apo_down', 'sft']" ) - chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() - rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() + if loss_type != "sft": + chosen_rewards = self.beta * (chosen_logps.to(device) - ref_chosen_logps.to(device)).detach() + rejected_rewards = self.beta * (rejected_logps.to(device) - ref_rejected_logps.to(device)).detach() return losses, chosen_rewards, rejected_rewards