diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 7a57b2c80cc..a95e5e71849 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -781,21 +781,15 @@ def agg_loss( """ Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism. - NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine - for now, while the legacy model engines conduct the aggregation outside ``agg_loss``. - NOTE: The returned loss has different behaviors for different backend: - FSDP: the loss is directly used for backward. - Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule. - # TODO: Consider the numerical stability? - Args: loss_mat: micro batch loss matrix, (bs, response_length) loss_mask: micro batch loss mask, (bs, response_length) loss_agg_mode: method to aggregate the loss matrix into a scalar - dp_size: data parallel size. When appling manual aggregation, - scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging. + dp_size: data parallel size batch_num_tokens: number of valid tokens in global batch global_batch_size: global batch size loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1]. @@ -805,39 +799,30 @@ def agg_loss( loss: `a scalar torch.Tensor` aggregated loss """ - # NOTE: `masked_sum` is more robust than multiplying the `mask`. if loss_agg_mode == "token-mean": if batch_num_tokens is None: batch_num_tokens = loss_mask.sum() loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size - elif loss_agg_mode.startswith("seq-mean"): - # TODO: Correct and unify the denominator logic. - if global_batch_size is not None: - seq_denominator = global_batch_size * dp_size - else: # The default logic which is only correct when the batch sizes are even. - local_bsz = loss_mat.shape[0] - seq_denominator = local_bsz - - if loss_agg_mode.startswith("seq-mean-token-sum"): - seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence - - if loss_agg_mode == "seq-mean-token-sum": - pass # TODO: Add assertation. - elif loss_agg_mode == "seq-mean-token-sum-norm": - if loss_scale_factor is None: - loss_scale_factor = loss_mask.shape[-1] - seq_losses = seq_losses / loss_scale_factor - else: - raise ValueError(f"Invalid {loss_agg_mode=}") - elif loss_agg_mode == "seq-mean-token-mean": - token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count - # token-mean per sequence - seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8) - else: - raise ValueError(f"Invalid {loss_agg_mode=}") - loss = torch.sum(seq_losses) / seq_denominator # seq-mean + elif loss_agg_mode == "seq-mean-token-sum": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum + seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences + if global_batch_size is None: + global_batch_size = seq_mask.sum() + loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean + elif loss_agg_mode == "seq-mean-token-mean": + seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean + seq_mask = (seq_mask > 0).float() # exclude fully masked sequences + if global_batch_size is None: + global_batch_size = seq_mask.sum() + loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean + elif loss_agg_mode == "seq-mean-token-sum-norm": + seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) + if loss_scale_factor is None: + loss_scale_factor = loss_mask.shape[-1] + loss = torch.sum(seq_losses) / loss_scale_factor else: - raise ValueError(f"Invalid {loss_agg_mode=}") + raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") return loss diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 6b279fc8903..5b64f1901f0 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -484,6 +484,7 @@ def loss_func(output, data, meta_info): entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode + loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") policy_loss_fn = get_policy_loss_fn(loss_mode)