Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/train/search/run_search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ else
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false"
fi

: "${MERGE_STEPWISE:=false}"

STEP_WISE_ARGS=""
if [ "$STEP_WISE" = "true" ]; then
STEP_WISE_ARGS="generator.step_wise_trajectories=true"
Expand All @@ -56,6 +58,9 @@ if [ "$STEP_WISE" = "true" ]; then
echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically."
MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true"
fi
if [ "$MERGE_STEPWISE" = "true" ]; then
STEP_WISE_ARGS="$STEP_WISE_ARGS generator.merge_stepwise_output=true"
fi
fi

uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \
Expand Down
3 changes: 3 additions & 0 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,9 @@ class GeneratorConfig(BaseConfig):
"""Can differ from the trainer's ``rope_scaling``, useful for thinking models."""
rope_theta: Optional[float] = None
step_wise_trajectories: bool = False
merge_stepwise_output: bool = False
"""When True (and step_wise_trajectories is True), apply prefix-aware merging
to collapse multi-turn step-wise sequences into single sequences before training."""

def __post_init__(self):

Expand Down
1 change: 1 addition & 0 deletions skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ generator:
rope_theta: ${trainer.rope_theta}

step_wise_trajectories: false
merge_stepwise_output: false

environment:
env_class: "gsm8k"
Expand Down
35 changes: 33 additions & 2 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
build_dataloader,
cleanup_old_checkpoints,
extract_step_from_path,
merge_stepwise_output,
run_on_each_node,
validate_consistency_for_latest_checkpoint,
validate_generator_output,
Expand Down Expand Up @@ -600,6 +601,22 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]:

def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch:
"""Converts lists to a padded batch of tensors for training"""
if self.cfg.generator.merge_stepwise_output:
num_seq_before_merge = len(generator_output["response_ids"])
generator_output = merge_stepwise_output(generator_output)
num_seq_after_merge = len(generator_output["response_ids"])
logger.info(
f"merge_stepwise_output: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences"
)
self.all_metrics.update(
{
"generate/num_seq_before_merge": num_seq_before_merge,
"generate/num_seq_after_merge": num_seq_after_merge,
}
)
# Update uids to match the merged output length
uids = [tid.instance_id for tid in generator_output["trajectory_ids"]]

prompt_ids: List[List[int]] = generator_output["prompt_token_ids"]
response_ids: List[List[int]] = generator_output["response_ids"]
rewards: List[List[float]] = generator_output["rewards"]
Expand Down Expand Up @@ -886,11 +903,25 @@ def dump_data(self, data: TrainingInputBatch, file_name: str):
data.save(data_save_dir / f"{file_name}.pkl")

def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
"""Pad the batch to be divisible by dp size"""
"""Pad the batch to be divisible by the training mini-batch size.

For step-wise training the batch size is variable (depends on how many
turns each trajectory takes), so we pad to the full mini_batch_size to
satisfy the divisibility requirement in stage_data.
"""
import math

dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
if self.cfg.generator.step_wise_trajectories:
n_samples = self.cfg.generator.n_samples_per_prompt
pad_target = max(
self.cfg.trainer.policy_mini_batch_size * n_samples,
self.cfg.trainer.critic_mini_batch_size * n_samples,
dp_size,
)
else:
pad_target = dp_size
pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size
Comment on lines +915 to +924
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 pad_batch uses max() instead of lcm() for step-wise pad target, breaking divisibility guarantee

When step_wise_trajectories=True, pad_target is computed as max(policy_mini_batch_size * n, critic_mini_batch_size * n, dp_size). Padding the batch to a multiple of max(A, B, C) does not guarantee divisibility by all three values — only lcm(A, B, C) would. This causes stage_chunks at skyrl/backends/skyrl_train/distributed/dispatch.py:190-192 to assert-fail (len(data) % mini_batch_size == 0) when policy_mini_batch_size != critic_mini_batch_size.

Concrete failure example

With policy_mini_batch_size=256, critic_mini_batch_size=384, n_samples=5:

  • pad_target = max(1280, 1920, dp_size) = 1920
  • Batch padded to e.g. 1920
  • 1920 % 1280 = 640 ≠ 0 → policy stage_chunks asserts

Additionally, critic_mini_batch_size is unconditionally included even when self.has_critic is False, which can unnecessarily inflate pad_target above policy_mini_batch_size * n and break divisibility for the policy training step.

Prompt for agents
In pad_batch, the pad_target for step-wise training uses max() to combine policy_mini_batch_size * n_samples, critic_mini_batch_size * n_samples, and dp_size. However, max(A, B, C) does not guarantee the result is divisible by all three values — only lcm(A, B, C) does. Additionally, critic_mini_batch_size should only be included when self.has_critic is True.

