Skip to content

[algo] fix: Add seq mean mask denominator option#4510

Merged
tongyx361 merged 12 commits intoverl-project:mainfrom
szrlee:add_seq_mean_mask_denominator_option
Dec 17, 2025
Merged

[algo] fix: Add seq mean mask denominator option#4510
tongyx361 merged 12 commits intoverl-project:mainfrom
szrlee:add_seq_mean_mask_denominator_option

Conversation

@szrlee
Copy link
Collaborator

@szrlee szrlee commented Dec 14, 2025

Summary

Refactor agg_loss function and fix entropy/KL loss scaling in distributed training.

Changes:

  • Refactor: Unify seq-mean-* modes with shared denominator logic using masked_sum
  • Behavior change: seq-mean-token-sum-norm now applies seq-mean division (denominator = global_batch_size * dp_size or local_bsz), matching the mode name
  • Simplification: Remove fully-masked sequence exclusion from denominator; use total batch size consistently

NOTE: Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside agg_loss and is going to be deprecated, we keep this PR from modifying the the legacy model engine.

⚠️ Breaking: seq-mean-token-sum-norm now divides by both loss_scale_factor AND seq_denominator. Previously only divided by loss_scale_factor.

Test plan

  • Verify PPO training with seq-mean-token-sum mode
  • Verify PPO training with seq-mean-token-mean mode
  • Verify PPO training with seq-mean-token-sum-norm mode (note: behavior changed)
  • Confirm entropy/KL loss values are correctly scaled in multi-GPU training

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a useful exclude_fully_masked_seq option to control how the denominator is calculated for sequence-mean loss aggregation, and fixes a pre-existing issue where global batch information was not correctly propagated in dp_actor.py and megatron_actor.py. The changes are well-structured and clearly described.

My main feedback is a critical issue in verl/trainer/ppo/core_algos.py where the calculation of global_batch_size for non-fully-masked sequences assumes a uniform distribution across data-parallel workers. This can lead to gradient mismatches and cause distributed training to fail. I've provided suggestions to use torch.distributed.all_reduce for a robust and correct implementation.

@tongyx361 tongyx361 self-assigned this Dec 14, 2025
Copy link
Collaborator

@tongyx361 tongyx361 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Loss/gradient aggregation in DP training (token losses -> DP rank losses -> global loss) is tricky, which has been not thoroughly considered in the codebase for now.

For now, the dp_size is calculated as:

https://github.com/volcengine/verl/blob/7deb67ca177a38b060bca007198aecb3fa4431dd/verl/workers/engine/fsdp/transformer_impl.py#L478-L479

  1. Distributed training frameworks sometimes apply aggregation implicitly, e.g., FSDP by default means the loss (sums the loss and divides by the reduce_scatter_world_size), which will make loss/gradient in the PR's implementation shrink by dp_size from the targeted ground truth. If not considering the numerical scale, multiplying this loss by dp_size first is an acceptable workaround. Besides, as far as I know, for FSDP2 in torch>=2.8.0, we can resolve this by calling set_gradient_divide_factor(1), but I am not sure about other setups like FSDP1 and Megatron.
  2. The "mean" strategy assumes that the batch_sizes are even between DP ranks, but this is not always the case, e.g., 1) if the seq_mask is valid, each DP rank's sum is very likely to be uneven between DP ranks; 2) DP balance (balance_batch) might be optimized to allow dispatching with uneven batch_size for better workload balance in the future, so the implementation using all_reduce(SUM) suggested by Gemini is indeed more robust (but still problematic, see 3).
  3. If Ulysses SP is enabled, the data will be all-gathered within each USP group of sp_size, which might cause the global batch_size multiplied by sp_size if simply summed up with all_reduce(SUM) as is suggested by Gemini (while the original implementation takes care of USP).

cc @wuxibin89 , maybe we can further improve the aggregation logic in the future.

For this PR individually, I think it can be approved because it is at least not worse than the original implementation, which adds an option where the seq_mask does not affect the aggregation, avoiding the uneven case.

loss_mask: micro batch loss mask, (bs, response_length)
loss_agg_mode: method to aggregate the loss matrix into a scalar
dp_size: data parallel size
dp_size: data parallel size. When appling manual aggregation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible that we remove the dp_size from the algorithm loss function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dp_size was brought into loss by @wuxibin89

Copy link
Collaborator

@tongyx361 tongyx361 Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is kept unresolved to show why @wuxibin89 introduced the dp_size below.

@wuxibin89
Copy link
Collaborator

wuxibin89 commented Dec 15, 2025

dp_size, batch_num_tokens and global_batch_size were introduced to make sure each DP group's loss is averaged across global mini batch instead of averaged over local micro batch. This is mean to correct the contribution of micro-batch to gradient from different dp groups, which have different number of valid tokens and sequences. We want to make sure that each token each sequence in micro batch has equal contribution to gradient.

