[train][step-wise] Add prefix-aware merging for step-wise training#1532
[train][step-wise] Add prefix-aware merging for step-wise training#1532CharlieFRuan merged 2 commits intomainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces prefix-aware merging for stepwise trajectories, allowing multi-turn sequences to be collapsed into single sequences before training. The implementation includes new configuration parameters, integration into the training pipeline with associated metrics, and the core merging logic in trainer_utils.py. Comprehensive unit tests were added to cover various merging scenarios and edge cases. Feedback focused on removing redundant list() calls within the merging utility functions to improve code efficiency.
| if acc_logprobs is not None: | ||
| acc_logprobs.extend(gen_out["rollout_logprobs"][i]) | ||
| if acc_rewards_tokens is not None: | ||
| acc_rewards_tokens.extend(list(gen_out["rewards"][i])) |
There was a problem hiding this comment.
The previous ``sum([go[key] for go in gens], [])`` pattern repeatedly rebuilds the running concatenation, making the flattening step O(K^2 * L̄) in the number of GeneratorOutputs. Replace with an explicit extend loop (O(N_total)). No behavior change, no signature change. Benchmarked with 8 trajectories per GeneratorOutput, 64k-token response_ids / loss_masks / rollout_logprobs, six flatten calls per concat: K=128 (1024 trajectories): 0.6ms -> 0.1ms ( 7.6x) K=512 (4096 trajectories): 8.6ms -> 0.2ms (49.8x) The speedup grows quadratically with K, which matters when concat is called on per-trajectory chunks (e.g. the prefix-aware merging in #1532). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous ``sum([go[key] for go in gens], [])`` pattern repeatedly rebuilds the running concatenation, making the flattening step O(K^2 * L̄) in the number of GeneratorOutputs. Replace with an explicit extend loop (O(N_total)). No behavior change, no signature change. Benchmarked with 8 trajectories per GeneratorOutput, 64k-token response_ids / loss_masks / rollout_logprobs, six flatten calls per concat: K=128 (1024 trajectories): 0.6ms -> 0.1ms ( 7.6x) K=512 (4096 trajectories): 8.6ms -> 0.2ms (49.8x) The speedup grows quadratically with K, which matters when concat is called on per-trajectory chunks (e.g. the prefix-aware merging in #1532). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
) ## Summary \`concatenate_generator_outputs\` used the \`sum([go[key] for go in gens], [])\` pattern to flatten each list-valued field, which is **O(K² · L̄)** — every \`+\` copies the running result. Replace with an explicit extend loop (**O(N_total)**). No signature change; no behavior change (existing test passes unmodified). Also move related tests to `tests/train/generators/test_generator_output_utils.py` ## Benchmark Config: 8 trajectories per GeneratorOutput, 64k-token \`response_ids\`/\`loss_masks\`/\`rollout_logprobs\`, six flatten calls per concat (as currently done). | K (GeneratorOutputs) | Total trajectories | Old (sum,[]) | New (extend) | Speedup | |---:|---:|---:|---:|---:| | 128 | 1,024 | 0.6 ms | 0.1 ms | **7.6×** | | 512 | 4,096 | 8.6 ms | 0.2 ms | **49.8×** | Speedup grows quadratically with K. This matters when concat is called on per-trajectory chunks — e.g. the prefix-aware merging work in #1532 calls \`concatenate_generator_outputs\` with K = number of trajectories (~2560 in a typical SearchR1 run), which would extrapolate to ~200 ms under the old path and sub-millisecond under the new one. ## Test plan - [x] \`pytest tests/train/generators/test_skyrl_gym_generator.py::test_generator_output_concatenation\` — passes unchanged (no signature or behavior change). - [x] \`pytest tests/train/generators/ tests/train/test_generator_postprocess.py tests/train/test_trainer_utils.py\` — 137/137 pass. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
c4c24f4 to
6468c30
Compare
6468c30 to
fcfaa5d
Compare
There was a problem hiding this comment.
(Refers to lines 620-634)
🚩 Merged sequences could exceed max_seq_len causing silent truncation
After merging multiple turns into a single sequence, the combined prompt + response length could significantly exceed max_seq_len configured for training. The truncation happens silently in convert_prompts_responses_to_batch_tensors at skyrl/train/trainer.py:668. For multi-turn scenarios with long observations, this could truncate the final reward-bearing tokens, potentially degrading training signal. This isn't a bug per se (the truncation logic is pre-existing), but it's a behavioral consideration unique to the merge feature that users should be aware of. Logging a warning when sequences are truncated after merge could be valuable.
Was this helpful? React with 👍 or 👎 to provide feedback.
Identical to #1538, but this one messed up the commit message.
Summary
This PR implements prefix-aware merging for step-wise training, guarded by a flag
cfg.generator.merge_stepwise_outputthat defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a singleGeneratorOutputentry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns).merge_stepwise_output()inskyrl/train/utils/trainer_utils.pyimplements greedy merging: for consecutive turns in the same trajectory whereprompt[i] + response[i]is a prefix ofprompt[i+1], merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (loss_masks,rewards,rollout_logprobs) align accordingly; per-turn fields (stop_reason,is_last_step,trajectory_id) take the last turn's value.RayPPOTrainer.postprocess_generator_outputcallsmerge_stepwise_outputwhengenerator.merge_stepwise_output=true, updatesuidsfrom the mergedtrajectory_ids, and logsgenerate/num_seq_{before,after}_merge.uidsmay need to be modified, update the signature ofpostprocess_generator_outputto return bothgenerator_outputanduids, changing various caller placesgenerator.merge_stepwise_outputconfig flag (default false).examples/train/search/run_search.shacceptsMERGE_STEPWISE=trueenv var to pass the flag through.tests/train/test_merge_stepwise_output.pycover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts.Test plan
pytest tests/train/test_merge_stepwise_output.py— 16 passedpytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py— 58 passed (existing tests unaffected)MERGE_STEPWISE=true.Curves
With pricesly the same setup as #1529 , we do:
Below is the curves (light blue is this PR).
batch_num_seqdecreased roughly 500 for each step. With non-step wise, it is 512 * 5 = 2560 sequences. We were not able to merge everything primarily due to how Qwen2.5 retokenizes<thinkinto<think(or something along the line for these thinking tokens).convert_to_training_inputand hence the slight increase. The spikes are likely due to how we are juggling inside the same event loop across all trajectories' interaction with the environmentCo-authored-by: Deep Sheth deepsheth3@users.noreply.github.com
🤖 Generated with Claude Code