Fix: Replace the max() call with math.lcm(), and only include critic_mini_batch_size when a critic model is configured. For example:

  from math import lcm
  pad_target = lcm(self.cfg.trainer.policy_mini_batch_size * n_samples, dp_size)
  if self.has_critic:
      pad_target = lcm(pad_target, self.cfg.trainer.critic_mini_batch_size * n_samples)

This ensures the padded batch size is divisible by all required mini_batch_sizes for stage_chunks.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

new_tensors = {}
training_input.metadata["pad_size"] = pad_size
if pad_size == 0:
Expand Down
192 changes: 192 additions & 0 deletions skyrl/train/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,198 @@ def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses:
)


def _is_prefix(maybe_prefix: List[int], candidate: List[int]) -> bool:
"""Check if maybe_prefix is a prefix of candidate (or equal)."""
if len(maybe_prefix) > len(candidate):
return False
return maybe_prefix == candidate[: len(maybe_prefix)]


def _slice_generator_output(generator_output: GeneratorOutput, indices: List[int]) -> GeneratorOutput:
"""Slice a GeneratorOutput to keep only the entries at the given indices."""
return {
"prompt_token_ids": [generator_output["prompt_token_ids"][i] for i in indices],
"response_ids": [generator_output["response_ids"][i] for i in indices],
"rewards": [generator_output["rewards"][i] for i in indices],
"loss_masks": [generator_output["loss_masks"][i] for i in indices],
"stop_reasons": (
[generator_output["stop_reasons"][i] for i in indices]
if generator_output.get("stop_reasons") is not None
else None
),
"rollout_metrics": generator_output.get("rollout_metrics"),
"rollout_logprobs": (
[generator_output["rollout_logprobs"][i] for i in indices]
if generator_output.get("rollout_logprobs") is not None
else None
),
"trajectory_ids": (
[generator_output["trajectory_ids"][i] for i in indices]
if generator_output.get("trajectory_ids") is not None
else None
),
"rollout_expert_indices": (
[generator_output["rollout_expert_indices"][i] for i in indices]
if generator_output.get("rollout_expert_indices") is not None
else None
),
"is_last_step": (
[generator_output["is_last_step"][i] for i in indices]
if generator_output.get("is_last_step") is not None
else None
),
}


def _merge_single_trajectory(gen_out: GeneratorOutput) -> GeneratorOutput:
"""Greedily merge turns of a single trajectory using prefix matching.

Takes a GeneratorOutput whose entries all belong to the same trajectory
and returns a merged GeneratorOutput (potentially fewer entries).
"""
n = len(gen_out["response_ids"])
token_level_rewards = n > 0 and isinstance(gen_out["rewards"][0], list)
has_logprobs = gen_out.get("rollout_logprobs") is not None
has_stop_reasons = gen_out.get("stop_reasons") is not None

# Per-field output accumulators (lists of per-entry values)
out_prompt_ids: List[List[int]] = []
out_response_ids: List[List[int]] = []
out_loss_masks: List[List[int]] = []
out_logprobs: Optional[List[List[float]]] = [] if has_logprobs else None
out_rewards: list = []
out_stop_reasons: Optional[List[str]] = [] if has_stop_reasons else None
out_trajectory_ids: list = []
out_is_last_step: List[bool] = []

# Accumulator for the current merge group
acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0])
acc_response: List[int] = list(gen_out["response_ids"][0])
acc_loss_mask: List[int] = list(gen_out["loss_masks"][0])
acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None
acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if token_level_rewards else None
last = 0
Comment on lines +811 to +817
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to initialize the accumulators for a merge group (lines 812-817) is nearly identical to the logic for resetting them upon a prefix mismatch (lines 853-858). To reduce code duplication and improve maintainability, consider extracting this logic into a nested helper function that can be called in both places.


def flush():
nonlocal acc_prompt, acc_response, acc_loss_mask, acc_logprobs, acc_rewards_tokens, last
out_prompt_ids.append(acc_prompt)
out_response_ids.append(acc_response)
out_loss_masks.append(acc_loss_mask)
if has_logprobs:
out_logprobs.append(acc_logprobs)
out_rewards.append(acc_rewards_tokens if token_level_rewards else gen_out["rewards"][last])
if has_stop_reasons:
out_stop_reasons.append(gen_out["stop_reasons"][last])
out_trajectory_ids.append(gen_out["trajectory_ids"][last])
out_is_last_step.append(gen_out["is_last_step"][last])

prefix_failures_logged = 0
for i in range(1, n):
full_merged = acc_prompt + acc_response

if not _is_prefix(full_merged, list(gen_out["prompt_token_ids"][i])):
if prefix_failures_logged < 3:
next_prompt = list(gen_out["prompt_token_ids"][i])
# Find the first divergence point
diverge_idx = next(
(j for j in range(min(len(full_merged), len(next_prompt))) if full_merged[j] != next_prompt[j]),
min(len(full_merged), len(next_prompt)),
)
logger.warning(
f"Prefix mismatch at turn {i}: "
f"len(full_merged)={len(full_merged)}, len(next_prompt)={len(next_prompt)}, "
f"first_diverge_idx={diverge_idx}, "
f"full_merged[{diverge_idx-2}:{diverge_idx+3}]={full_merged[max(0,diverge_idx-2):diverge_idx+3]}, "
f"next_prompt[{diverge_idx-2}:{diverge_idx+3}]={next_prompt[max(0,diverge_idx-2):diverge_idx+3]}"
)
prefix_failures_logged += 1
flush()
acc_prompt = list(gen_out["prompt_token_ids"][i])
acc_response = list(gen_out["response_ids"][i])
acc_loss_mask = list(gen_out["loss_masks"][i])
acc_logprobs = list(gen_out["rollout_logprobs"][i]) if has_logprobs else None
acc_rewards_tokens = list(gen_out["rewards"][i]) if token_level_rewards else None
last = i
continue

obs_delta = list(gen_out["prompt_token_ids"][i][len(full_merged):])

acc_response.extend(obs_delta)
acc_loss_mask.extend([0] * len(obs_delta))
if acc_logprobs is not None:
acc_logprobs.extend([0.0] * len(obs_delta))
if acc_rewards_tokens is not None:
acc_rewards_tokens.extend([0.0] * len(obs_delta))

acc_response.extend(gen_out["response_ids"][i])
acc_loss_mask.extend(gen_out["loss_masks"][i])
if acc_logprobs is not None:
acc_logprobs.extend(gen_out["rollout_logprobs"][i])
if acc_rewards_tokens is not None:
acc_rewards_tokens.extend(list(gen_out["rewards"][i]))

last = i

flush()

return {
"prompt_token_ids": out_prompt_ids,
"response_ids": out_response_ids,
"rewards": out_rewards,
"loss_masks": out_loss_masks,
"stop_reasons": out_stop_reasons,
"rollout_metrics": gen_out.get("rollout_metrics"),
"rollout_logprobs": out_logprobs,
"trajectory_ids": out_trajectory_ids,
"rollout_expert_indices": None,
"is_last_step": out_is_last_step,
}


def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput:
"""Merge step-wise GeneratorOutput entries using prefix-aware merging.

For consecutive turns within the same trajectory where
prompt[i] + response[i] is a prefix of prompt[i+1],
merges them into a single entry with:
- prompt from the first turn in the merge group
- response tokens concatenated with observation deltas (obs_delta) in between
- Per-token fields (loss_masks, rewards, logprobs) concatenated, with default
values (0) for obs_delta positions
- Per-turn fields (stop_reason, is_last_step, trajectory_id) taken from the
last turn in the merge group

When the prefix condition fails between two consecutive turns, the current
merge group is flushed and a new group starts (greedy merging).

Args:
generator_output: Step-wise GeneratorOutput with trajectory_ids and is_last_step.

Returns:
Merged GeneratorOutput with one entry per merged group.
"""
num_samples = len(generator_output["response_ids"])
# step_wise=True validates trajectory_ids, is_last_step, and contiguous ordering
validate_generator_output(num_samples, generator_output, step_wise=True)
assert generator_output.get("rollout_expert_indices") is None, (
"rollout_expert_indices not supported for prefix-aware merging"
)

# Split into per-trajectory GeneratorOutputs using is_last_step boundaries
# (contiguity is guaranteed by validate_generator_output with step_wise=True)
is_last_step = generator_output["is_last_step"]
trajectory_slices: List[GeneratorOutput] = []
start = 0
for i in range(num_samples):
if is_last_step[i]:
trajectory_slices.append(_slice_generator_output(generator_output, list(range(start, i + 1))))
start = i + 1

merged_slices = [_merge_single_trajectory(s) for s in trajectory_slices]
# concatenate_generator_outputs re-aggregates rollout_metrics and validates
return concatenate_generator_outputs(merged_slices)


def build_dataloader(
cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False
) -> StatefulDataLoader:
Expand Down
Loading
Loading