Skip to content

Add prefix-aware merging for step-wise training#1479

Closed
CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
CharlieFRuan:prefix-aware-merge
Closed

Add prefix-aware merging for step-wise training#1479
CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
CharlieFRuan:prefix-aware-merge

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 8, 2026

Closed in favor of #1532

Summary

Adds prefix-aware merging for step-wise (multi-turn) training. When merge_stepwise_output=true, consecutive turns within a trajectory are merged into single sequences when prompt[i] + response[i] is a prefix of prompt[i+1]. This reduces the number of training sequences and avoids redundant computation on shared prefixes.

  • Implements merge_stepwise_output() and _merge_single_trajectory() in trainer_utils.py
  • Per-token fields (loss_masks, logprobs, rewards) are concatenated, with zeros for observation-delta positions
  • Per-turn fields (stop_reason, is_last_step, trajectory_id) taken from the last turn in each merge group
  • Greedy merging: when prefix condition fails, the current group is flushed and a new one starts
  • Adds merge_stepwise_output config flag (default false)
  • Adds MERGE_STEPWISE env var to run_search.sh
  • Logs generate/num_seq_before_merge and generate/num_seq_after_merge metrics to wandb
  • 16 CPU-only unit tests covering all merge cases (assistant-only response, response with observation, combinations, multi-trajectory, partial merge, scalar rewards, etc.)

Test plan

  • All 16 unit tests in tests/train/test_merge_stepwise_output.py pass
  • E2E run with search_r1 step-wise + merge enabled completes 100 steps successfully
  • Merge metrics logged to wandb (generate/num_seq_before_merge, generate/num_seq_after_merge)

Co-authored-by: Deep Sheth deepsheth3@users.noreply.github.com

🤖 Generated with Claude Code


Open with Devin

CharlieFRuan and others added 2 commits April 8, 2026 20:48
Implement merge_stepwise_output() that collapses multi-turn step-wise
GeneratorOutput sequences into single sequences when turns share a
common prefix, reducing training cost from O(T^2) to O(T).

- _merge_single_trajectory: GeneratorOutput in/out for a single trajectory
- merge_stepwise_output: splits by is_last_step, merges each, concatenates
- Add merge_stepwise_output config flag and generate/num_seq_before_merge,
  generate/num_seq_after_merge metrics
- 15 CPU-only unit tests covering all 3 merging cases

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Fix uids mismatch: update uids from merged trajectory_ids after merge
- Fix mini-batch divisibility: pad_batch now pads to mini_batch_size
  (not just dp_size) for step-wise training, fixing a pre-existing
  issue where variable step counts weren't divisible by mini_batch_size
- Add debug logging for prefix mismatch diagnostics (first 3 per trajectory)
- Add MERGE_STEPWISE env var to run_search.sh
- Add test_partial_merge_within_trajectory test case

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces prefix-aware merging for step-wise trajectories, allowing multi-turn sequences to be collapsed into single sequences before training. Key changes include the implementation of the merging logic in trainer_utils.py, updates to the Trainer to handle merged outputs and adjust batch padding, and the addition of a MERGE_STEPWISE flag in the search execution script. Feedback was provided regarding code duplication in the _merge_single_trajectory function, where the initialization and reset logic for accumulators could be refactored into a helper function.

Comment on lines +811 to +817
# Accumulator for the current merge group
acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0])
acc_response: List[int] = list(gen_out["response_ids"][0])
acc_loss_mask: List[int] = list(gen_out["loss_masks"][0])
acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None
acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if token_level_rewards else None
last = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to initialize the accumulators for a merge group (lines 812-817) is nearly identical to the logic for resetting them upon a prefix mismatch (lines 853-858). To reduce code duplication and improve maintainability, consider extracting this logic into a nested helper function that can be called in both places.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 potential issue.

View 4 additional findings in Devin Review.

Open in Devin Review

Comment thread skyrl/train/trainer.py
Comment on lines +915 to +924
if self.cfg.generator.step_wise_trajectories:
n_samples = self.cfg.generator.n_samples_per_prompt
pad_target = max(
self.cfg.trainer.policy_mini_batch_size * n_samples,
self.cfg.trainer.critic_mini_batch_size * n_samples,
dp_size,
)
else:
pad_target = dp_size
pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 pad_batch uses max() instead of lcm() for step-wise pad target, breaking divisibility guarantee

When step_wise_trajectories=True, pad_target is computed as max(policy_mini_batch_size * n, critic_mini_batch_size * n, dp_size). Padding the batch to a multiple of max(A, B, C) does not guarantee divisibility by all three values — only lcm(A, B, C) would. This causes stage_chunks at skyrl/backends/skyrl_train/distributed/dispatch.py:190-192 to assert-fail (len(data) % mini_batch_size == 0) when policy_mini_batch_size != critic_mini_batch_size.

Concrete failure example

With policy_mini_batch_size=256, critic_mini_batch_size=384, n_samples=5:

  • pad_target = max(1280, 1920, dp_size) = 1920
  • Batch padded to e.g. 1920
  • 1920 % 1280 = 640 ≠ 0 → policy stage_chunks asserts

Additionally, critic_mini_batch_size is unconditionally included even when self.has_critic is False, which can unnecessarily inflate pad_target above policy_mini_batch_size * n and break divisibility for the policy training step.

Prompt for agents
In pad_batch, the pad_target for step-wise training uses max() to combine policy_mini_batch_size * n_samples, critic_mini_batch_size * n_samples, and dp_size. However, max(A, B, C) does not guarantee the result is divisible by all three values — only lcm(A, B, C) does. Additionally, critic_mini_batch_size should only be included when self.has_critic is True.

Fix: Replace the max() call with math.lcm(), and only include critic_mini_batch_size when a critic model is configured. For example:

  from math import lcm
  pad_target = lcm(self.cfg.trainer.policy_mini_batch_size * n_samples, dp_size)
  if self.has_critic:
      pad_target = lcm(pad_target, self.cfg.trainer.critic_mini_batch_size * n_samples)

This ensures the padded batch size is divisible by all required mini_batch_sizes for stage_chunks.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

CharlieFRuan added a commit that referenced this pull request Apr 19, 2026
Rebase PR #1479 onto current main (post-PRs #1507/#1526/#1527/#1529).
The original E2E fix's `pad_batch` change is dropped since #1529's
prompt-based mini-batch boundaries removed the need to pad to
`mini_batch_size * n_samples`.

- merge_stepwise_output() in trainer_utils.py collapses multi-turn
  step-wise GeneratorOutput sequences into single sequences when
  consecutive turns share a common prefix, reducing training cost
  from O(T^2) to O(T).
- trainer.py: call merge before extracting generator fields, update
  uids from merged trajectory_ids, emit generate/num_seq_{before,after}_merge.
- Add generator.merge_stepwise_output config flag.
- run_search.sh: MERGE_STEPWISE env var.
- 16 CPU-only tests covering all 3 merging cases, partial merges, and
  validation asserts.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant