Skip to content

[train][step-wise] Add prefix-aware merging for step-wise training#1538

Merged
CharlieFRuan merged 2 commits intomainfrom
charlie/merge-stepwise-output
Apr 20, 2026
Merged

[train][step-wise] Add prefix-aware merging for step-wise training#1538
CharlieFRuan merged 2 commits intomainfrom
charlie/merge-stepwise-output

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 20, 2026

Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com

Identical to #1532, which messed up the commit authors, and hence reverted and re-opened here.

Summary

This PR implements prefix-aware merging for step-wise training, guarded by a flag cfg.generator.merge_stepwise_output that defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a single GeneratorOutput entry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns).

  • merge_stepwise_output() in skyrl/train/utils/trainer_utils.py implements greedy merging: for consecutive turns in the same trajectory where prompt[i] + response[i] is a prefix of prompt[i+1], merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (loss_masks, rewards, rollout_logprobs) align accordingly; per-turn fields (stop_reason, is_last_step, trajectory_id) take the last turn's value.
  • RayPPOTrainer.postprocess_generator_output calls merge_stepwise_output when generator.merge_stepwise_output=true, updates uids from the merged trajectory_ids, and logs generate/num_seq_{before,after}_merge.
    • Since uids may need to be modified, update the signature of postprocess_generator_output to return both generator_output and uids, changing various caller places
  • New generator.merge_stepwise_output config flag (default false).
  • examples/train/search/run_search.sh accepts MERGE_STEPWISE=true env var to pass the flag through.
  • 16 CPU-only unit tests in tests/train/test_merge_stepwise_output.py cover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts.

Test plan

  • pytest tests/train/test_merge_stepwise_output.py — 16 passed
  • pytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py — 58 passed (existing tests unaffected)
  • E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100, MERGE_STEPWISE=true.

Curves

With pricesly the same setup as #1529 , we do:

MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1

Below is the curves (light blue is this PR).

image
  • The number of sequences batch_num_seq decreased roughly 500 for each step. With non-step wise, it is 512 * 5 = 2560 sequences. We were not able to merge everything primarily due to how Qwen2.5 retokenizes < think into <think (or something along the line for these thinking tokens).
  • We also see that avg_num_tokens increases -- note that this field still records the per-turn average rather than trajectory.
image
  • Reward is pretty much identical
image
  • Timing-wise, we indeed see improved training time thanks to less number of sequences.
  • The run happened when merging is done in convert_to_training_input and hence the slight increase. The spikes are likely due to how we are juggling inside the same event loop across all trajectories' interaction with the environment

Co-authored-by: Claude Opus 4.7 (1M context) noreply@anthropic.com
Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com

🤖 Generated with Claude Code


Open in Devin Review

@CharlieFRuan CharlieFRuan merged commit 880f0f7 into main Apr 20, 2026
11 of 12 checks passed
Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements prefix-aware merging for step-wise training, which collapses multi-turn sequences sharing a common prefix into single sequences to improve training efficiency. The changes include updating the postprocess_generator_output method signature to return both the generator output and updated UIDs across various trainers, adding the merge_stepwise_output configuration flag, and implementing the core merging logic in skyrl/train/generators/utils.py. A potential issue was identified in the merging logic where environment-specific rollout_metrics might be lost because they are overwritten during the concatenation of merged slices.

"rewards": out_rewards,
"loss_masks": out_loss_masks,
"stop_reasons": out_stop_reasons,
"rollout_metrics": gen_out.get("rollout_metrics", None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _merge_single_trajectory, rollout_metrics is taken from the sliced gen_out. However, concatenate_generator_outputs (called at the end of merge_stepwise_output) will overwrite this field by re-calculating token statistics, which might cause environment-specific metrics to be lost in the returned GeneratorOutput. Although RayPPOTrainer already logs these metrics in the generate method, other potential callers of merge_stepwise_output might rely on the returned object containing the full metrics.

CharlieFRuan added a commit that referenced this pull request Apr 23, 2026
Before this PR, when training with Harbor, we rely on Harbor returning
an `all_messages` field which contains a string chat history. We then
re-tokenize it in SkyRL, compute loss masks, and feed it to the trainer
via `GeneratorOutput`.

This causes re-tokenization issues, and prevent us from doing fully
async training (which requires `logprobs` for algorithmic correction,
and `logprobs` will not match upon re-tokenization drift).

This PR makes `HarborGenerator` perform step-wise training.

We rely on setting `collect_rollout_details: true` (already did in
`harbor_trial_config/default.yaml`). Harbor will then do per-turn book
keeping. For each LLM invocation (i.e. turn), Harbor will record
`prompt_token_ids`, `completion_token_ids`, and `logprobs`.

Then, by setting the following configs in SkyRL, we can perform
step-wise training while merging when possible:

```sh
  generator.step_wise_trajectories=true \
  generator.merge_stepwise_output=true \
```

### Curve comparison


https://wandb.ai/sky-posttraining-uc-berkeley/harbor/reports/PR1542-Harbor-step-wise-training-in-SkyRL--VmlldzoxNjYzNDU0Mg

- Blue: this PR's sync training (with `token_mean`)
- Pink: before this PR (with `seq_mean_token_sum_norm`, everything else
the same)
- Red: this PR fully async (with `train_batch_size=mini_batch_size=16`)

<img width="681" height="307" alt="image"
src="https://github.com/user-attachments/assets/32065b96-5b28-4afd-98bf-a11d4f335d1e"
/>

<img width="679" height="305" alt="image"
src="https://github.com/user-attachments/assets/b20b9dd8-e5eb-4a79-86b4-3b66d66c2899"
/>

Besides, with `merge_stepwise_output`, we can shrink ~1000 sequences to
~300 sequences, improving training efficiencies. For sync run, it is 256
sequences (batch size 32 * 8) if everything can be merged. Number of
sequences unmerged is roughly `256 * avg_num_turns`. For fully async,
similar things except it has 128 sequences (batch size 16 * 8). Related
PR: #1538. Note this merging is much better than PR 1538's test on
Qwen2.5 for search-r1 (which suffered a lot from retokenization drift
for <think> which Qwen2.5 is not familiar with). Also, for where merging
fails, here is a report:
https://gist.github.com/CharlieFRuan/b91cecfe891f9458c455b6f5e2f6af1d

<img width="1024" height="300" alt="image"
src="https://github.com/user-attachments/assets/901f7ec6-33da-41b6-b3b8-8b4e51fcd37b"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant