diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 74a85563d49..8e8021635ec 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1824,11 +1824,10 @@ def _compute_loss(self, model, inputs): # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) - if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: + if self.loss_type == "cispo": clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach() per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps - - else: + elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) # Two-sided clipping if self.args.delta is not None: @@ -1837,6 +1836,8 @@ def _compute_loss(self, model, inputs): per_token_loss1 = coef_1 * advantages.unsqueeze(1) per_token_loss2 = coef_2 * advantages.unsqueeze(1) per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + else: + raise ValueError(f"Unknown loss type: {self.loss_type}") if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask @@ -1880,7 +1881,7 @@ def masked_batch_mean(x): mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) - if self.loss_type != "cispo": + if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]: # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)