Skip to content

[train][step-wise] Add prefix-aware merging for step-wise training#1532

Merged
CharlieFRuan merged 2 commits intomainfrom
charlie/merge-stepwise-output
Apr 20, 2026
Merged

[train][step-wise] Add prefix-aware merging for step-wise training#1532
CharlieFRuan merged 2 commits intomainfrom
charlie/merge-stepwise-output

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 19, 2026

Identical to #1538, but this one messed up the commit message.

Summary

This PR implements prefix-aware merging for step-wise training, guarded by a flag cfg.generator.merge_stepwise_output that defaults to False. During step-wise training, within a trajectory, when consecutive steps share the same prefix (i.e. no re-tokenization drift or context management like thinking token stripping), we collapse into a single GeneratorOutput entry. This can reduce the O(T²) training cost introduced by step-wise (T being number of turns).

  • merge_stepwise_output() in skyrl/train/utils/trainer_utils.py implements greedy merging: for consecutive turns in the same trajectory where prompt[i] + response[i] is a prefix of prompt[i+1], merge into one entry. Response tokens concatenated with the observation-delta (loss-masked to 0) between turns; per-token fields (loss_masks, rewards, rollout_logprobs) align accordingly; per-turn fields (stop_reason, is_last_step, trajectory_id) take the last turn's value.
  • RayPPOTrainer.postprocess_generator_output calls merge_stepwise_output when generator.merge_stepwise_output=true, updates uids from the merged trajectory_ids, and logs generate/num_seq_{before,after}_merge.
    • Since uids may need to be modified, update the signature of postprocess_generator_output to return both generator_output and uids, changing various caller places
  • New generator.merge_stepwise_output config flag (default false).
  • examples/train/search/run_search.sh accepts MERGE_STEPWISE=true env var to pass the flag through.
  • 16 CPU-only unit tests in tests/train/test_merge_stepwise_output.py cover the three merge cases, partial merges, prefix mismatches, single-turn passthrough, per-trajectory scalar rewards, and required-field asserts.

Test plan

  • pytest tests/train/test_merge_stepwise_output.py — 16 passed
  • pytest tests/train/test_trainer_utils.py tests/train/test_prompt_mini_batch.py — 58 passed (existing tests unaffected)
  • E2E: Search-R1 step-wise GRPO run on Qwen2.5-3B-Instruct, 8×H100, MERGE_STEPWISE=true.

Curves

With pricesly the same setup as #1529 , we do:

MERGE_STEPWISE=true USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \
  generator.inference_engine.num_engines=8 \
  generator.inference_engine.tensor_parallel_size=1

Below is the curves (light blue is this PR).

image
  • The number of sequences batch_num_seq decreased roughly 500 for each step. With non-step wise, it is 512 * 5 = 2560 sequences. We were not able to merge everything primarily due to how Qwen2.5 retokenizes < think into <think (or something along the line for these thinking tokens).
  • We also see that avg_num_tokens increases -- note that this field still records the per-turn average rather than trajectory.
image
  • Reward is pretty much identical
image
  • Timing-wise, we indeed see improved training time thanks to less number of sequences.
  • The run happened when merging is done in convert_to_training_input and hence the slight increase. The spikes are likely due to how we are juggling inside the same event loop across all trajectories' interaction with the environment

Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com

🤖 Generated with Claude Code


Open in Devin Review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces prefix-aware merging for stepwise trajectories, allowing multi-turn sequences to be collapsed into single sequences before training. The implementation includes new configuration parameters, integration into the training pipeline with associated metrics, and the core merging logic in trainer_utils.py. Comprehensive unit tests were added to cover various merging scenarios and edge cases. Feedback focused on removing redundant list() calls within the merging utility functions to improve code efficiency.

Comment thread skyrl/train/utils/trainer_utils.py Outdated
Comment thread skyrl/train/utils/trainer_utils.py Outdated
Comment thread skyrl/train/utils/trainer_utils.py Outdated
Comment thread skyrl/train/utils/trainer_utils.py Outdated
if acc_logprobs is not None:
acc_logprobs.extend(gen_out["rollout_logprobs"][i])
if acc_rewards_tokens is not None:
acc_rewards_tokens.extend(list(gen_out["rewards"][i]))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The extend method accepts any iterable and will copy the elements. The explicit conversion to list() is redundant if the source is already a list.

Suggested change
acc_rewards_tokens.extend(list(gen_out["rewards"][i]))
acc_rewards_tokens.extend(gen_out["rewards"][i])

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

devin-ai-integration[bot]

