[algo] fix: Add seq mean mask denominator option#4510
[algo] fix: Add seq mean mask denominator option#4510tongyx361 merged 12 commits intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a useful exclude_fully_masked_seq option to control how the denominator is calculated for sequence-mean loss aggregation, and fixes a pre-existing issue where global batch information was not correctly propagated in dp_actor.py and megatron_actor.py. The changes are well-structured and clearly described.
My main feedback is a critical issue in verl/trainer/ppo/core_algos.py where the calculation of global_batch_size for non-fully-masked sequences assumes a uniform distribution across data-parallel workers. This can lead to gradient mismatches and cause distributed training to fail. I've provided suggestions to use torch.distributed.all_reduce for a robust and correct implementation.
There was a problem hiding this comment.
Loss/gradient aggregation in DP training (token losses -> DP rank losses -> global loss) is tricky, which has been not thoroughly considered in the codebase for now.
For now, the dp_size is calculated as:
- Distributed training frameworks sometimes apply aggregation implicitly, e.g., FSDP by default means the loss (sums the loss and divides by the
reduce_scatter_world_size), which will make loss/gradient in the PR's implementation shrink bydp_sizefrom the targeted ground truth. If not considering the numerical scale, multiplying this loss bydp_sizefirst is an acceptable workaround. Besides, as far as I know, for FSDP2 intorch>=2.8.0, we can resolve this by callingset_gradient_divide_factor(1), but I am not sure about other setups like FSDP1 and Megatron. - The "mean" strategy assumes that the
batch_sizes are even between DP ranks, but this is not always the case, e.g., 1) if theseq_maskis valid, each DP rank'ssumis very likely to be uneven between DP ranks; 2) DP balance (balance_batch) might be optimized to allow dispatching with unevenbatch_sizefor better workload balance in the future, so the implementation usingall_reduce(SUM)suggested by Gemini is indeed more robust (but still problematic, see 3). - If Ulysses SP is enabled, the data will be all-gathered within each USP group of
sp_size, which might cause the globalbatch_sizemultiplied bysp_sizeif simply summed up withall_reduce(SUM)as is suggested by Gemini (while the original implementation takes care of USP).
cc @wuxibin89 , maybe we can further improve the aggregation logic in the future.
For this PR individually, I think it can be approved because it is at least not worse than the original implementation, which adds an option where the seq_mask does not affect the aggregation, avoiding the uneven case.
| loss_mask: micro batch loss mask, (bs, response_length) | ||
| loss_agg_mode: method to aggregate the loss matrix into a scalar | ||
| dp_size: data parallel size | ||
| dp_size: data parallel size. When appling manual aggregation, |
There was a problem hiding this comment.
is it possible that we remove the dp_size from the algorithm loss function?
There was a problem hiding this comment.
dp_size was brought into loss by @wuxibin89
There was a problem hiding this comment.
This comment is kept unresolved to show why @wuxibin89 introduced the dp_size below.
|
For example, we have dp_size=2 and num_micro_batches=2 (Gradient is accumulated across 2 micro batches in each dp group, then mean across 2 dp groups).
Then for
@vermouth1992 create a example for explanation: https://gist.github.com/vermouth1992/6c273240765c4f223478081042bfcd4a |
|
Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside So the comments above are either resolved or avoided. @wuxibin89 @ISEEKYAN . |
| elif loss_agg_mode.startswith("seq-mean"): | ||
| # TODO: Correct and unify the denominator logic. | ||
| if global_batch_size is not None: | ||
| seq_denominator = global_batch_size * dp_size |
There was a problem hiding this comment.
The seq_denominator is not right, should be global_batch_size / dp_size
## Summary Refactor `agg_loss` function and fix entropy/KL loss scaling in distributed training. **Changes:** - **Refactor**: Unify `seq-mean-*` modes with shared denominator logic using `masked_sum` - **Behavior change**: `seq-mean-token-sum-norm` now applies seq-mean division (denominator = `global_batch_size * dp_size` or `local_bsz`), matching the mode name - **Simplification**: Remove fully-masked sequence exclusion from denominator; use total batch size consistently NOTE: Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside `agg_loss` and is going to be deprecated, we keep this PR from modifying the the legacy model engine.⚠️ **Breaking**: `seq-mean-token-sum-norm` now divides by both `loss_scale_factor` AND `seq_denominator`. Previously only divided by `loss_scale_factor`. ## Test plan - [ ] Verify PPO training with `seq-mean-token-sum` mode - [ ] Verify PPO training with `seq-mean-token-mean` mode - [ ] Verify PPO training with `seq-mean-token-sum-norm` mode (note: behavior changed) - [ ] Confirm entropy/KL loss values are correctly scaled in multi-GPU training --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
## Summary Refactor `agg_loss` function and fix entropy/KL loss scaling in distributed training. **Changes:** - **Refactor**: Unify `seq-mean-*` modes with shared denominator logic using `masked_sum` - **Behavior change**: `seq-mean-token-sum-norm` now applies seq-mean division (denominator = `global_batch_size * dp_size` or `local_bsz`), matching the mode name - **Simplification**: Remove fully-masked sequence exclusion from denominator; use total batch size consistently NOTE: Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside `agg_loss` and is going to be deprecated, we keep this PR from modifying the the legacy model engine.⚠️ **Breaking**: `seq-mean-token-sum-norm` now divides by both `loss_scale_factor` AND `seq_denominator`. Previously only divided by `loss_scale_factor`. ## Test plan - [ ] Verify PPO training with `seq-mean-token-sum` mode - [ ] Verify PPO training with `seq-mean-token-mean` mode - [ ] Verify PPO training with `seq-mean-token-sum-norm` mode (note: behavior changed) - [ ] Confirm entropy/KL loss values are correctly scaled in multi-GPU training --------- Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>

Summary
Refactor
agg_lossfunction and fix entropy/KL loss scaling in distributed training.Changes:
seq-mean-*modes with shared denominator logic usingmasked_sumseq-mean-token-sum-normnow applies seq-mean division (denominator =global_batch_size * dp_sizeorlocal_bsz), matching the mode nameNOTE: Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside
agg_lossand is going to be deprecated, we keep this PR from modifying the the legacy model engine.seq-mean-token-sum-normnow divides by bothloss_scale_factorANDseq_denominator. Previously only divided byloss_scale_factor.Test plan
seq-mean-token-summodeseq-mean-token-meanmodeseq-mean-token-sum-normmode (note: behavior changed)