diff --git a/trl/trainer/orpo_trainer.py b/trl/trainer/orpo_trainer.py index 84a800ba4c3..fe11d136c19 100644 --- a/trl/trainer/orpo_trainer.py +++ b/trl/trainer/orpo_trainer.py @@ -714,6 +714,8 @@ def cross_entropy_loss(logits, labels): labels = concatenated_batch["concatenated_labels"].clone() else: labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])