This comment was marked as resolved.

CharlieFRuan added a commit that referenced this pull request Apr 19, 2026
The previous ``sum([go[key] for go in gens], [])`` pattern repeatedly
rebuilds the running concatenation, making the flattening step
O(K^2 * L̄) in the number of GeneratorOutputs. Replace with an
explicit extend loop (O(N_total)). No behavior change, no signature
change.

Benchmarked with 8 trajectories per GeneratorOutput, 64k-token
response_ids / loss_masks / rollout_logprobs, six flatten calls
per concat:

  K=128 (1024 trajectories): 0.6ms -> 0.1ms  ( 7.6x)
  K=512 (4096 trajectories): 8.6ms -> 0.2ms  (49.8x)

The speedup grows quadratically with K, which matters when concat
is called on per-trajectory chunks (e.g. the prefix-aware merging
in #1532).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CharlieFRuan added a commit that referenced this pull request Apr 20, 2026
The previous ``sum([go[key] for go in gens], [])`` pattern repeatedly
rebuilds the running concatenation, making the flattening step
O(K^2 * L̄) in the number of GeneratorOutputs. Replace with an
explicit extend loop (O(N_total)). No behavior change, no signature
change.

Benchmarked with 8 trajectories per GeneratorOutput, 64k-token
response_ids / loss_masks / rollout_logprobs, six flatten calls
per concat:

  K=128 (1024 trajectories): 0.6ms -> 0.1ms  ( 7.6x)
  K=512 (4096 trajectories): 8.6ms -> 0.2ms  (49.8x)

The speedup grows quadratically with K, which matters when concat
is called on per-trajectory chunks (e.g. the prefix-aware merging
in #1532).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CharlieFRuan added a commit that referenced this pull request Apr 20, 2026
)

## Summary

\`concatenate_generator_outputs\` used the \`sum([go[key] for go in
gens], [])\` pattern to flatten each list-valued field, which is **O(K²
· L̄)** — every \`+\` copies the running result. Replace with an
explicit extend loop (**O(N_total)**). No signature change; no behavior
change (existing test passes unmodified).

Also move related tests to
`tests/train/generators/test_generator_output_utils.py`

## Benchmark

Config: 8 trajectories per GeneratorOutput, 64k-token
\`response_ids\`/\`loss_masks\`/\`rollout_logprobs\`, six flatten calls
per concat (as currently done).

| K (GeneratorOutputs) | Total trajectories | Old (sum,[]) | New
(extend) | Speedup |
|---:|---:|---:|---:|---:|
| 128 | 1,024 | 0.6 ms | 0.1 ms | **7.6×** |
| 512 | 4,096 | 8.6 ms | 0.2 ms | **49.8×** |

Speedup grows quadratically with K. This matters when concat is called
on per-trajectory chunks — e.g. the prefix-aware merging work in #1532
calls \`concatenate_generator_outputs\` with K = number of trajectories
(~2560 in a typical SearchR1 run), which would extrapolate to ~200 ms
under the old path and sub-millisecond under the new one.

## Test plan

- [x] \`pytest
tests/train/generators/test_skyrl_gym_generator.py::test_generator_output_concatenation\`
— passes unchanged (no signature or behavior change).
- [x] \`pytest tests/train/generators/
tests/train/test_generator_postprocess.py
tests/train/test_trainer_utils.py\` — 137/137 pass.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@CharlieFRuan CharlieFRuan force-pushed the charlie/merge-stepwise-output branch from c4c24f4 to 6468c30 Compare April 20, 2026 01:27
Comment thread skyrl/train/trainer.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Refers to lines 620-634)

🚩 Merged sequences could exceed max_seq_len causing silent truncation

After merging multiple turns into a single sequence, the combined prompt + response length could significantly exceed max_seq_len configured for training. The truncation happens silently in convert_prompts_responses_to_batch_tensors at skyrl/train/trainer.py:668. For multi-turn scenarios with long observations, this could truncate the final reward-bearing tokens, potentially degrading training signal. This isn't a bug per se (the truncation logic is pre-existing), but it's a behavioral consideration unique to the merge feature that users should be aware of. Logging a warning when sequences are truncated after merge could be valuable.

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

@CharlieFRuan CharlieFRuan merged commit 8180541 into main Apr 20, 2026
7 checks passed
@CharlieFRuan CharlieFRuan deleted the charlie/merge-stepwise-output branch April 20, 2026 21:17
@CharlieFRuan CharlieFRuan restored the charlie/merge-stepwise-output branch April 20, 2026 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant