Skip to content

[BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296

Merged
erictang000 merged 26 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction
Mar 31, 2026
Merged

[BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
erictang000 merged 26 commits intoNovaSky-AI:mainfrom
justinvyu:token_mean_loss_reduction

Conversation

@justinvyu
Copy link
Copy Markdown
Contributor

@justinvyu justinvyu commented Mar 9, 2026

This is a breaking change for the default token_mean loss behavior, as well as observed policy_loss metrics! See the "Differences in reported loss metric magnitudes" section below.

Summary

  • Change reduce_loss() to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.
    • Scales the loss by the DP size before calling backward() to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.
  • Fixes the token_mean loss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with the token_mean_legacy config.

Loss reduction strategies

  • Option 1: token_mean

    • Average loss per token across the entire mini-batch.
    • This is the fixed version where the denominator is the total token count across the full mini-batch, so the gradient is independent of how the minibbatch is split into micro-batches.
  • Option 1b: token_mean_legacy

    • Compute token-mean loss within each micro-batch, then average across micro-batches.
    • This reproduces the token_mean behavior before this PR.
    • The problem: if micro-batches have different token counts, the effective weighting differs from a true global token mean. This is also less usable since changing micro batch size affects the loss and the training dynamics.
    • Kept as a fallback in case of performance regressions — we should remove this down the line.
  • Option 2: sequence_mean

    • Compute per-token loss within each sequence, average across sequences.
    • This is unchanged and is just implemented via advantage normalization instead.
  • Option 3: seq_mean_token_sum_norm

    • Dr. GRPO style — normalize by a fixed constant to avoid any length-dependent weighting.
    • This is unchanged and is just implemented via advantage normalization instead.

Mean all-reduce -> sum all-reduce

We need the loss to be summed across microbatches and data parallel workers:

  • DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size in order to keep the loss sum across data parallel groups.
  • Megatron also does a similar mean reduction across microbatches and workers, so we counteract this by multiplying by num microbatches and DP size to achieve the sum.

Difference in reported loss metric magnitudes

You will observe that the loss metric reported has a different magnitude compared to your older runs. This is beacuse the old token_mean implementation was somewhere between a true token mean and a sequence mean due to per-micro-batch normalization (ex: micro_batch_size=1 was equivalent to sequence mean).

The new token_mean is a proper minibatch token mean, while sequence_mean weights every sequence equally regardless of length. When comparing the loss produced by different reduction methods computed on the same advantages, from a real run:

  token_mean (new):  0.322   — every token weighted equally across the mini-batch
  token_mean (old):  0.065   — mean of per-micro-batch token means, where micro_batch_size=4
  sequence_mean:     0.00098 — every sequence weighted equally

The old token_mean gave each micro-batch equal weight rather than each token, so its scale depended on how advantages were distributed across micro-batches. The new implementation is invariant to micro-batch size.

Note that token_mean_legacy reports the old metrics still, and the sequence_mean and seq_mean_token_sum_norm modes also match exactly. See this comment for more details.

Tinker compatibility

Here was the first attempt at fixing the loss reduction across microbatches: #909

This method was to track total tokens and then do one big normalization at the optim_step in order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.

The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque forward_backward, optim_step implementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:

client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible

The current PR aligns with Tinker semantics:

Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.

Example for loss_reduction="token_mean":

  • Move the 1/num_minibatch_tokens normalization into the advantage: loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokens
  • -> sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )

Learning curve comparisons before/after the PR

FSDP (wandb)

Screenshot 2026-03-20 at 3 29 10 PM

Megatron (wandb)

1.7B:

Screenshot 2026-03-20 at 3 29 40 PM

30B lora:

master baseline:
Screenshot 2026-03-20 at 3 33 29 PM

token_mean_legacy + fixed token_mean:

Screenshot 2026-03-24 at 11 16 56 AM
Open with Devin

justinvyu and others added 3 commits March 9, 2026 11:51
… scale loss by dp_size for FSDP/Megatron parity

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment thread skyrl/train/trainer.py Outdated
Comment thread skyrl/train/trainer.py Outdated
justinvyu and others added 7 commits March 9, 2026 18:27
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction

# Conflicts:
#	skyrl/backends/skyrl_train/utils/ppo_utils.py
#	skyrl/train/fully_async_trainer.py
#	skyrl/train/trainer.py
#	tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…ritic, rename token_mean_baseline to token_mean_legacy

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@justinvyu justinvyu marked this pull request as ready for review March 20, 2026 22:34
gemini-code-assist[bot]

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no bugs or issues to report.

Open in Devin Review

@justinvyu justinvyu changed the title [wip] loss reduction [skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy Mar 20, 2026
Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks almost good to merge, super clean thanks for adding the token_mean_legacy path

just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...

Comment thread skyrl/backends/skyrl_train/utils/ppo_utils.py
Comment thread skyrl/backends/skyrl_train/workers/megatron/megatron_model_wrapper.py Outdated
… mini-batch reduction

- Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers
- Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation)
- Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging)
- Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Comment on lines +371 to +372
"final_loss": unscaled_loss.detach().item(),
"policy_loss": policy_loss.detach().item(),
Copy link
Copy Markdown
Contributor Author

@justinvyu justinvyu Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.

Comment thread skyrl/train/trainer.py Outdated
# pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it
all_metrics.pop("loss_fn_outputs", None)
reduced_metrics = reduce_metrics(all_metrics)
reduced_metrics = reduce_metrics_across_minibatches(all_metrics)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches

Comment thread skyrl/backends/skyrl_train/workers/worker.py Outdated
justinvyu and others added 8 commits March 27, 2026 11:53
…e_metrics

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uction

# Conflicts:
#	skyrl/train/trainer.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines +292 to +294
if resolved_loss_name == "cross_entropy":
loss = policy_loss
unscaled_loss = policy_loss
loss = unscaled_loss * grad_sum_correction_factor
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case.

Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers.
Now, we are summing negative log likelihood across the entire minibatch. What's the desired behavior here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverting to the old behavior for now and we can tackle it in a follow-up. SFT loss reduction is broken already due to taking a sum within the microbatch but then a mean across microbatches/workers. The current behavior does not align with this comment: https://github.com/justinvyu/SkyRL/blob/c5feb83b38f4635c7fc705c2bb192a7d6ad16947/skyrl/backends/skyrl_train/utils/ppo_utils.py#L917

@justinvyu
Copy link
Copy Markdown
Contributor Author

justinvyu commented Mar 27, 2026

To sanity check the difference in loss metric magnitudes, I dumped the raw advantages on the first step and manually calculated the loss with the different reduction methods on the same dumped data.

Using dumped advantage tensors from a real GRPO run to compare old vs. new:

With micro_batch_size=4 (128 micro-batches), old and new differ by ~5x — matching what's observed in the actual run:

Average: old=0.065  new=0.322  ratio=4.92

With micro_batch_size=1, the old token_mean reduces to sequence_mean (each sequence weighted equally). The old values match sequence_mean exactly:

token_mean old:  Average=-0.024
sequence_mean:   Average=-0.024  ratio=1.0000

With micro_batch_size=512 (1 micro-batch = full mini-batch), old and new converge:

Average: old=0.322  new=0.322  ratio=1.0000

The new token_mean value (0.322) is the same regardless of micro_batch_size — which is the correct behavior. The old value varied between -0.024 (at micro_batch_size=1, i.e. sequence_mean) and 0.322 (at micro_batch_size=mini_batch_size=512, i.e. fixed token_mean) depending on how micro-batches were formed.

token_mean_legacy reproduces the old behavior. Runs using token_mean won't be directly comparable to before, but the difference is analogous to comparing token_mean vs. sequence_mean — a different weighting, not a bug.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
devin-ai-integration[bot]

This comment was marked as resolved.

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…py loss

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 14 additional findings in Devin Review.

Open in Devin Review

Comment thread skyrl/backends/skyrl_train/workers/worker.py
Comment on lines +718 to +721
# TODO: SFT path still averages metrics across microbatches and workers.
# This needs to be unified with the RL path which sums.
resolved_loss_name = loss_fn or self.cfg.algorithm.policy_loss_type
sum_loss_metrics = resolved_loss_name != "cross_entropy"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to followup to unify the codepaths -- SFT loss is wrong right now

Comment on lines +886 to +891
grad_sum_correction_factor = self.mesh_rank.dp_size

# NOTE: The KL and entropy loss terms are not pre-scaled,
# so we just average them across microbatches and DP workers.
loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * microbatch_weight
unscaled_loss = loss / grad_sum_correction_factor
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is a bit complicated to maintain kl/entropy loss parity:

  • Previously, the kl/entropy terms are per-token averages within the microbatch (see the masked_mean above). Then, we took the average across microbatches and DP workers (same as the old loss).
  • We can't just sum them because they were computed on the worker and we didn't pre-scale them in the same way we scaled the advantages.
  • So, to maintain the average behavior, we divide the terms by the microbatch weight (1/num_microbatches), and we don't apply the grad sum correction factor to keep the all-reduce as a mean across DP workers.

Comment on lines +352 to +357
grad_sum_correction_factor = num_microbatches * dp_size

# NOTE: The KL and entropy loss terms are not pre-scaled,
# so we just average them across microbatches and DP workers.
loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term
unscaled_loss = loss / grad_sum_correction_factor
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to the FSDP case, except megatron already divides by num_microbatches and dp_size internally (so no need to divide by num_microbatches here).

Copy link
Copy Markdown
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! thanks for all the work getting this in

@erictang000 erictang000 changed the title [skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy [BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy Mar 31, 2026
@erictang000 erictang000 merged commit bf243b8 into NovaSky-AI:main Mar 31, 2026
5 of 6 checks passed
erictang000 added a commit that referenced this pull request Mar 31, 2026
…ron (#1420)

Fixes `test_megatron_train[tp2_cp2_policy_seq_packing_no_entropy_loss]`
failing after #1296.

### Problem
The loss refactor in #1296 introduced two CP-specific bugs:
1. **Metrics double-counted across CP ranks**: `all_reduce_metrics` used
`get_data_parallel_group(with_context_parallel=True)`, which includes CP
ranks in the reduction group. With `sum_loss_metrics=True`, this
**sums** `policy_loss` across CP ranks. But since
`postprocess_packed_seqs` already gathers logprobs across CP before
computing the loss, all CP ranks produce identical metrics — so summing
doubles the value. This caused the ~2x discrepancy (`-28.43` FSDP vs
`-57.36` Megatron).
2. **Gradient correction factor ignores CP**:
`grad_sum_correction_factor` used `get_data_parallel_world_size()`
(without CP), but Megatron's `finalize_model_grads` averages gradients
across the full DP+CP group. The correction was therefore `1/CP_size`
too small.
### Fix
- Use `get_data_parallel_group(with_context_parallel=False)` for the
metrics all-reduce, since metrics are already complete on each CP rank.
- Use `get_data_parallel_world_size(with_context_parallel=True)` for the
gradient correction factor, matching the group that
`finalize_model_grads` reduces over.
Both changes are no-ops when CP=1.
<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1420"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
@erictang000 erictang000 mentioned this pull request Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants