Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,11 +1824,10 @@ def _compute_loss(self, model, inputs):

# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
if self.loss_type == "cispo":
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps

else:
elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
# Two-sided clipping
if self.args.delta is not None:
Expand All @@ -1837,6 +1836,8 @@ def _compute_loss(self, model, inputs):
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}")

if entropy_mask is not None:
per_token_loss = per_token_loss * entropy_mask
Expand Down Expand Up @@ -1880,7 +1881,7 @@ def masked_batch_mean(x):
mean_entropy = masked_batch_mean(entropies)
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())

if self.loss_type != "cispo":
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
# Compute the clipped probability ratios
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
Expand Down
Loading