Skip to content

Revert "[algo] fix: Add seq mean mask denominator option"#4769

Merged
wuxibin89 merged 1 commit intomainfrom
revert-4510-add_seq_mean_mask_denominator_option
Jan 2, 2026
Merged

Revert "[algo] fix: Add seq mean mask denominator option"#4769
wuxibin89 merged 1 commit intomainfrom
revert-4510-add_seq_mean_mask_denominator_option

Conversation

@wuxibin89
Copy link
Collaborator

Reverts #4510

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request reverts a previous fix related to loss aggregation. The revert re-introduces several issues, including a potential for numerical instability due to NaN propagation and an incorrect loss calculation for the seq-mean-token-sum-norm mode, which computes a sum instead of a mean. This could lead to incorrect gradients and training instability, especially in distributed environments. I've provided a detailed comment and a code suggestion to address these critical issues.

Comment on lines +806 to +825
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
loss = torch.sum(seq_losses) / loss_scale_factor
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This revert re-introduces critical issues in the loss aggregation logic:

  1. Numerical Instability: The restored code uses torch.sum(loss_mat * loss_mask, dim=-1), which is not robust to NaN values in padded regions of loss_mat. This can lead to NaN losses and training instability. The reverted code correctly used verl_F.masked_sum, which handles this case.

  2. Incorrect Aggregation: The seq-mean-token-sum-norm mode now calculates a scaled sum of sequence losses, not a mean as its name implies. This is inconsistent with other seq-mean-* modes and will result in incorrect loss scaling.

  3. Inconsistent Distributed Handling: The seq-mean-token-sum-norm mode also omits global_batch_size and dp_size, making it behave differently from other modes in a distributed setting.

I've provided a suggestion to fix these issues by using the more robust verl_F.masked_sum and correcting the logic for seq-mean-token-sum-norm to be consistent with the other aggregation modes.

    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1)  # token-sum
        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_mask = torch.sum(loss_mask, dim=-1)  # per-sequence token count
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (seq_mask + 1e-8)  # token-mean
        seq_mask = (seq_mask > 0).float()  # exclude fully masked sequences
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size  # seq-mean
    elif loss_agg_mode == "seq-mean-token-sum-norm":
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1)
        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences
        if loss_scale_factor is None:
            loss_scale_factor = loss_mask.shape[-1]
        seq_losses = seq_losses / loss_scale_factor
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")

@wuxibin89 wuxibin89 enabled auto-merge (squash) January 2, 2026 16:15
@wuxibin89 wuxibin89 disabled auto-merge January 2, 2026 16:15
@wuxibin89 wuxibin89 merged commit 78014a2 into main Jan 2, 2026
59 of 63 checks passed
@wuxibin89 wuxibin89 deleted the revert-4510-add_seq_mean_mask_denominator_option branch January 2, 2026 16:20
jsfanfanfan pushed a commit to meituan-search/verl that referenced this pull request Jan 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant