Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 20 additions & 35 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,21 +781,15 @@ def agg_loss(
"""
Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism.

NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine
for now, while the legacy model engines conduct the aggregation outside ``agg_loss``.

NOTE: The returned loss has different behaviors for different backend:
- FSDP: the loss is directly used for backward.
- Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule.

# TODO: Consider the numerical stability?

Args:
loss_mat: micro batch loss matrix, (bs, response_length)
loss_mask: micro batch loss mask, (bs, response_length)
loss_agg_mode: method to aggregate the loss matrix into a scalar
dp_size: data parallel size. When appling manual aggregation,
scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging.
dp_size: data parallel size
batch_num_tokens: number of valid tokens in global batch
global_batch_size: global batch size
loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1].
Expand All @@ -805,39 +799,30 @@ def agg_loss(
loss: `a scalar torch.Tensor`
aggregated loss
"""
# NOTE: `masked_sum` is more robust than multiplying the `mask`.
if loss_agg_mode == "token-mean":
if batch_num_tokens is None:
batch_num_tokens = loss_mask.sum()
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
elif loss_agg_mode.startswith("seq-mean"):
# TODO: Correct and unify the denominator logic.
if global_batch_size is not None:
seq_denominator = global_batch_size * dp_size
else: # The default logic which is only correct when the batch sizes are even.
local_bsz = loss_mat.shape[0]
seq_denominator = local_bsz

if loss_agg_mode.startswith("seq-mean-token-sum"):
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence

if loss_agg_mode == "seq-mean-token-sum":
pass # TODO: Add assertation.
elif loss_agg_mode == "seq-mean-token-sum-norm":
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
seq_losses = seq_losses / loss_scale_factor
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
elif loss_agg_mode == "seq-mean-token-mean":
token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count
# token-mean per sequence
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8)
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
loss = torch.sum(seq_losses) / seq_denominator # seq-mean
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
loss = torch.sum(seq_losses) / loss_scale_factor
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
Comment on lines +806 to +825
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This revert re-introduces critical issues in the loss aggregation logic:

  1. Numerical Instability: The restored code uses torch.sum(loss_mat * loss_mask, dim=-1), which is not robust to NaN values in padded regions of loss_mat. This can lead to NaN losses and training instability. The reverted code correctly used verl_F.masked_sum, which handles this case.

  2. Incorrect Aggregation: The seq-mean-token-sum-norm mode now calculates a scaled sum of sequence losses, not a mean as its name implies. This is inconsistent with other seq-mean-* modes and will result in incorrect loss scaling.

  3. Inconsistent Distributed Handling: The seq-mean-token-sum-norm mode also omits global_batch_size and dp_size, making it behave differently from other modes in a distributed setting.

I've provided a suggestion to fix these issues by using the more robust verl_F.masked_sum and correcting the logic for seq-mean-token-sum-norm to be consistent with the other aggregation modes.

    elif loss_agg_mode == "seq-mean-token-sum":
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1)  # token-sum
        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size  # seq-mean
    elif loss_agg_mode == "seq-mean-token-mean":
        seq_mask = torch.sum(loss_mask, dim=-1)  # per-sequence token count
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (seq_mask + 1e-8)  # token-mean
        seq_mask = (seq_mask > 0).float()  # exclude fully masked sequences
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size  # seq-mean
    elif loss_agg_mode == "seq-mean-token-sum-norm":
        seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1)
        seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float()  # exclude fully masked sequences
        if loss_scale_factor is None:
            loss_scale_factor = loss_mask.shape[-1]
        seq_losses = seq_losses / loss_scale_factor
        if global_batch_size is None:
            global_batch_size = seq_mask.sum()
        loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size
    else:
        raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")


return loss

Expand Down
1 change: 1 addition & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def loss_func(output, data, meta_info):

entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

policy_loss_fn = get_policy_loss_fn(loss_mode)
Expand Down
Loading