-
Notifications
You must be signed in to change notification settings - Fork 3.5k
[algo] fix: Add seq mean mask denominator option #4510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tongyx361
merged 12 commits into
verl-project:main
from
szrlee:add_seq_mean_mask_denominator_option
Dec 17, 2025
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
3f4658c
feat(agg_loss): add exclude_fully_masked_seq option for seq-mean deno…
szrlee 3f2cb33
feat(ActorConfig): add exclude_fully_masked_seq field
szrlee 6d9a296
feat(ppo_loss): propagate exclude_fully_masked_seq to global_batch_info
szrlee 9915db1
feat(dp_actor): populate global_batch_info for agg_loss calls
szrlee de1fde6
feat(megatron_actor): populate global_batch_info for agg_loss calls
szrlee d3df7de
fix(agg_loss): use local batch size to avoid distributed training com…
szrlee db3c2cc
fix(agg_loss): remove dp_size multiplier when using local count fallback
szrlee 8e95ebd
revert: remove exclude_fully_masked_seq feature
szrlee 2af4aff
fix(agg_loss): use total batch size in denominator, never apply mask
szrlee ab58550
chore: remove PR_MESSAGE.md
szrlee d0e2a2b
feat: refactor agg_loss
tongyx361 1985023
feat: remove the global aggregation from legacy model engine
tongyx361 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -781,15 +781,21 @@ def agg_loss( | |
| """ | ||
| Aggregate the loss across global batch to ensure the loss is invariant to fsdp/megatron parallelism. | ||
|
|
||
| NOTE: ``dp_size``, ``batch_num_tokens``, and ``global_batch_size`` are only compatible with the new model engine | ||
| for now, while the legacy model engines conduct the aggregation outside ``agg_loss``. | ||
|
|
||
| NOTE: The returned loss has different behaviors for different backend: | ||
| - FSDP: the loss is directly used for backward. | ||
| - Megatron: the loss should be scaled by `num_microbatches` and `cp_size` for pp schedule. | ||
|
|
||
| # TODO: Consider the numerical stability? | ||
|
|
||
| Args: | ||
| loss_mat: micro batch loss matrix, (bs, response_length) | ||
| loss_mask: micro batch loss mask, (bs, response_length) | ||
| loss_agg_mode: method to aggregate the loss matrix into a scalar | ||
| dp_size: data parallel size | ||
| dp_size: data parallel size. When appling manual aggregation, | ||
| scaling up the ``loss`` by ``dp_size`` can cancel out FSDP averaging. | ||
| batch_num_tokens: number of valid tokens in global batch | ||
| global_batch_size: global batch size | ||
| loss_scale_factor: scale factor for "seq-mean-token-sum-norm" mode. If None, uses loss_mask.shape[-1]. | ||
|
|
@@ -799,30 +805,39 @@ def agg_loss( | |
| loss: `a scalar torch.Tensor` | ||
| aggregated loss | ||
| """ | ||
| # NOTE: `masked_sum` is more robust than multiplying the `mask`. | ||
| if loss_agg_mode == "token-mean": | ||
| if batch_num_tokens is None: | ||
| batch_num_tokens = loss_mask.sum() | ||
| loss = verl_F.masked_sum(loss_mat, loss_mask) / batch_num_tokens * dp_size | ||
| elif loss_agg_mode == "seq-mean-token-sum": | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) # token-sum | ||
| seq_mask = (torch.sum(loss_mask, dim=-1) > 0).float() # exclude fully masked sequences | ||
| if global_batch_size is None: | ||
| global_batch_size = seq_mask.sum() | ||
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean | ||
| elif loss_agg_mode == "seq-mean-token-mean": | ||
| seq_mask = torch.sum(loss_mask, dim=-1) # per-sequence token count | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) / (seq_mask + 1e-8) # token-mean | ||
| seq_mask = (seq_mask > 0).float() # exclude fully masked sequences | ||
| if global_batch_size is None: | ||
| global_batch_size = seq_mask.sum() | ||
| loss = verl_F.masked_sum(seq_losses, seq_mask) / global_batch_size * dp_size # seq-mean | ||
| elif loss_agg_mode == "seq-mean-token-sum-norm": | ||
| seq_losses = torch.sum(loss_mat * loss_mask, dim=-1) | ||
| if loss_scale_factor is None: | ||
| loss_scale_factor = loss_mask.shape[-1] | ||
| loss = torch.sum(seq_losses) / loss_scale_factor | ||
| elif loss_agg_mode.startswith("seq-mean"): | ||
| # TODO: Correct and unify the denominator logic. | ||
| if global_batch_size is not None: | ||
| seq_denominator = global_batch_size * dp_size | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| else: # The default logic which is only correct when the batch sizes are even. | ||
| local_bsz = loss_mat.shape[0] | ||
| seq_denominator = local_bsz | ||
|
|
||
| if loss_agg_mode.startswith("seq-mean-token-sum"): | ||
| seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) # token-sum per sequence | ||
|
|
||
| if loss_agg_mode == "seq-mean-token-sum": | ||
| pass # TODO: Add assertation. | ||
| elif loss_agg_mode == "seq-mean-token-sum-norm": | ||
| if loss_scale_factor is None: | ||
| loss_scale_factor = loss_mask.shape[-1] | ||
| seq_losses = seq_losses / loss_scale_factor | ||
| else: | ||
| raise ValueError(f"Invalid {loss_agg_mode=}") | ||
| elif loss_agg_mode == "seq-mean-token-mean": | ||
| token_counts = torch.sum(loss_mask, dim=-1) # per-sequence token count | ||
| # token-mean per sequence | ||
| seq_losses = verl_F.masked_sum(loss_mat, loss_mask, axis=-1) / (token_counts + 1e-8) | ||
| else: | ||
| raise ValueError(f"Invalid {loss_agg_mode=}") | ||
| loss = torch.sum(seq_losses) / seq_denominator # seq-mean | ||
| else: | ||
| raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}") | ||
| raise ValueError(f"Invalid {loss_agg_mode=}") | ||
|
|
||
| return loss | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible that we remove the dp_size from the algorithm loss function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dp_size was brought into loss by @wuxibin89
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is kept unresolved to show why @wuxibin89 introduced the
dp_sizebelow.