For example, we have dp_size=2 and num_micro_batches=2 (Gradient is accumulated across 2 micro batches in each dp group, then mean across 2 dp groups).

  • dp_rank=0: [micro_batch_0_0, micro_batch_0_1]
  • dp_rank=1: [micro_batch_1_0, micro_batch_1_1]

Then for agg_loss, we have

  • dp_size: 2
  • batch_num_tokens: sum(micro_batch_0_0, micro_batch_0_1, micro_batch_1_0, micro_batch_1_1)
  • global_batch_size: sum(len(micro_batch_0_0), len(micro_batch_0_1), len(micro_batch_1_0), len(micro_batch_1_1))

@vermouth1992 create a example for explanation: https://gist.github.com/vermouth1992/6c273240765c4f223478081042bfcd4a

@tongyx361
Copy link
Collaborator

tongyx361 commented Dec 15, 2025

Since the global loss aggregation logic is not compatible with the legacy model engine that conduct the aggregation outside agg_loss and is going to be deprecated, we keep this PR from modifying the the legacy model engine.

So the comments above are either resolved or avoided. @wuxibin89 @ISEEKYAN .

@tongyx361 tongyx361 merged commit 6a58521 into verl-project:main Dec 17, 2025
75 of 78 checks passed
@wuxibin89
Copy link
Collaborator

wuxibin89 commented Jan 2, 2026

image

This PR make new model engine loss significant small than expected.

elif loss_agg_mode.startswith("seq-mean"):
# TODO: Correct and unify the denominator logic.
if global_batch_size is not None:
seq_denominator = global_batch_size * dp_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The seq_denominator is not right, should be global_batch_size / dp_size

wuxibin89 added a commit that referenced this pull request Jan 2, 2026
jsfanfanfan pushed a commit to meituan-search/verl that referenced this pull request Jan 9, 2026
vyomakesh0728 added a commit to vyomakesh0728/verl that referenced this pull request Jan 22, 2026
## Summary

Refactor `agg_loss` function and fix entropy/KL loss scaling in
distributed training.

**Changes:**
- **Refactor**: Unify `seq-mean-*` modes with shared denominator logic
using `masked_sum`
- **Behavior change**: `seq-mean-token-sum-norm` now applies seq-mean
division (denominator = `global_batch_size * dp_size` or `local_bsz`),
matching the mode name
- **Simplification**: Remove fully-masked sequence exclusion from
denominator; use total batch size consistently

NOTE: Since the global loss aggregation logic is not compatible with the
legacy model engine that conduct the aggregation outside `agg_loss` and
is going to be deprecated, we keep this PR from modifying the the legacy
model engine.

⚠️ **Breaking**: `seq-mean-token-sum-norm` now divides by both
`loss_scale_factor` AND `seq_denominator`. Previously only divided by
`loss_scale_factor`.

## Test plan

- [ ] Verify PPO training with `seq-mean-token-sum` mode
- [ ] Verify PPO training with `seq-mean-token-mean` mode  
- [ ] Verify PPO training with `seq-mean-token-sum-norm` mode (note:
behavior changed)
- [ ] Confirm entropy/KL loss values are correctly scaled in multi-GPU
training

---------

Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
sophiayyya pushed a commit to sophiayyya/verl that referenced this pull request Jan 25, 2026
## Summary

Refactor `agg_loss` function and fix entropy/KL loss scaling in
distributed training.

**Changes:**
- **Refactor**: Unify `seq-mean-*` modes with shared denominator logic
using `masked_sum`
- **Behavior change**: `seq-mean-token-sum-norm` now applies seq-mean
division (denominator = `global_batch_size * dp_size` or `local_bsz`),
matching the mode name
- **Simplification**: Remove fully-masked sequence exclusion from
denominator; use total batch size consistently

NOTE: Since the global loss aggregation logic is not compatible with the
legacy model engine that conduct the aggregation outside `agg_loss` and
is going to be deprecated, we keep this PR from modifying the the legacy
model engine.

⚠️ **Breaking**: `seq-mean-token-sum-norm` now divides by both
`loss_scale_factor` AND `seq_denominator`. Previously only divided by
`loss_scale_factor`.

## Test plan

- [ ] Verify PPO training with `seq-mean-token-sum` mode
- [ ] Verify PPO training with `seq-mean-token-mean` mode  
- [ ] Verify PPO training with `seq-mean-token-sum-norm` mode (note:
behavior changed)
- [ ] Confirm entropy/KL loss values are correctly scaled in multi-GPU
training

---------

Co-authored-by: Shawn/Yuxuan Tong <tongyuxuan361@gmail.com>
y-a23 pushed a commit to y-a23/query that referenced this pull request Feb 5, 2026
KimperYang pushed a commit to KimperYang/TauVerl that referenced this pull request Mar 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants