Skip to content

[fix][train][step-wise] Broadcast step-wise advantage with each step's own response_mask#1507

Merged
CharlieFRuan merged 3 commits intoNovaSky-AI:mainfrom
CharlieFRuan:fix/step-wise-advantage-1492
Apr 14, 2026
Merged

[fix][train][step-wise] Broadcast step-wise advantage with each step's own response_mask#1507
CharlieFRuan merged 3 commits intoNovaSky-AI:mainfrom
CharlieFRuan:fix/step-wise-advantage-1492

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 13, 2026

In step-wise training, advantages = last_step_advantages[traj_ids] copied the last step's already-masked tensor verbatim to every step of the trajectory. Because each step has its own right-aligned response_mask at its own token positions (prompts and responses vary in length across turns), the scalar advantage landed at the last step's token positions — which almost never overlap with the earlier step's own response positions. Earlier steps silently received advantages = 0 at their loss_mask=1 positions, so roughly (N-1)/N of step-samples are trained incorrectly.

Fix:

  • Call the estimator with an all-ones mask so it returns the scalar advantage at every position (outcome-based estimators produce scalar * mask).
  • Broadcast by indexing into traj_ids, then multiply by each step's own response_mask so values land at the right token positions.
  • Reject step_wise_trajectories=True + gae/reinforce++ in validate_cfg: temporal estimators produce per-token advantages that the scalar-broadcast model can't represent, and the step-wise path only sees the last step's slice anyway (no cross-step credit).
  • Add regression test test_calc_advantages_and_returns_step_wise_broadcast with a realistic right-aligned batch, plus a parametrized test_validate_cfg_step_wise_estimator_compatibility covering all 5 built-in estimators.
  • Document the outcome-only constraint in the step-wise tutorial.

Fixes #1492.

Testing

https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/reports/PR-1285-search-r1-PAD-left-only-baseline-after-PR-and-step-wise---VmlldzoxNjE3MTA0OQ

Since we will run into #1501 at HEAD, we make this hacky hot fix, mainly the pad_batch() change: CharlieFRuan@f2783a4

    if self.cfg.generator.step_wise_trajectories:
        n_samples = self.cfg.generator.n_samples_per_prompt
        pad_target = max(
            self.cfg.trainer.policy_mini_batch_size * n_samples,
            self.cfg.trainer.critic_mini_batch_size * n_samples,
            dp_size,
        )
    else:
        pad_target = dp_size
    pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size

We compare against the red curve from #1285

Note that the actual curve should converge slower than this after we merge in #1483, in which each mini batch will truly be one optimizer step. Currently we can have multiple optimizer step due to the number of turns can inflate the number of sequences in each mini batch.

image

The reward curve being higher is likely due to:

  1. The fix we apply in this PR
  2. In the previous PR (red line), we silently strip the amount of data and only train num_mini_batches = len(data) // mini_batch_size, which will no longer be an issue with the pad_target hacky fix above, or properly fixed after [fix][train] Prompt-based mini-batching for step-wise training #1483

Open with Devin

claude and others added 2 commits April 13, 2026 22:53
…nse_mask

In step-wise training, `advantages = last_step_advantages[traj_ids]` copied
the last step's already-masked tensor verbatim to every step of the
trajectory. Because each step has its own right-aligned response_mask at
its own token positions (prompts and responses vary in length across
turns), the scalar advantage landed at the last step's token positions —
which almost never overlap with the earlier step's own response positions.
Earlier steps silently received advantages = 0 at their loss_mask=1
positions, so roughly (N-1)/N of step-samples contributed no gradient
signal.

Fix:
  - Call the estimator with an all-ones mask so it returns the scalar
    advantage at every position (outcome-based estimators produce
    `scalar * mask`).
  - Broadcast by indexing into `traj_ids`, then multiply by each step's
    own `response_mask` so values land at the right token positions.
  - Reject `step_wise_trajectories=True` + `gae`/`reinforce++` in
    `validate_cfg`: temporal estimators produce per-token advantages that
    the scalar-broadcast model can't represent, and the step-wise path
    only sees the last step's slice anyway (no cross-step credit).
  - Add regression test `test_calc_advantages_and_returns_step_wise_broadcast`
    with a realistic right-aligned batch, plus a parametrized
    `test_validate_cfg_step_wise_estimator_compatibility` covering all
    5 built-in estimators.
  - Document the outcome-only constraint in the step-wise tutorial.

See NovaSky-AI#1492.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request fixes an issue with advantage broadcasting in step-wise training by ensuring trajectory-level rewards are correctly distributed across all steps using per-step response masks. It also adds configuration validation to restrict step-wise training to outcome-based advantage estimators (GRPO, RLOO, MAXRL), updates the documentation, and includes new regression tests. I have no feedback to provide.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

@SumanthRH
Copy link
Copy Markdown
Member

Thanks! Would be good to get a E2E run!

@CharlieFRuan CharlieFRuan merged commit 76b7187 into NovaSky-AI:main Apr 14, 2026
5 of 7 checks passed
@CharlieFRuan CharlieFRuan deleted the fix/step-wise-advantage-1492 branch April 14, 2026 04:46
CharlieFRuan added a commit that referenced this pull request Apr 19, 2026
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529).
The original E2E fix's `pad_batch` change is dropped since #1529's
prompt-based mini-batch boundaries removed the need to pad to
`mini_batch_size * n_samples`.

- merge_stepwise_output() in trainer_utils.py collapses multi-turn
  step-wise GeneratorOutput sequences into single sequences when
  consecutive turns share a common prefix, reducing training cost
  from O(T^2) to O(T).
- trainer.py: call merge before extracting generator fields, update
  uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge.
- Add generator.merge_stepwise_output config flag.
- run_search.sh: MERGE_STEPWISE env var.
- 16 CPU-only tests covering all 3 merging cases, partial merges, and
  validation asserts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[train] Incorrect advantage assignment for step_wise_trajectores in RayPPOTrainer?

3 participants