diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index dc30adfc343..1043c8f207f 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -781,15 +781,21 @@ def agg_loss( """ Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism. + NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine + for now, while the legacy model engines conduct the aggregation outside ``agg_loss``. + NOTE: The returned loss has different behaviors for different backend: - FSDP: the loss is directly used for backward. - Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule. + # TODO: Consider the numerical stability? + Args: loss_mat: micro batch loss matrix, (bs, response_length) loss_mask: micro batch loss mask, (bs, response_length) loss_agg_mode: method to aggregate the loss matrix into a scalar - dp_size: data parallel size + dp_size: data parallel size. When appling manual aggregation, + scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging. batch_num_tokens: number of valid tokens in global batch global_batch_size: global batch size loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1]. @@ -799,30 +805,39 @@ def agg_loss( loss: `a scalar torch.Tensor` aggregated loss """ + # NOTE: `masked_sum` is more robust than multiplying the `mask`. if loss_agg_mode == "token-mean": if batch_num_tokens is None: batch_num_tokens = loss_mask.sum() loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size - elif loss_agg_mode == "seq-mean-token-sum": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum - seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences - if global_batch_size is None: - global_batch_size = seq_mask.sum() - loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean - elif loss_agg_mode == "seq-mean-token-mean": - seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean - seq_mask = (seq_mask > 0).float() # exclude fully masked sequences - if global_batch_size is None: - global_batch_size = seq_mask.sum() - loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean - elif loss_agg_mode == "seq-mean-token-sum-norm": - seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) - if loss_scale_factor is None: - loss_scale_factor = loss_mask.shape[-1] - loss = torch.sum(seq_losses) / loss_scale_factor + elif loss_agg_mode.startswith("seq-mean"): + # TODO: Correct and unify the denominator logic. + if global_batch_size is not None: + seq_denominator = global_batch_size * dp_size + else: # The default logic which is only correct when the batch sizes are even. + local_bsz = loss_mat.shape[0] + seq_denominator = local_bsz + + if loss_agg_mode.startswith("seq-mean-token-sum"): + seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence + + if loss_agg_mode == "seq-mean-token-sum": + pass # TODO: Add assertation. + elif loss_agg_mode == "seq-mean-token-sum-norm": + if loss_scale_factor is None: + loss_scale_factor = loss_mask.shape[-1] + seq_losses = seq_losses / loss_scale_factor + else: + raise ValueError(f"Invalid {loss_agg_mode=}") + elif loss_agg_mode == "seq-mean-token-mean": + token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count + # token-mean per sequence + seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8) + else: + raise ValueError(f"Invalid {loss_agg_mode=}") + loss = torch.sum(seq_losses) / seq_denominator # seq-mean else: - raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") + raise ValueError(f"Invalid {loss_agg_mode=}") return loss diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 466402cab4e..09beb478477 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -479,7 +479,6 @@ def loss_func(output, data, meta_info): entropy_coeff = self.config.entropy_coeff loss_agg_mode = self.config.loss_agg_mode - loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") policy_loss_fn = get_policy_loss_fn(loss_mode)