[train][step-wise] Add prefix-aware merging for step-wise training#1538
[train][step-wise] Add prefix-aware merging for step-wise training#1538CharlieFRuan merged 2 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements prefix-aware merging for step-wise training, which collapses multi-turn sequences sharing a common prefix into single sequences to improve training efficiency. The changes include updating the postprocess_generator_output method signature to return both the generator output and updated UIDs across various trainers, adding the merge_stepwise_output configuration flag, and implementing the core merging logic in skyrl/train/generators/utils.py. A potential issue was identified in the merging logic where environment-specific rollout_metrics might be lost because they are overwritten during the concatenation of merged slices.
| "rewards": out_rewards, | ||
| "loss_masks": out_loss_masks, | ||
| "stop_reasons": out_stop_reasons, | ||
| "rollout_metrics": gen_out.get("rollout_metrics", None), |
There was a problem hiding this comment.
In _merge_single_trajectory, rollout_metrics is taken from the sliced gen_out. However, concatenate_generator_outputs (called at the end of merge_stepwise_output) will overwrite this field by re-calculating token statistics, which might cause environment-specific metrics to be lost in the returned GeneratorOutput. Although RayPPOTrainer already logs these metrics in the generate method, other potential callers of merge_stepwise_output might rely on the returned object containing the full metrics.
Before this PR, when training with Harbor, we rely on Harbor returning an `all_messages` field which contains a string chat history. We then re-tokenize it in SkyRL, compute loss masks, and feed it to the trainer via `GeneratorOutput`. This causes re-tokenization issues, and prevent us from doing fully async training (which requires `logprobs` for algorithmic correction, and `logprobs` will not match upon re-tokenization drift). This PR makes `HarborGenerator` perform step-wise training. We rely on setting `collect_rollout_details: true` (already did in `harbor_trial_config/default.yaml`). Harbor will then do per-turn book keeping. For each LLM invocation (i.e. turn), Harbor will record `prompt_token_ids`, `completion_token_ids`, and `logprobs`. Then, by setting the following configs in SkyRL, we can perform step-wise training while merging when possible: ```sh generator.step_wise_trajectories=true \ generator.merge_stepwise_output=true \ ``` ### Curve comparison https://wandb.ai/sky-posttraining-uc-berkeley/harbor/reports/PR1542-Harbor-step-wise-training-in-SkyRL--VmlldzoxNjYzNDU0Mg - Blue: this PR's sync training (with `token_mean`) - Pink: before this PR (with `seq_mean_token_sum_norm`, everything else the same) - Red: this PR fully async (with `train_batch_size=mini_batch_size=16`) <img width="681" height="307" alt="image" src="https://github.com/user-attachments/assets/32065b96-5b28-4afd-98bf-a11d4f335d1e" /> <img width="679" height="305" alt="image" src="https://github.com/user-attachments/assets/b20b9dd8-e5eb-4a79-86b4-3b66d66c2899" /> Besides, with `merge_stepwise_output`, we can shrink ~1000 sequences to ~300 sequences, improving training efficiencies. For sync run, it is 256 sequences (batch size 32 * 8) if everything can be merged. Number of sequences unmerged is roughly `256 * avg_num_turns`. For fully async, similar things except it has 128 sequences (batch size 16 * 8). Related PR: #1538. Note this merging is much better than PR 1538's test on Qwen2.5 for search-r1 (which suffered a lot from retokenization drift for <think> which Qwen2.5 is not familiar with). Also, for where merging fails, here is a report: https://gist.github.com/CharlieFRuan/b91cecfe891f9458c455b6f5e2f6af1d <img width="1024" height="300" alt="image" src="https://github.com/user-attachments/assets/901f7ec6-33da-41b6-b3b8-8b4e51fcd37b" />
Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com
Identical to #1532, which messed up the commit authors, and hence reverted and re-opened here.
Summary
This PR implements prefix-aware merging for step-wise training, guarded by a flag
cfg.generator.merge_stepwise_outputthat defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a singleGeneratorOutputentry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns).merge_stepwise_output()inskyrl/train/utils/trainer_utils.pyimplements greedy merging: for consecutive turns in the same trajectory whereprompt[i] + response[i]is a prefix ofprompt[i+1], merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (loss_masks,rewards,rollout_logprobs) align accordingly; per-turn fields (stop_reason,is_last_step,trajectory_id) take the last turn's value.RayPPOTrainer.postprocess_generator_outputcallsmerge_stepwise_outputwhengenerator.merge_stepwise_output=true, updatesuidsfrom the mergedtrajectory_ids, and logsgenerate/num_seq_{before,after}_merge.uidsmay need to be modified, update the signature ofpostprocess_generator_outputto return bothgenerator_outputanduids, changing various caller placesgenerator.merge_stepwise_outputconfig flag (default false).examples/train/search/run_search.shacceptsMERGE_STEPWISE=trueenv var to pass the flag through.tests/train/test_merge_stepwise_output.pycover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts.Test plan
pytest tests/train/test_merge_stepwise_output.py— 16 passedpytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py— 58 passed (existing tests unaffected)MERGE_STEPWISE=true.Curves
With pricesly the same setup as #1529 , we do:
Below is the curves (light blue is this PR).
batch_num_seqdecreased roughly 500 for each step. With non-step wise, it is 512 * 5 = 2560 sequences. We were not able to merge everything primarily due to how Qwen2.5 retokenizes<thinkinto<think(or something along the line for these thinking tokens).convert_to_training_inputand hence the slight increase. The spikes are likely due to how we are juggling inside the same event loop across all trajectories' interaction with the environmentCo-authored-by: Claude Opus 4.7 (1M context) noreply@anthropic.com
Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com
🤖 Generated with Claude Code