Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 35 additions & 20 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,15 +781,21 @@ def agg_loss(
"""
Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism.

NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine
for now, while the legacy model engines conduct the aggregation outside ``agg_loss``.

NOTE: The returned loss has different behaviors for different backend:
- FSDP: the loss is directly used for backward.
- Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule.

# TODO: Consider the numerical stability?

Args:
loss_mat: micro batch loss matrix, (bs, response_length)
loss_mask: micro batch loss mask, (bs, response_length)
loss_agg_mode: method to aggregate the loss matrix into a scalar
dp_size: data parallel size
dp_size: data parallel size. When appling manual aggregation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that we remove the dp_size from the algorithm loss function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp_size was brought into loss by @wuxibin89

Copy link
Collaborator

@tongyx361 tongyx361 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is kept unresolved to show why @wuxibin89 introduced the dp_size below.

scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging.
batch_num_tokens: number of valid tokens in global batch
global_batch_size: global batch size
loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1].
Expand All @@ -799,30 +805,39 @@ def agg_loss(
loss: `a scalar torch.Tensor`
aggregated loss
"""
# NOTE: `masked_sum` is more robust than multiplying the `mask`.
if loss_agg_mode == "token-mean":
if batch_num_tokens is None:
batch_num_tokens = loss_mask.sum()
loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum
seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean
seq_mask = (seq_mask > 0).float() # exclude fully masked sequences
if global_batch_size is None:
global_batch_size = seq_mask.sum()
loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = torch.sum(loss_mat * loss_mask, dim=-1)
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
loss = torch.sum(seq_losses) / loss_scale_factor
elif loss_agg_mode.startswith("seq-mean"):
# TODO: Correct and unify the denominator logic.
if global_batch_size is not None:
seq_denominator = global_batch_size * dp_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seq_denominator is not right, should be global_batch_size / dp_size

else: # The default logic which is only correct when the batch sizes are even.
local_bsz = loss_mat.shape[0]
seq_denominator = local_bsz

if loss_agg_mode.startswith("seq-mean-token-sum"):
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence

if loss_agg_mode == "seq-mean-token-sum":
pass # TODO: Add assertation.
elif loss_agg_mode == "seq-mean-token-sum-norm":
if loss_scale_factor is None:
loss_scale_factor = loss_mask.shape[-1]
seq_losses = seq_losses / loss_scale_factor
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
elif loss_agg_mode == "seq-mean-token-mean":
token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count
# token-mean per sequence
seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8)
else:
raise ValueError(f"Invalid {loss_agg_mode=}")
loss = torch.sum(seq_losses) / seq_denominator # seq-mean
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
raise ValueError(f"Invalid {loss_agg_mode=}")

return loss

Expand Down
1 change: 0 additions & 1 deletion verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def loss_func(output, data, meta_info):

entropy_coeff = self.config.entropy_coeff
loss_agg_mode = self.config.loss_agg_mode

loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")

policy_loss_fn = get_policy_loss_fn(loss_mode)
Expand Down
Loading