diff --git a/examples/train/search/run_search.sh b/examples/train/search/run_search.sh index 92714d0d05..08f0859ef2 100755 --- a/examples/train/search/run_search.sh +++ b/examples/train/search/run_search.sh @@ -48,6 +48,8 @@ else MULTI_TURN_ARGS="generator.use_conversation_multi_turn=false" fi +: "${MERGE_STEPWISE:=false}" + STEP_WISE_ARGS="" if [ "$STEP_WISE" = "true" ]; then STEP_WISE_ARGS="generator.step_wise_trajectories=true" @@ -56,6 +58,9 @@ if [ "$STEP_WISE" = "true" ]; then echo "WARNING: STEP_WISE=true requires USE_CONVERSATION_MULTI_TURN=true. Enabling it automatically." MULTI_TURN_ARGS="generator.use_conversation_multi_turn=true generator.append_eos_token_after_stop_str_in_multi_turn=true" fi + if [ "$MERGE_STEPWISE" = "true" ]; then + STEP_WISE_ARGS="$STEP_WISE_ARGS generator.merge_stepwise_output=true" + fi fi uv run --isolated --frozen --extra fsdp -m skyrl.train.entrypoints.main_base \ diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 6f5f145418..f30491cd4e 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -518,6 +518,9 @@ class GeneratorConfig(BaseConfig): """Can differ from the trainer's ``rope_scaling``, useful for thinking models.""" rope_theta: Optional[float] = None step_wise_trajectories: bool = False + merge_stepwise_output: bool = False + """When True (and step_wise_trajectories is True), apply prefix-aware merging + to collapse multi-turn step-wise sequences into single sequences before training.""" def __post_init__(self): diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index f2a52006e3..1a372db856 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -381,6 +381,7 @@ generator: rope_theta: ${trainer.rope_theta} step_wise_trajectories: false + merge_stepwise_output: false environment: env_class: "gsm8k" diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index f01eed964c..896aac2d52 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -71,6 +71,7 @@ build_dataloader, cleanup_old_checkpoints, extract_step_from_path, + merge_stepwise_output, run_on_each_node, validate_consistency_for_latest_checkpoint, validate_generator_output, @@ -600,6 +601,22 @@ def sync_policy_weights_to_inference_engines(self) -> List[ObjectRef]: def convert_to_training_input(self, generator_output: GeneratorOutput, uids: List[str]) -> TrainingInputBatch: """Converts lists to a padded batch of tensors for training""" + if self.cfg.generator.merge_stepwise_output: + num_seq_before_merge = len(generator_output["response_ids"]) + generator_output = merge_stepwise_output(generator_output) + num_seq_after_merge = len(generator_output["response_ids"]) + logger.info( + f"merge_stepwise_output: {num_seq_before_merge} sequences -> {num_seq_after_merge} sequences" + ) + self.all_metrics.update( + { + "generate/num_seq_before_merge": num_seq_before_merge, + "generate/num_seq_after_merge": num_seq_after_merge, + } + ) + # Update uids to match the merged output length + uids = [tid.instance_id for tid in generator_output["trajectory_ids"]] + prompt_ids: List[List[int]] = generator_output["prompt_token_ids"] response_ids: List[List[int]] = generator_output["response_ids"] rewards: List[List[float]] = generator_output["rewards"] @@ -886,11 +903,25 @@ def dump_data(self, data: TrainingInputBatch, file_name: str): data.save(data_save_dir / f"{file_name}.pkl") def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: - """Pad the batch to be divisible by dp size""" + """Pad the batch to be divisible by the training mini-batch size. + + For step-wise training the batch size is variable (depends on how many + turns each trajectory takes), so we pad to the full mini_batch_size to + satisfy the divisibility requirement in stage_data. + """ import math dp_size = self.dispatch.get_lcm_dp_size() - pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size + if self.cfg.generator.step_wise_trajectories: + n_samples = self.cfg.generator.n_samples_per_prompt + pad_target = max( + self.cfg.trainer.policy_mini_batch_size * n_samples, + self.cfg.trainer.critic_mini_batch_size * n_samples, + dp_size, + ) + else: + pad_target = dp_size + pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size new_tensors = {} training_input.metadata["pad_size"] = pad_size if pad_size == 0: diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index 7fe3e53fb5..795309f484 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -744,6 +744,198 @@ def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: ) +def _is_prefix(maybe_prefix: List[int], candidate: List[int]) -> bool: + """Check if maybe_prefix is a prefix of candidate (or equal).""" + if len(maybe_prefix) > len(candidate): + return False + return maybe_prefix == candidate[: len(maybe_prefix)] + + +def _slice_generator_output(generator_output: GeneratorOutput, indices: List[int]) -> GeneratorOutput: + """Slice a GeneratorOutput to keep only the entries at the given indices.""" + return { + "prompt_token_ids": [generator_output["prompt_token_ids"][i] for i in indices], + "response_ids": [generator_output["response_ids"][i] for i in indices], + "rewards": [generator_output["rewards"][i] for i in indices], + "loss_masks": [generator_output["loss_masks"][i] for i in indices], + "stop_reasons": ( + [generator_output["stop_reasons"][i] for i in indices] + if generator_output.get("stop_reasons") is not None + else None + ), + "rollout_metrics": generator_output.get("rollout_metrics"), + "rollout_logprobs": ( + [generator_output["rollout_logprobs"][i] for i in indices] + if generator_output.get("rollout_logprobs") is not None + else None + ), + "trajectory_ids": ( + [generator_output["trajectory_ids"][i] for i in indices] + if generator_output.get("trajectory_ids") is not None + else None + ), + "rollout_expert_indices": ( + [generator_output["rollout_expert_indices"][i] for i in indices] + if generator_output.get("rollout_expert_indices") is not None + else None + ), + "is_last_step": ( + [generator_output["is_last_step"][i] for i in indices] + if generator_output.get("is_last_step") is not None + else None + ), + } + + +def _merge_single_trajectory(gen_out: GeneratorOutput) -> GeneratorOutput: + """Greedily merge turns of a single trajectory using prefix matching. + + Takes a GeneratorOutput whose entries all belong to the same trajectory + and returns a merged GeneratorOutput (potentially fewer entries). + """ + n = len(gen_out["response_ids"]) + token_level_rewards = n > 0 and isinstance(gen_out["rewards"][0], list) + has_logprobs = gen_out.get("rollout_logprobs") is not None + has_stop_reasons = gen_out.get("stop_reasons") is not None + + # Per-field output accumulators (lists of per-entry values) + out_prompt_ids: List[List[int]] = [] + out_response_ids: List[List[int]] = [] + out_loss_masks: List[List[int]] = [] + out_logprobs: Optional[List[List[float]]] = [] if has_logprobs else None + out_rewards: list = [] + out_stop_reasons: Optional[List[str]] = [] if has_stop_reasons else None + out_trajectory_ids: list = [] + out_is_last_step: List[bool] = [] + + # Accumulator for the current merge group + acc_prompt: List[int] = list(gen_out["prompt_token_ids"][0]) + acc_response: List[int] = list(gen_out["response_ids"][0]) + acc_loss_mask: List[int] = list(gen_out["loss_masks"][0]) + acc_logprobs: Optional[List[float]] = list(gen_out["rollout_logprobs"][0]) if has_logprobs else None + acc_rewards_tokens: Optional[List[float]] = list(gen_out["rewards"][0]) if token_level_rewards else None + last = 0 + + def flush(): + nonlocal acc_prompt, acc_response, acc_loss_mask, acc_logprobs, acc_rewards_tokens, last + out_prompt_ids.append(acc_prompt) + out_response_ids.append(acc_response) + out_loss_masks.append(acc_loss_mask) + if has_logprobs: + out_logprobs.append(acc_logprobs) + out_rewards.append(acc_rewards_tokens if token_level_rewards else gen_out["rewards"][last]) + if has_stop_reasons: + out_stop_reasons.append(gen_out["stop_reasons"][last]) + out_trajectory_ids.append(gen_out["trajectory_ids"][last]) + out_is_last_step.append(gen_out["is_last_step"][last]) + + prefix_failures_logged = 0 + for i in range(1, n): + full_merged = acc_prompt + acc_response + + if not _is_prefix(full_merged, list(gen_out["prompt_token_ids"][i])): + if prefix_failures_logged < 3: + next_prompt = list(gen_out["prompt_token_ids"][i]) + # Find the first divergence point + diverge_idx = next( + (j for j in range(min(len(full_merged), len(next_prompt))) if full_merged[j] != next_prompt[j]), + min(len(full_merged), len(next_prompt)), + ) + logger.warning( + f"Prefix mismatch at turn {i}: " + f"len(full_merged)={len(full_merged)}, len(next_prompt)={len(next_prompt)}, " + f"first_diverge_idx={diverge_idx}, " + f"full_merged[{diverge_idx-2}:{diverge_idx+3}]={full_merged[max(0,diverge_idx-2):diverge_idx+3]}, " + f"next_prompt[{diverge_idx-2}:{diverge_idx+3}]={next_prompt[max(0,diverge_idx-2):diverge_idx+3]}" + ) + prefix_failures_logged += 1 + flush() + acc_prompt = list(gen_out["prompt_token_ids"][i]) + acc_response = list(gen_out["response_ids"][i]) + acc_loss_mask = list(gen_out["loss_masks"][i]) + acc_logprobs = list(gen_out["rollout_logprobs"][i]) if has_logprobs else None + acc_rewards_tokens = list(gen_out["rewards"][i]) if token_level_rewards else None + last = i + continue + + obs_delta = list(gen_out["prompt_token_ids"][i][len(full_merged):]) + + acc_response.extend(obs_delta) + acc_loss_mask.extend([0] * len(obs_delta)) + if acc_logprobs is not None: + acc_logprobs.extend([0.0] * len(obs_delta)) + if acc_rewards_tokens is not None: + acc_rewards_tokens.extend([0.0] * len(obs_delta)) + + acc_response.extend(gen_out["response_ids"][i]) + acc_loss_mask.extend(gen_out["loss_masks"][i]) + if acc_logprobs is not None: + acc_logprobs.extend(gen_out["rollout_logprobs"][i]) + if acc_rewards_tokens is not None: + acc_rewards_tokens.extend(list(gen_out["rewards"][i])) + + last = i + + flush() + + return { + "prompt_token_ids": out_prompt_ids, + "response_ids": out_response_ids, + "rewards": out_rewards, + "loss_masks": out_loss_masks, + "stop_reasons": out_stop_reasons, + "rollout_metrics": gen_out.get("rollout_metrics"), + "rollout_logprobs": out_logprobs, + "trajectory_ids": out_trajectory_ids, + "rollout_expert_indices": None, + "is_last_step": out_is_last_step, + } + + +def merge_stepwise_output(generator_output: GeneratorOutput) -> GeneratorOutput: + """Merge step-wise GeneratorOutput entries using prefix-aware merging. + + For consecutive turns within the same trajectory where + prompt[i] + response[i] is a prefix of prompt[i+1], + merges them into a single entry with: + - prompt from the first turn in the merge group + - response tokens concatenated with observation deltas (obs_delta) in between + - Per-token fields (loss_masks, rewards, logprobs) concatenated, with default + values (0) for obs_delta positions + - Per-turn fields (stop_reason, is_last_step, trajectory_id) taken from the + last turn in the merge group + + When the prefix condition fails between two consecutive turns, the current + merge group is flushed and a new group starts (greedy merging). + + Args: + generator_output: Step-wise GeneratorOutput with trajectory_ids and is_last_step. + + Returns: + Merged GeneratorOutput with one entry per merged group. + """ + num_samples = len(generator_output["response_ids"]) + # step_wise=True validates trajectory_ids, is_last_step, and contiguous ordering + validate_generator_output(num_samples, generator_output, step_wise=True) + assert generator_output.get("rollout_expert_indices") is None, ( + "rollout_expert_indices not supported for prefix-aware merging" + ) + + # Split into per-trajectory GeneratorOutputs using is_last_step boundaries + # (contiguity is guaranteed by validate_generator_output with step_wise=True) + is_last_step = generator_output["is_last_step"] + trajectory_slices: List[GeneratorOutput] = [] + start = 0 + for i in range(num_samples): + if is_last_step[i]: + trajectory_slices.append(_slice_generator_output(generator_output, list(range(start, i + 1)))) + start = i + 1 + + merged_slices = [_merge_single_trajectory(s) for s in trajectory_slices] + # concatenate_generator_outputs re-aggregates rollout_metrics and validates + return concatenate_generator_outputs(merged_slices) + + def build_dataloader( cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False ) -> StatefulDataLoader: diff --git a/tests/train/test_merge_stepwise_output.py b/tests/train/test_merge_stepwise_output.py new file mode 100644 index 0000000000..7b32b22dcb --- /dev/null +++ b/tests/train/test_merge_stepwise_output.py @@ -0,0 +1,564 @@ +""" +CPU-only tests for merge_stepwise_output (prefix-aware merging). + +uv run --isolated --extra dev pytest tests/train/test_merge_stepwise_output.py -v +""" + +import pytest + +from skyrl.train.generators.base import GeneratorOutput, TrajectoryID +from skyrl.train.utils.trainer_utils import merge_stepwise_output + + +def _make_tid(instance_id: str, rep: int = 0) -> TrajectoryID: + return TrajectoryID(instance_id=instance_id, repetition_id=rep) + + +# ─────────────────────────────────────────────────────────────────── +# Core merging cases +# ─────────────────────────────────────────────────────────────────── + + +def test_case1_response_only_assistant(): + """Case 1: response only contains assistant-generated tokens. + + Turn 1: prompt=[O1], response=[A1] + Turn 2: prompt=[O1, A1, O2], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, A2] + obs_delta = [O2] + """ + tid = _make_tid("traj_1") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20], [40, 41]], + "rewards": [[1.0], [0.0, 5.0]], + "loss_masks": [[1], [1, 1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.5], [-0.3, -0.4]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = A1 + obs_delta(O2) + A2(two tokens) + assert merged["response_ids"] == [[20, 30, 40, 41]] + # loss_mask: A1=1, O2=0, A2_tok1=1, A2_tok2=1 + assert merged["loss_masks"] == [[1, 0, 1, 1]] + # logprobs: A1=-0.5, O2=0.0, A2_tok1=-0.3, A2_tok2=-0.4 + assert merged["rollout_logprobs"] == [[-0.5, 0.0, -0.3, -0.4]] + # rewards: A1=1.0, O2=0.0, A2_tok1=0.0, A2_tok2=5.0 + assert merged["rewards"] == [[1.0, 0.0, 0.0, 5.0]] + assert merged["stop_reasons"] == ["eos"] + assert merged["trajectory_ids"] == [tid] + assert merged["is_last_step"] == [True] + + +def test_case2_response_contains_observation(): + """Case 2: response contains both assistant and observation tokens. + + Turn 1: prompt=[O1], response=[A1, O2] + Turn 2: prompt=[O1, A1, O2], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, A2] + obs_delta = [] (empty, since prompt+response of turn 1 == prompt of turn 2) + """ + tid = _make_tid("traj_2") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20, 30], [40]], + "rewards": [[0.0, 0.0], [7.0]], + "loss_masks": [[1, 0], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.1, -0.2], [-0.3]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = [A1, O2] + [] + [A2] = [A1, O2, A2] + assert merged["response_ids"] == [[20, 30, 40]] + assert merged["loss_masks"] == [[1, 0, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2, -0.3]] + assert merged["rewards"] == [[0.0, 0.0, 7.0]] + assert merged["stop_reasons"] == ["eos"] + assert merged["is_last_step"] == [True] + + +def test_case3_combination(): + """Case 3: response has obs tokens AND there's extra obs_delta in prompt. + + Turn 1: prompt=[O1], response=[A1, O2] + Turn 2: prompt=[O1, A1, O2, O2_5], response=[A2] + + Merged: prompt=[O1], response=[A1, O2, O2_5, A2] + obs_delta = [O2_5] + """ + tid = _make_tid("traj_3") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30, 35]], + "response_ids": [[20, 30], [40]], + "rewards": [[0.0, 0.0], [9.0]], + "loss_masks": [[1, 0], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-0.1, -0.2], [-0.5]], + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[10]] + # response = [A1, O2] + obs_delta[O2_5] + [A2] + assert merged["response_ids"] == [[20, 30, 35, 40]] + assert merged["loss_masks"] == [[1, 0, 0, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2, 0.0, -0.5]] + assert merged["rewards"] == [[0.0, 0.0, 0.0, 9.0]] + assert merged["is_last_step"] == [True] + + +# ─────────────────────────────────────────────────────────────────── +# Multi-turn and multi-trajectory +# ─────────────────────────────────────────────────────────────────── + + +def test_three_turns(): + """Three-turn trajectory merging (Case 1 pattern repeated). + + Turn 1: prompt=[1], response=[2] + Turn 2: prompt=[1,2,3], response=[4] + Turn 3: prompt=[1,2,3,4,5], response=[6] + + Merged: prompt=[1], response=[2, 3, 4, 5, 6] + """ + tid = _make_tid("traj_multi") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1], [1, 2, 3], [1, 2, 3, 4, 5]], + "response_ids": [[2], [4], [6]], + "rewards": [[0.0], [0.0], [10.0]], + "loss_masks": [[1], [1], [1]], + "stop_reasons": ["continue", "continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0]], + "trajectory_ids": [tid, tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["prompt_token_ids"] == [[1]] + # resp[0]=2, obs_delta=3, resp[1]=4, obs_delta=5, resp[2]=6 + assert merged["response_ids"] == [[2, 3, 4, 5, 6]] + assert merged["loss_masks"] == [[1, 0, 1, 0, 1]] + assert merged["rollout_logprobs"] == [[-1.0, 0.0, -2.0, 0.0, -3.0]] + assert merged["rewards"] == [[0.0, 0.0, 0.0, 0.0, 10.0]] + assert merged["is_last_step"] == [True] + assert merged["stop_reasons"] == ["eos"] + + +def test_multiple_trajectories(): + """Two separate trajectories in the same batch, each with 2 turns.""" + tid_a = _make_tid("A") + tid_b = _make_tid("B") + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [10], # A turn 1 + [10, 20, 30], # A turn 2 + [100], # B turn 1 + [100, 200, 300], # B turn 2 + ], + "response_ids": [ + [20], # A turn 1 + [40], # A turn 2 + [200], # B turn 1 + [400], # B turn 2 + ], + "rewards": [[0.0], [1.0], [0.0], [2.0]], + "loss_masks": [[1], [1], [1], [1]], + "stop_reasons": ["continue", "eos", "continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid_a, tid_a, tid_b, tid_b], + "rollout_expert_indices": None, + "is_last_step": [False, True, False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 2 + # Trajectory A merged + assert merged["prompt_token_ids"][0] == [10] + assert merged["response_ids"][0] == [20, 30, 40] + assert merged["loss_masks"][0] == [1, 0, 1] + # Trajectory B merged + assert merged["prompt_token_ids"][1] == [100] + assert merged["response_ids"][1] == [200, 300, 400] + assert merged["loss_masks"][1] == [1, 0, 1] + assert merged["is_last_step"] == [True, True] + assert merged["stop_reasons"] == ["eos", "eos"] + + +def test_mixed_trajectories_and_turns(): + """Mix of single-turn and multi-turn trajectories in one batch.""" + tid_a = _make_tid("A") + tid_b = _make_tid("B") # single turn + tid_c = _make_tid("C") + + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [1], # A turn 1 + [1, 2, 3], # A turn 2 + [1, 2, 3, 4, 5], # A turn 3 + [50], # B single turn + [60], # C turn 1 + [60, 70, 80], # C turn 2 + ], + "response_ids": [ + [2], # A + [4], # A + [6], # A + [51], # B + [70], # C + [90], # C + ], + "rewards": [[0.0], [0.0], [10.0], [3.0], [0.0], [7.0]], + "loss_masks": [[1], [1], [1], [1], [1], [1]], + "stop_reasons": ["c", "c", "eos", "eos", "c", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0], [-4.0], [-5.0], [-6.0]], + "trajectory_ids": [tid_a, tid_a, tid_a, tid_b, tid_c, tid_c], + "rollout_expert_indices": None, + "is_last_step": [False, False, True, True, False, True], + } + + merged = merge_stepwise_output(gen_out) + + # 3 entries: A merged, B single, C merged + assert len(merged["response_ids"]) == 3 + + # A: prompt=[1], response=[2,3,4,5,6] + assert merged["prompt_token_ids"][0] == [1] + assert merged["response_ids"][0] == [2, 3, 4, 5, 6] + assert merged["loss_masks"][0] == [1, 0, 1, 0, 1] + assert merged["rollout_logprobs"][0] == [-1.0, 0.0, -2.0, 0.0, -3.0] + assert merged["rewards"][0] == [0.0, 0.0, 0.0, 0.0, 10.0] + + # B: unchanged + assert merged["prompt_token_ids"][1] == [50] + assert merged["response_ids"][1] == [51] + + # C: prompt=[60], response=[70,80,90] + assert merged["prompt_token_ids"][2] == [60] + assert merged["response_ids"][2] == [70, 80, 90] + assert merged["loss_masks"][2] == [1, 0, 1] + + assert merged["is_last_step"] == [True, True, True] + assert merged["stop_reasons"] == ["eos", "eos", "eos"] + + +# ─────────────────────────────────────────────────────────────────── +# Edge cases / optional fields +# ─────────────────────────────────────────────────────────────────── + + +def test_single_turn_passthrough(): + """Single-turn trajectory is passed through unchanged.""" + tid = _make_tid("single") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1, 2, 3]], + "response_ids": [[4, 5]], + "rewards": [[0.5, 0.6]], + "loss_masks": [[1, 1]], + "stop_reasons": ["eos"], + "rollout_metrics": {"some_metric": 1.0}, + "rollout_logprobs": [[-0.1, -0.2]], + "trajectory_ids": [tid], + "rollout_expert_indices": None, + "is_last_step": [True], + } + + merged = merge_stepwise_output(gen_out) + + assert merged["prompt_token_ids"] == [[1, 2, 3]] + assert merged["response_ids"] == [[4, 5]] + assert merged["loss_masks"] == [[1, 1]] + assert merged["rollout_logprobs"] == [[-0.1, -0.2]] + assert merged["rewards"] == [[0.5, 0.6]] + assert merged["is_last_step"] == [True] + # rollout_metrics is re-aggregated by concatenate_generator_outputs + assert merged["rollout_metrics"] is not None + + +def test_per_trajectory_scalar_rewards(): + """Per-trajectory scalar rewards (List[float]) are handled correctly.""" + tid = _make_tid("scalar_rew") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20], [40]], + "rewards": [0.0, 5.0], # scalar per turn + "loss_masks": [[1], [1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert len(merged["response_ids"]) == 1 + assert merged["response_ids"] == [[20, 30, 40]] + # Scalar reward: use the last turn's value + assert merged["rewards"] == [5.0] + assert merged["rollout_logprobs"] is None + + +def test_no_logprobs_no_stop_reasons(): + """Works correctly when rollout_logprobs and stop_reasons are None.""" + tid = _make_tid("no_lp") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1], [1, 2, 3]], + "response_ids": [[2], [4]], + "rewards": [[0.0], [1.0]], + "loss_masks": [[1], [1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert merged["rollout_logprobs"] is None + assert merged["stop_reasons"] is None + assert merged["response_ids"] == [[2, 3, 4]] + assert merged["loss_masks"] == [[1, 0, 1]] + + +def test_empty_obs_delta(): + """Case 2 where obs_delta is empty (prompt+response == next prompt exactly).""" + tid = _make_tid("empty_delta") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20, 30]], + "response_ids": [[20, 30], [40, 50]], + "rewards": [[0.0, 0.0], [1.0, 2.0]], + "loss_masks": [[1, 0], [1, 1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + assert merged["prompt_token_ids"] == [[10]] + # No obs_delta, just concat responses + assert merged["response_ids"] == [[20, 30, 40, 50]] + assert merged["loss_masks"] == [[1, 0, 1, 1]] + assert merged["rewards"] == [[0.0, 0.0, 1.0, 2.0]] + + +def test_prefix_mismatch_no_merge(): + """If prefix condition fails, turns are kept separate.""" + tid = _make_tid("no_merge") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [99, 88]], # second prompt doesn't share prefix + "response_ids": [[20], [40]], + "rewards": [[0.0], [1.0]], + "loss_masks": [[1], [1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + # No merging happened, output has same number of entries + assert len(merged["response_ids"]) == 2 + assert merged["prompt_token_ids"] == [[10], [99, 88]] + assert merged["response_ids"] == [[20], [40]] + assert merged["is_last_step"] == [False, True] + + +def test_prefix_of_prompt_plus_response_but_not_prompt_alone_no_merge(): + """prompt[i]+response[i] is a prefix of prompt[i+1]+response[i+1] but NOT + of prompt[i+1] alone. This must NOT merge. + + Turn 1: prompt=[10], response=[20, 30] + → full = [10, 20, 30] + Turn 2: prompt=[10, 20], response=[30, 40] + → prompt[1]+response[1] = [10, 20, 30, 40] + + full ([10,20,30]) IS a prefix of prompt[1]+response[1] ([10,20,30,40]), + but is NOT a prefix of prompt[1] ([10,20]) since full is longer. + This would imply response tokens from turn 1 overlap with response tokens + of turn 2, which is malformed — we intentionally refuse to merge. + """ + tid = _make_tid("overlap") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[10], [10, 20]], + "response_ids": [[20, 30], [30, 40]], + "rewards": [[0.0, 0.0], [1.0, 2.0]], + "loss_masks": [[1, 0], [1, 1]], + "stop_reasons": ["continue", "eos"], + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, True], + } + + merged = merge_stepwise_output(gen_out) + + # No merging: kept as 2 separate entries + assert len(merged["response_ids"]) == 2 + assert merged["prompt_token_ids"] == [[10], [10, 20]] + assert merged["response_ids"] == [[20, 30], [30, 40]] + assert merged["is_last_step"] == [False, True] + + +def test_partial_merge_within_trajectory(): + """4 turns where prefix breaks mid-trajectory, producing 2 merged sequences. + + Turn 1: prompt=[1], response=[2] → prefix OK for turn 2 + Turn 2: prompt=[1,2,3], response=[4] → prefix BREAKS for turn 3 + (re-tokenization changed [2,3] to [23]) + Turn 3: prompt=[1,23,4,5], response=[6] → prefix OK for turn 4 + Turn 4: prompt=[1,23,4,5,6,7], response=[8] + + Turns 1+2 merge into one sequence, turns 3+4 merge into another. + Result: 4 turns → 2 sequences. + """ + tid = _make_tid("partial") + gen_out: GeneratorOutput = { + "prompt_token_ids": [ + [1], # turn 1 + [1, 2, 3], # turn 2: prompt[0]+resp[0]=[1,2] is prefix of [1,2,3] ✓ + [1, 23, 4, 5], # turn 3: prompt[1]+resp[1]=[1,2,3,4] is NOT prefix of [1,23,4,5] ✗ + [1, 23, 4, 5, 6, 7], # turn 4: prompt[2]+resp[2]=[1,23,4,5,6] is prefix ✓ + ], + "response_ids": [ + [2], # turn 1 + [4], # turn 2 + [6], # turn 3 + [8], # turn 4 + ], + "rewards": [[0.0], [0.0], [0.0], [10.0]], + "loss_masks": [[1], [1], [1], [1]], + "stop_reasons": ["c", "c", "c", "eos"], + "rollout_metrics": None, + "rollout_logprobs": [[-1.0], [-2.0], [-3.0], [-4.0]], + "trajectory_ids": [tid, tid, tid, tid], + "rollout_expert_indices": None, + "is_last_step": [False, False, False, True], + } + + merged = merge_stepwise_output(gen_out) + + # 4 turns → 2 merged sequences + assert len(merged["response_ids"]) == 2 + + # First merged group: turns 1+2 + # prompt=[1], response=[2] + obs_delta=[3] + [4] = [2, 3, 4] + assert merged["prompt_token_ids"][0] == [1] + assert merged["response_ids"][0] == [2, 3, 4] + assert merged["loss_masks"][0] == [1, 0, 1] + assert merged["rollout_logprobs"][0] == [-1.0, 0.0, -2.0] + assert merged["rewards"][0] == [0.0, 0.0, 0.0] + assert merged["is_last_step"][0] is False + + # Second merged group: turns 3+4 + # prompt=[1,23,4,5], response=[6] + obs_delta=[7] + [8] = [6, 7, 8] + assert merged["prompt_token_ids"][1] == [1, 23, 4, 5] + assert merged["response_ids"][1] == [6, 7, 8] + assert merged["loss_masks"][1] == [1, 0, 1] + assert merged["rollout_logprobs"][1] == [-3.0, 0.0, -4.0] + assert merged["rewards"][1] == [0.0, 0.0, 10.0] + assert merged["is_last_step"][1] is True + + assert merged["stop_reasons"] == ["c", "eos"] + assert merged["trajectory_ids"] == [tid, tid] + + +# ─────────────────────────────────────────────────────────────────── +# Assertion / validation tests +# ─────────────────────────────────────────────────────────────────── + + +def test_asserts_trajectory_ids_required(): + """Raises when trajectory_ids is None.""" + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": None, + "rollout_expert_indices": None, + "is_last_step": [True], + } + with pytest.raises(AssertionError, match="trajectory_ids"): + merge_stepwise_output(gen_out) + + +def test_asserts_is_last_step_required(): + """Raises when is_last_step is None.""" + tid = _make_tid("x") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid], + "rollout_expert_indices": None, + "is_last_step": None, + } + with pytest.raises(AssertionError, match="is_last_step"): + merge_stepwise_output(gen_out) + + +def test_asserts_no_expert_indices(): + """Raises when rollout_expert_indices is present.""" + tid = _make_tid("x") + gen_out: GeneratorOutput = { + "prompt_token_ids": [[1]], + "response_ids": [[2]], + "rewards": [[1.0]], + "loss_masks": [[1]], + "stop_reasons": None, + "rollout_metrics": None, + "rollout_logprobs": None, + "trajectory_ids": [tid], + "rollout_expert_indices": [[[[1, 2]]]], + "is_last_step": [True], + } + with pytest.raises(AssertionError, match="rollout_expert_indices not supported"): + merge_stepwise_output(gen_out)