diff --git a/docs/content/docs/tutorials/step-wise-training.mdx b/docs/content/docs/tutorials/step-wise-training.mdx index b2e5125372..af3810e202 100644 --- a/docs/content/docs/tutorials/step-wise-training.mdx +++ b/docs/content/docs/tutorials/step-wise-training.mdx @@ -12,6 +12,16 @@ However, re-tokenization has two fundamental limitations: **Step-wise training addresses both problems.** Instead of producing one `(prompt, response)` pair per trajectory, it decomposes each multi-turn trajectory into N separate training samples (one per LLM turn), using the **exact token IDs and logprobs from the inference engine** (via vLLM's `return_token_ids`). Each step's prompt is the full context the model saw at that turn, and the response is exactly the tokens the model generated. Because each turn is an independent sample, context management operations between turns are naturally supported — there is no requirement that turn N+1's prompt be a prefix extension of turn N's full sequence. +### Quick start + +To see how `SkyRLGymGenerator` supports step-wise training, you can run it with the search-r1 example. + +```bash +USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh +``` + +This page will also guide you how to implement step-wise training for your custom generator. + ### Impact on Training When step-wise is enabled, a batch of T trajectories with an average of M turns per trajectory produces T×M training samples (sequences). This means: diff --git a/examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh b/examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh index 1159a8774e..6a8c8fea9a 100644 --- a/examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh +++ b/examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh @@ -28,7 +28,7 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \ generator.inference_engine.tensor_parallel_size=1 \ trainer.epochs=20 \ trainer.eval_batch_size=1024 \ - trainer.eval_before_train=false \ + trainer.eval_before_train=true \ trainer.eval_interval=5 \ trainer.update_epochs_per_batch=1 \ trainer.train_batch_size=256 \ diff --git a/skyrl/train/evaluate.py b/skyrl/train/evaluate.py index e6b67673b9..938f0847ed 100644 --- a/skyrl/train/evaluate.py +++ b/skyrl/train/evaluate.py @@ -181,7 +181,7 @@ async def evaluate_step_wise( concat_all_envs.append(traj_id_to_input[traj_id.instance_id]["env_class"]) concat_env_extras.append(traj_id_to_input[traj_id.instance_id]["env_extras"]) concat_uids.append(traj_id.instance_id) - # validate_generator_output(generator_input, generator_output) + validate_generator_output(generator_input, generator_output, step_wise=True) generator_outputs.append(generator_output) concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs) diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 1f84a8c4a5..dd18e73ee5 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -712,8 +712,11 @@ async def generate( if generator_output["rollout_metrics"] is not None: self.all_metrics.update(generator_output["rollout_metrics"]) - if not self.cfg.generator.step_wise_trajectories: - validate_generator_output(len(input_batch["prompts"]), generator_output) + validate_generator_output( + len(input_batch["prompts"]), + generator_output, + step_wise=self.cfg.generator.step_wise_trajectories, + ) return generator_output diff --git a/skyrl/train/utils/trainer_utils.py b/skyrl/train/utils/trainer_utils.py index b2c27e5a1c..7fe3e53fb5 100644 --- a/skyrl/train/utils/trainer_utils.py +++ b/skyrl/train/utils/trainer_utils.py @@ -593,20 +593,25 @@ def zero_variance_filter(rewards: List[float], uids: List[str]) -> List[int]: return [i for i, uid in enumerate(uids) if uid in kept_uids_set] -def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput): +def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput, step_wise: bool = False): """Validate the generator output. Args: num_prompts: Number of input prompts used to produce this output. generator_output: The generated output batch to validate. + step_wise: If True, validate step-wise specific fields (is_last_step, trajectory_ids, + contiguous ordering). In step-wise mode, num_responses may exceed num_prompts + because each trajectory is expanded into multiple per-turn samples. """ if len(generator_output["response_ids"]) <= 0: raise RuntimeError("No outputs generated") - # check that input prompts, response ids, and prompt token ids are all the same length num_responses = len(generator_output["response_ids"]) num_prompt_tokens = len(generator_output["prompt_token_ids"]) - assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})" + + if not step_wise: + assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})" + assert ( num_responses == num_prompt_tokens ), f"Mismatch between responses ({num_responses}) and prompt_token_ids ({num_prompt_tokens})" @@ -660,6 +665,84 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu not isinstance(reward, list) for reward in rewards ), "rewards must be `List[float]` or `List[List[float]]`" + if step_wise: + _validate_step_wise_fields(generator_output, num_responses) + + +def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: int): + """Validate step-wise specific fields in the generator output. + + Checks that is_last_step and trajectory_ids are present, correctly sized, + contiguously ordered, and that is_last_step boundaries align with trajectory_id changes. + + The contiguity check is critical: the trainer's advantage broadcast uses + ``cumsum(shifted_is_last_step)`` to map each step to its trajectory, which + silently produces wrong results if steps from the same trajectory are interleaved + with steps from other trajectories. + + For more, see https://docs.skyrl.ai/docs/tutorials/step-wise-training#generatoroutput-format + """ + assert ( + generator_output.get("is_last_step") is not None + ), "step_wise=True but `is_last_step` is missing from generator output" + assert ( + generator_output.get("trajectory_ids") is not None + ), "step_wise=True but `trajectory_ids` is missing from generator output" + + is_last_step = generator_output["is_last_step"] + trajectory_ids = generator_output["trajectory_ids"] + + assert ( + len(is_last_step) == num_responses + ), f"is_last_step length ({len(is_last_step)}) must equal response_ids length ({num_responses})" + assert ( + len(trajectory_ids) == num_responses + ), f"trajectory_ids length ({len(trajectory_ids)}) must equal response_ids length ({num_responses})" + + assert ( + is_last_step[-1] is True + ), "is_last_step[-1] must be True (the last sample must be the final step of a trajectory)" + + num_trajectories = sum(1 for x in is_last_step if x) + assert num_trajectories >= 1, "is_last_step must contain at least one True value" + + # Validate contiguous ordering: all steps of the same trajectory must be adjacent. + seen_trajectory_ids = set() + prev_tid = None + for i, tid in enumerate(trajectory_ids): + tid_key = tid.to_string() if hasattr(tid, "to_string") else str(tid) + if tid_key != prev_tid: + assert tid_key not in seen_trajectory_ids, ( + f"Non-contiguous trajectory at index {i}: trajectory '{tid_key}' appeared before " + f"(at earlier indices), then a different trajectory, then again here. " + f"Step-wise training requires all steps of the same trajectory to be adjacent." + ) + if prev_tid is not None: + seen_trajectory_ids.add(prev_tid) + prev_tid = tid_key + if prev_tid is not None: + seen_trajectory_ids.add(prev_tid) + + # Validate is_last_step aligns with trajectory boundaries (both directions) + for i in range(num_responses - 1): + tid_cur = trajectory_ids[i].to_string() if hasattr(trajectory_ids[i], "to_string") else str(trajectory_ids[i]) + tid_next = ( + trajectory_ids[i + 1].to_string() + if hasattr(trajectory_ids[i + 1], "to_string") + else str(trajectory_ids[i + 1]) + ) + if tid_cur != tid_next: + assert is_last_step[i] is True, ( + f"Trajectory boundary at index {i} ('{tid_cur}' → '{tid_next}') " + f"but is_last_step[{i}] is False. Must be True at trajectory boundaries." + ) + else: + assert is_last_step[i] is not True, ( + f"is_last_step[{i}] is True but trajectory continues " + f"(trajectory '{tid_cur}' at index {i} and {i+1}). " + f"is_last_step must only be True at the final step of a trajectory." + ) + def build_dataloader( cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False diff --git a/tests/train/test_trainer_utils.py b/tests/train/test_trainer_utils.py index 0177ffbd5f..fa6039f01f 100644 --- a/tests/train/test_trainer_utils.py +++ b/tests/train/test_trainer_utils.py @@ -13,7 +13,7 @@ import pytest import ray -from skyrl.train.generators.base import GeneratorInput, GeneratorOutput +from skyrl.train.generators.base import GeneratorInput, GeneratorOutput, TrajectoryID from skyrl.train.utils.trainer_utils import ( build_dataloader, calculate_per_dataset_metrics, @@ -922,3 +922,143 @@ def test_validate_generator_output_invalid_rewards(): generator_output["rewards"] = [[0.5, 0.6], [0.7, 0.8]] validate_generator_output(len(input_batch["prompts"]), generator_output) + + +# ============================================================ +# Step-wise validation tests +# ============================================================ + + +def _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3), contiguous=True): + """Helper to build a step-wise GeneratorOutput for testing.""" + items = [] + for traj_idx in range(n_trajectories): + n_steps = steps_per_traj[traj_idx] + tid = TrajectoryID(instance_id=str(traj_idx), repetition_id=0) + for step in range(n_steps): + is_last = step == n_steps - 1 + prompt = list(range(10 + traj_idx * 100, 10 + traj_idx * 100 + 3 + step)) + resp = list(range(50 + traj_idx * 100 + step * 10, 50 + traj_idx * 100 + step * 10 + 3)) + reward = [0.0, 0.0, float(traj_idx + 1) if is_last else 0.0] + items.append((prompt, resp, reward, [1, 1, 1], is_last, tid)) + + if not contiguous: + max_steps = max(steps_per_traj) + reordered = [] + for step in range(max_steps): + for traj_idx in range(n_trajectories): + if step < steps_per_traj[traj_idx]: + idx = sum(steps_per_traj[:traj_idx]) + step + reordered.append(items[idx]) + items = reordered + + prompt_token_ids, response_ids, rewards, loss_masks = [], [], [], [] + is_last_step, trajectory_ids = [], [] + for prompt, resp, reward, mask, is_last, tid in items: + prompt_token_ids.append(prompt) + response_ids.append(resp) + rewards.append(reward) + loss_masks.append(mask) + is_last_step.append(is_last) + trajectory_ids.append(tid) + + return { + "prompt_token_ids": prompt_token_ids, + "response_ids": response_ids, + "rewards": rewards, + "loss_masks": loss_masks, + "stop_reasons": ["complete"] * len(response_ids), + "rollout_metrics": {}, + "rollout_logprobs": None, + "is_last_step": is_last_step, + "trajectory_ids": trajectory_ids, + } + + +def test_validate_stepwise_valid(): + """Valid step-wise output should pass validation.""" + output = _make_stepwise_output(n_trajectories=3, steps_per_traj=(1, 2, 3)) + validate_generator_output(num_prompts=3, generator_output=output, step_wise=True) + + +def test_validate_stepwise_single_step_trajectories(): + """All single-step trajectories should pass.""" + output = _make_stepwise_output(n_trajectories=4, steps_per_traj=(1, 1, 1, 1)) + validate_generator_output(num_prompts=4, generator_output=output, step_wise=True) + + +def test_validate_stepwise_missing_is_last_step(): + """Missing is_last_step should fail.""" + output = _make_stepwise_output() + del output["is_last_step"] + with pytest.raises(AssertionError, match="is_last_step.*missing"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_missing_trajectory_ids(): + """Missing trajectory_ids should fail.""" + output = _make_stepwise_output() + del output["trajectory_ids"] + with pytest.raises(AssertionError, match="trajectory_ids.*missing"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_is_last_step_length_mismatch(): + """is_last_step length mismatch should fail.""" + output = _make_stepwise_output() + output["is_last_step"] = output["is_last_step"][:-1] + with pytest.raises(AssertionError, match="is_last_step length"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_last_element_not_true(): + """is_last_step[-1] must be True.""" + output = _make_stepwise_output() + output["is_last_step"][-1] = False + with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_non_contiguous(): + """Non-contiguous trajectory ordering should fail.""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2), contiguous=False) + with pytest.raises(AssertionError, match="Non-contiguous trajectory"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_boundary_without_is_last(): + """Trajectory boundary where is_last_step is False should fail.""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2)) + # Traj 0 has steps at indices 0,1 and traj 1 at 2,3. Corrupt boundary. + output["is_last_step"][1] = False + with pytest.raises(AssertionError, match="Trajectory boundary at index 1"): + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_no_true_in_is_last_step(): + """is_last_step with no True values should fail.""" + output = _make_stepwise_output(n_trajectories=1, steps_per_traj=(3,)) + output["is_last_step"] = [False, False, False] + with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"): + validate_generator_output(num_prompts=1, generator_output=output, step_wise=True) + + +def test_validate_stepwise_num_prompts_not_checked(): + """In step-wise mode, num_prompts != num_responses is allowed (expansion).""" + output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3)) + # 5 step-samples from 2 prompts + validate_generator_output(num_prompts=2, generator_output=output, step_wise=True) + + +def test_validate_stepwise_multiple_is_last_step_true_per_trajectory(): + """Multiple is_last_step=True within a single trajectory should fail. + + A trajectory with 3 steps should have is_last_step=[False, False, True], + not [True, True, True]. Having multiple True values would corrupt the + cumsum(shifted_is_last_step) advantage broadcast. + """ + output = _make_stepwise_output(n_trajectories=1, steps_per_traj=(3,)) + # Corrupt: mark all steps as last + output["is_last_step"] = [True, True, True] + with pytest.raises(AssertionError, match="is_last_step.*True.*trajectory continues"): + validate_generator_output(num_prompts=1, generator_output=output, step_wise=True)