[BREAKING][skyrl-train] Implement loss reduction via advantage normalization and fix token_mean reduction strategy#1296
Conversation
… scale loss by dp_size for FSDP/Megatron parity Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…omparison Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…uction # Conflicts: # skyrl/backends/skyrl_train/utils/ppo_utils.py # skyrl/train/fully_async_trainer.py # skyrl/train/trainer.py # tests/backends/skyrl_train/gpu/test_grpo_sp_sanity.py
…ritic, rename token_mean_baseline to token_mean_legacy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
… add unit tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
token_mean reduction strategy
erictang000
left a comment
There was a problem hiding this comment.
this looks almost good to merge, super clean thanks for adding the token_mean_legacy path
just want to check my understanding + 1 question about the metrics code that I think I probably wrote on the old PR...
… mini-batch reduction - Report unscaled loss metrics (remove * loss_scale / * dp_size) in both FSDP and Megatron workers - Rename reduce_metrics -> reduce_metrics_across_microbatches (sums _loss for gradient accumulation) - Add reduce_metrics_across_minibatches in trainer_utils (averages _loss for logging) - Use sum all-reduce for _loss keys across DP workers to reconstruct full mini-batch loss Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
| "final_loss": unscaled_loss.detach().item(), | ||
| "policy_loss": policy_loss.detach().item(), |
There was a problem hiding this comment.
Metrics fix 1: remove dp_size multiplier in reported metrics, since there's no average that we need to correct for, since reduce_microbatch_metrics and all_reduce_metrics both do sums for *_loss metrics.
| # pop out loss_fn_outputs since it's not a scalar metric and to avoid logging it | ||
| all_metrics.pop("loss_fn_outputs", None) | ||
| reduced_metrics = reduce_metrics(all_metrics) | ||
| reduced_metrics = reduce_metrics_across_minibatches(all_metrics) |
There was a problem hiding this comment.
Metrics fix 2: Take an average across minibatches instead of still summing. This is because the loss reduction normalization happens at the minibatch level. Across different minibatches we should just average, otherwise we'll increase the reported loss scale by ~num_minibatches
…e_metrics Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…uction # Conflicts: # skyrl/train/trainer.py
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| if resolved_loss_name == "cross_entropy": | ||
| loss = policy_loss | ||
| unscaled_loss = policy_loss | ||
| loss = unscaled_loss * grad_sum_correction_factor |
There was a problem hiding this comment.
Q: should this affect the SFT case? SFT doesn't look at the normalized advantages either, similar to the critic loss case.
Before the PR, the SFT case does a sum across the negative log likelihoods within a microbatch, but still averaged over microbatches and dp workers.
Now, we are summing negative log likelihood across the entire minibatch. What's the desired behavior here?
There was a problem hiding this comment.
Reverting to the old behavior for now and we can tackle it in a follow-up. SFT loss reduction is broken already due to taking a sum within the microbatch but then a mean across microbatches/workers. The current behavior does not align with this comment: https://github.com/justinvyu/SkyRL/blob/c5feb83b38f4635c7fc705c2bb192a7d6ad16947/skyrl/backends/skyrl_train/utils/ppo_utils.py#L917
|
To sanity check the difference in loss metric magnitudes, I dumped the raw advantages on the first step and manually calculated the loss with the different reduction methods on the same dumped data. Using dumped advantage tensors from a real GRPO run to compare old vs. new: With With With The new
|
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
…py loss Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
| # TODO: SFT path still averages metrics across microbatches and workers. | ||
| # This needs to be unified with the RL path which sums. | ||
| resolved_loss_name = loss_fn or self.cfg.algorithm.policy_loss_type | ||
| sum_loss_metrics = resolved_loss_name != "cross_entropy" |
There was a problem hiding this comment.
need to followup to unify the codepaths -- SFT loss is wrong right now
| grad_sum_correction_factor = self.mesh_rank.dp_size | ||
|
|
||
| # NOTE: The KL and entropy loss terms are not pre-scaled, | ||
| # so we just average them across microbatches and DP workers. | ||
| loss = policy_loss * grad_sum_correction_factor + (kl_loss_term - entropy_loss_term) * microbatch_weight | ||
| unscaled_loss = loss / grad_sum_correction_factor |
There was a problem hiding this comment.
This part is a bit complicated to maintain kl/entropy loss parity:
- Previously, the kl/entropy terms are per-token averages within the microbatch (see the
masked_meanabove). Then, we took the average across microbatches and DP workers (same as the old loss). - We can't just sum them because they were computed on the worker and we didn't pre-scale them in the same way we scaled the advantages.
- So, to maintain the average behavior, we divide the terms by the microbatch weight (
1/num_microbatches), and we don't apply the grad sum correction factor to keep the all-reduce as a mean across DP workers.
| grad_sum_correction_factor = num_microbatches * dp_size | ||
|
|
||
| # NOTE: The KL and entropy loss terms are not pre-scaled, | ||
| # so we just average them across microbatches and DP workers. | ||
| loss = policy_loss * grad_sum_correction_factor + kl_loss_term - entropy_loss_term | ||
| unscaled_loss = loss / grad_sum_correction_factor |
There was a problem hiding this comment.
This is similar to the FSDP case, except megatron already divides by num_microbatches and dp_size internally (so no need to divide by num_microbatches here).
erictang000
left a comment
There was a problem hiding this comment.
looks great! thanks for all the work getting this in
token_mean reduction strategytoken_mean reduction strategy
…ron (#1420) Fixes `test_megatron_train[tp2_cp2_policy_seq_packing_no_entropy_loss]` failing after #1296. ### Problem The loss refactor in #1296 introduced two CP-specific bugs: 1. **Metrics double-counted across CP ranks**: `all_reduce_metrics` used `get_data_parallel_group(with_context_parallel=True)`, which includes CP ranks in the reduction group. With `sum_loss_metrics=True`, this **sums** `policy_loss` across CP ranks. But since `postprocess_packed_seqs` already gathers logprobs across CP before computing the loss, all CP ranks produce identical metrics — so summing doubles the value. This caused the ~2x discrepancy (`-28.43` FSDP vs `-57.36` Megatron). 2. **Gradient correction factor ignores CP**: `grad_sum_correction_factor` used `get_data_parallel_world_size()` (without CP), but Megatron's `finalize_model_grads` averages gradients across the full DP+CP group. The correction was therefore `1/CP_size` too small. ### Fix - Use `get_data_parallel_group(with_context_parallel=False)` for the metrics all-reduce, since metrics are already complete on each CP rank. - Use `get_data_parallel_world_size(with_context_parallel=True)` for the gradient correction factor, matching the group that `finalize_model_grads` reduces over. Both changes are no-ops when CP=1. <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1420" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This is a breaking change for the default
token_meanloss behavior, as well as observedpolicy_lossmetrics! See the "Differences in reported loss metric magnitudes" section below.Summary
reduce_loss()to always returns a simple masked sum ((loss * mask).sum()). To achieve different reduction strategies, we pre-scale the advantages before they enter the loss function, which also aligns with how Tinker's API handles it.backward()to counteract the default data parallel mean gradient all-reduce across workers to do a sum instead.token_meanloss reduction method to take a mean across all tokens in the minibatch rather than averaging across microbatches. Allows running with the old loss reduction with thetoken_mean_legacyconfig.Loss reduction strategies
Option 1: token_mean
Option 1b: token_mean_legacy
token_meanbehavior before this PR.Option 2: sequence_mean
Option 3: seq_mean_token_sum_norm
Mean all-reduce -> sum all-reduce
We need the loss to be summed across microbatches and data parallel workers:
Difference in reported loss metric magnitudes
You will observe that the loss metric reported has a different magnitude compared to your older runs. This is beacuse the old token_mean implementation was somewhere between a true token mean and a sequence mean due to per-micro-batch normalization (ex: micro_batch_size=1 was equivalent to sequence mean).
The new
token_meanis a proper minibatch token mean, whilesequence_meanweights every sequence equally regardless of length. When comparing the loss produced by different reduction methods computed on the same advantages, from a real run:The old token_mean gave each micro-batch equal weight rather than each token, so its scale depended on how advantages were distributed across micro-batches. The new implementation is invariant to micro-batch size.
Note that
token_mean_legacyreports the old metrics still, and thesequence_meanandseq_mean_token_sum_normmodes also match exactly. See this comment for more details.Tinker compatibility
Here was the first attempt at fixing the loss reduction across microbatches: #909
This method was to track total tokens and then do one big normalization at the
optim_stepin order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque
forward_backward,optim_stepimplementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:The current PR aligns with Tinker semantics:
Example for
loss_reduction="token_mean":1/num_minibatch_tokensnormalization into the advantage:loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokenssum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )Learning curve comparisons before/after the PR
FSDP (wandb)
Megatron (wandb)
1.7B:
30B lora:
master baseline:

token_mean_legacy+ fixedtoken_mean: