Add prefix-aware merging for step-wise training#1479
Add prefix-aware merging for step-wise training#1479CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
Implement merge_stepwise_output() that collapses multi-turn step-wise GeneratorOutput sequences into single sequences when turns share a common prefix, reducing training cost from O(T^2) to O(T). - _merge_single_trajectory: GeneratorOutput in/out for a single trajectory - merge_stepwise_output: splits by is_last_step, merges each, concatenates - Add merge_stepwise_output config flag and generate/num_seq_before_merge, generate/num_seq_after_merge metrics - 15 CPU-only unit tests covering all 3 merging cases Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix uids mismatch: update uids from merged trajectory_ids after merge - Fix mini-batch divisibility: pad_batch now pads to mini_batch_size (not just dp_size) for step-wise training, fixing a pre-existing issue where variable step counts weren't divisible by mini_batch_size - Add debug logging for prefix mismatch diagnostics (first 3 per trajectory) - Add MERGE_STEPWISE env var to run_search.sh - Add test_partial_merge_within_trajectory test case Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces prefix-aware merging for step-wise trajectories, allowing multi-turn sequences to be collapsed into single sequences before training. Key changes include the implementation of the merging logic in trainer_utils.py, updates to the Trainer to handle merged outputs and adjust batch padding, and the addition of a MERGE_STEPWISE flag in the search execution script. Feedback was provided regarding code duplication in the _merge_single_trajectory function, where the initialization and reset logic for accumulators could be refactored into a helper function.
| # Accumulator for the current merge group | ||
| acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0]) | ||
| acc_response: List[int] = list(gen_out["response_ids"][0]) | ||
| acc_loss_mask: List[int] = list(gen_out["loss_masks"][0]) | ||
| acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None | ||
| acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if token_level_rewards else None | ||
| last = 0 |
There was a problem hiding this comment.
The logic to initialize the accumulators for a merge group (lines 812-817) is nearly identical to the logic for resetting them upon a prefix mismatch (lines 853-858). To reduce code duplication and improve maintainability, consider extracting this logic into a nested helper function that can be called in both places.
| if self.cfg.generator.step_wise_trajectories: | ||
| n_samples = self.cfg.generator.n_samples_per_prompt | ||
| pad_target = max( | ||
| self.cfg.trainer.policy_mini_batch_size * n_samples, | ||
| self.cfg.trainer.critic_mini_batch_size * n_samples, | ||
| dp_size, | ||
| ) | ||
| else: | ||
| pad_target = dp_size | ||
| pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size |
There was a problem hiding this comment.
🔴 pad_batch uses max() instead of lcm() for step-wise pad target, breaking divisibility guarantee
When step_wise_trajectories=True, pad_target is computed as max(policy_mini_batch_size * n, critic_mini_batch_size * n, dp_size). Padding the batch to a multiple of max(A, B, C) does not guarantee divisibility by all three values — only lcm(A, B, C) would. This causes stage_chunks at skyrl/backends/skyrl_train/distributed/dispatch.py:190-192 to assert-fail (len(data) % mini_batch_size == 0) when policy_mini_batch_size != critic_mini_batch_size.
Concrete failure example
With policy_mini_batch_size=256, critic_mini_batch_size=384, n_samples=5:
pad_target = max(1280, 1920, dp_size) = 1920- Batch padded to e.g.
1920 1920 % 1280 = 640 ≠ 0→ policystage_chunksasserts
Additionally, critic_mini_batch_size is unconditionally included even when self.has_critic is False, which can unnecessarily inflate pad_target above policy_mini_batch_size * n and break divisibility for the policy training step.
Prompt for agents
In pad_batch, the pad_target for step-wise training uses max() to combine policy_mini_batch_size * n_samples, critic_mini_batch_size * n_samples, and dp_size. However, max(A, B, C) does not guarantee the result is divisible by all three values — only lcm(A, B, C) does. Additionally, critic_mini_batch_size should only be included when self.has_critic is True.
Fix: Replace the max() call with math.lcm(), and only include critic_mini_batch_size when a critic model is configured. For example:
from math import lcm
pad_target = lcm(self.cfg.trainer.policy_mini_batch_size * n_samples, dp_size)
if self.has_critic:
pad_target = lcm(pad_target, self.cfg.trainer.critic_mini_batch_size * n_samples)
This ensures the padded batch size is divisible by all required mini_batch_sizes for stage_chunks.
Was this helpful? React with 👍 or 👎 to provide feedback.
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529). The original E2E fix's `pad_batch` change is dropped since #1529's prompt-based mini-batch boundaries removed the need to pad to `mini_batch_size * n_samples`. - merge_stepwise_output() in trainer_utils.py collapses multi-turn step-wise GeneratorOutput sequences into single sequences when consecutive turns share a common prefix, reducing training cost from O(T^2) to O(T). - trainer.py: call merge before extracting generator fields, update uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge. - Add generator.merge_stepwise_output config flag. - run_search.sh: MERGE_STEPWISE env var. - 16 CPU-only tests covering all 3 merging cases, partial merges, and validation asserts. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Closed in favor of #1532
Summary
Adds prefix-aware merging for step-wise (multi-turn) training. When
merge_stepwise_output=true, consecutive turns within a trajectory are merged into single sequences whenprompt[i] + response[i]is a prefix ofprompt[i+1]. This reduces the number of training sequences and avoids redundant computation on shared prefixes.merge_stepwise_output()and_merge_single_trajectory()intrainer_utils.pymerge_stepwise_outputconfig flag (defaultfalse)MERGE_STEPWISEenv var torun_search.shgenerate/num_seq_before_mergeandgenerate/num_seq_after_mergemetrics to wandbTest plan
tests/train/test_merge_stepwise_output.pypasssearch_r1step-wise + merge enabled completes 100 steps successfullygenerate/num_seq_before_merge,generate/num_seq_after_merge)Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com
🤖 Generated with Claude Code