Revert "[algo] fix: Add seq mean mask denominator option"#4769
Revert "[algo] fix: Add seq mean mask denominator option"#4769
Conversation
This reverts commit 6a58521.
There was a problem hiding this comment.
Code Review
This pull request reverts a previous fix related to loss aggregation. The revert re-introduces several issues, including a potential for numerical instability due to NaN propagation and an incorrect loss calculation for the seq-mean-token-sum-norm mode, which computes a sum instead of a mean. This could lead to incorrect gradients and training instability, especially in distributed environments. I've provided a detailed comment and a code suggestion to address these critical issues.
| elif loss_agg_mode == "seq-mean-token-sum": | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum | ||
| seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences | ||
| if global_batch_size is None: | ||
| global_batch_size = seq_mask.sum() | ||
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean | ||
| elif loss_agg_mode == "seq-mean-token-mean": | ||
| seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean | ||
| seq_mask = (seq_mask > 0).float() # exclude fully masked sequences | ||
| if global_batch_size is None: | ||
| global_batch_size = seq_mask.sum() | ||
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean | ||
| elif loss_agg_mode == "seq-mean-token-sum-norm": | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) | ||
| if loss_scale_factor is None: | ||
| loss_scale_factor = loss_mask.shape[-1] | ||
| loss = torch.sum(seq_losses) / loss_scale_factor | ||
| else: | ||
| raise ValueError(f"Invalid {loss_agg_mode=}") | ||
| raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") |
There was a problem hiding this comment.
This revert re-introduces critical issues in the loss aggregation logic:
-
Numerical Instability: The restored code uses
torch.sum(loss_mat * loss_mask, dim=-1), which is not robust toNaNvalues in padded regions ofloss_mat. This can lead toNaNlosses and training instability. The reverted code correctly usedverl_F.masked_sum, which handles this case. -
Incorrect Aggregation: The
seq-mean-token-sum-normmode now calculates a scaled sum of sequence losses, not a mean as its name implies. This is inconsistent with otherseq-mean-*modes and will result in incorrect loss scaling. -
Inconsistent Distributed Handling: The
seq-mean-token-sum-normmode also omitsglobal_batch_sizeanddp_size, making it behave differently from other modes in a distributed setting.
I've provided a suggestion to fix these issues by using the more robust verl_F.masked_sum and correcting the logic for seq-mean-token-sum-norm to be consistent with the other aggregation modes.
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1)
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
seq_losses = seq_losses / loss_scale_factor
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
Reverts #4510