[fix][train][step-wise] Broadcast step-wise advantage with each step's own response_mask#1507
Merged
CharlieFRuan merged 3 commits intoNovaSky-AI:mainfrom Apr 14, 2026
Conversation
…nse_mask
In step-wise training, `advantages = last_step_advantages[traj_ids]` copied
the last step's already-masked tensor verbatim to every step of the
trajectory. Because each step has its own right-aligned response_mask at
its own token positions (prompts and responses vary in length across
turns), the scalar advantage landed at the last step's token positions —
which almost never overlap with the earlier step's own response positions.
Earlier steps silently received advantages = 0 at their loss_mask=1
positions, so roughly (N-1)/N of step-samples contributed no gradient
signal.
Fix:
- Call the estimator with an all-ones mask so it returns the scalar
advantage at every position (outcome-based estimators produce
`scalar * mask`).
- Broadcast by indexing into `traj_ids`, then multiply by each step's
own `response_mask` so values land at the right token positions.
- Reject `step_wise_trajectories=True` + `gae`/`reinforce++` in
`validate_cfg`: temporal estimators produce per-token advantages that
the scalar-broadcast model can't represent, and the step-wise path
only sees the last step's slice anyway (no cross-step credit).
- Add regression test `test_calc_advantages_and_returns_step_wise_broadcast`
with a realistic right-aligned batch, plus a parametrized
`test_validate_cfg_step_wise_estimator_compatibility` covering all
5 built-in estimators.
- Document the outcome-only constraint in the step-wise tutorial.
See NovaSky-AI#1492.
Contributor
There was a problem hiding this comment.
Code Review
This pull request fixes an issue with advantage broadcasting in step-wise training by ensuring trajectory-level rewards are correctly distributed across all steps using per-step response masks. It also adds configuration validation to restrict step-wise training to outcome-based advantage estimators (GRPO, RLOO, MAXRL), updates the documentation, and includes new regression tests. I have no feedback to provide.
SumanthRH
approved these changes
Apr 13, 2026
Member
|
Thanks! Would be good to get a E2E run! |
CharlieFRuan
added a commit
that referenced
this pull request
Apr 19, 2026
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529). The original E2E fix's `pad_batch` change is dropped since #1529's prompt-based mini-batch boundaries removed the need to pad to `mini_batch_size * n_samples`. - merge_stepwise_output() in trainer_utils.py collapses multi-turn step-wise GeneratorOutput sequences into single sequences when consecutive turns share a common prefix, reducing training cost from O(T^2) to O(T). - trainer.py: call merge before extracting generator fields, update uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge. - Add generator.merge_stepwise_output config flag. - run_search.sh: MERGE_STEPWISE env var. - 16 CPU-only tests covering all 3 merging cases, partial merges, and validation asserts. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In step-wise training,
advantages = last_step_advantages[traj_ids]copied the last step's already-masked tensor verbatim to every step of the trajectory. Because each step has its own right-aligned response_mask at its own token positions (prompts and responses vary in length across turns), the scalar advantage landed at the last step's token positions — which almost never overlap with the earlier step's own response positions. Earlier steps silently received advantages = 0 at their loss_mask=1 positions, so roughly (N-1)/N of step-samples are trained incorrectly.Fix:
scalar * mask).traj_ids, then multiply by each step's ownresponse_maskso values land at the right token positions.step_wise_trajectories=True+gae/reinforce++invalidate_cfg: temporal estimators produce per-token advantages that the scalar-broadcast model can't represent, and the step-wise path only sees the last step's slice anyway (no cross-step credit).test_calc_advantages_and_returns_step_wise_broadcastwith a realistic right-aligned batch, plus a parametrizedtest_validate_cfg_step_wise_estimator_compatibilitycovering all 5 built-in estimators.Fixes #1492.
Testing
https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/reports/PR-1285-search-r1-PAD-left-only-baseline-after-PR-and-step-wise---VmlldzoxNjE3MTA0OQ
Since we will run into #1501 at HEAD, we make this hacky hot fix, mainly the
pad_batch()change: CharlieFRuan@f2783a4We compare against the red curve from #1285
Note that the actual curve should converge slower than this after we merge in #1483, in which each mini batch will truly be one optimizer step. Currently we can have multiple optimizer step due to the number of turns can inflate the number of sequences in each mini batch.
The reward curve being higher is likely due to:
num_mini_batches = len(data) // mini_batch_size, which will no longer be an issue with thepad_targethacky fix above, or properly fixed after [fix][train] Prompt-based mini-batching for step-wise training #1483