Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ However, re-tokenization has two fundamental limitations:

**Step-wise training addresses both problems.** Instead of producing one `(prompt, response)` pair per trajectory, it decomposes each multi-turn trajectory into N separate training samples (one per LLM turn), using the **exact token IDs and logprobs from the inference engine** (via vLLM's `return_token_ids`). Each step's prompt is the full context the model saw at that turn, and the response is exactly the tokens the model generated. Because each turn is an independent sample, context management operations between turns are naturally supported — there is no requirement that turn N+1's prompt be a prefix extension of turn N's full sequence.

### Quick start

To see how `SkyRLGymGenerator` supports step-wise training, you can run it with the search-r1 example.

```bash
USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh
```

This page will also guide you how to implement step-wise training for your custom generator.

### Impact on Training

When step-wise is enabled, a batch of T trajectories with an average of M turns per trajectory produces T×M training samples (sequences). This means:
Expand Down
2 changes: 1 addition & 1 deletion examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_base \
generator.inference_engine.tensor_parallel_size=1 \
trainer.epochs=20 \
trainer.eval_batch_size=1024 \
trainer.eval_before_train=false \
trainer.eval_before_train=true \
trainer.eval_interval=5 \
trainer.update_epochs_per_batch=1 \
trainer.train_batch_size=256 \
Expand Down
2 changes: 1 addition & 1 deletion skyrl/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def evaluate_step_wise(
concat_all_envs.append(traj_id_to_input[traj_id.instance_id]["env_class"])
concat_env_extras.append(traj_id_to_input[traj_id.instance_id]["env_extras"])
concat_uids.append(traj_id.instance_id)
# validate_generator_output(generator_input, generator_output)
validate_generator_output(generator_input, generator_output, step_wise=True)
Comment thread
CharlieFRuan marked this conversation as resolved.
Comment thread
CharlieFRuan marked this conversation as resolved.
generator_outputs.append(generator_output)
concat_generator_outputs: GeneratorOutput = concatenate_generator_outputs(generator_outputs)

Expand Down
7 changes: 5 additions & 2 deletions skyrl/train/trainer.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Step-wise avg_response_length divides by total steps instead of number of trajectories

At skyrl/train/trainer.py:679-683, the step-wise avg_response_length calculation sums response lengths only for steps where is_last_step=True, but divides by len(response_ids) (total number of steps across all trajectories). For example, with 2 trajectories having 3 and 2 steps respectively, and last-step response lengths of 10 and 8: numerator = 18, denominator = 5, result = 3.6 — instead of the correct 9.0. The denominator should be the number of trajectories (i.e., sum(1 for x in generator_output["is_last_step"] if x)). Compare with the non-step-wise version at line 685-687 which correctly uses len(response_ids) as denominator when summing all responses.

(Refers to lines 679-683)

Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will fix in a separate PR

Original file line number Diff line number Diff line change
Expand Up @@ -712,8 +712,11 @@ async def generate(
if generator_output["rollout_metrics"] is not None:
self.all_metrics.update(generator_output["rollout_metrics"])

if not self.cfg.generator.step_wise_trajectories:
validate_generator_output(len(input_batch["prompts"]), generator_output)
validate_generator_output(
len(input_batch["prompts"]),
generator_output,
step_wise=self.cfg.generator.step_wise_trajectories,
)

return generator_output

Expand Down
89 changes: 86 additions & 3 deletions skyrl/train/utils/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,25 @@ def zero_variance_filter(rewards: List[float], uids: List[str]) -> List[int]:
return [i for i, uid in enumerate(uids) if uid in kept_uids_set]


def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput):
def validate_generator_output(num_prompts: int, generator_output: GeneratorOutput, step_wise: bool = False):
"""Validate the generator output.

Args:
num_prompts: Number of input prompts used to produce this output.
generator_output: The generated output batch to validate.
step_wise: If True, validate step-wise specific fields (is_last_step, trajectory_ids,
contiguous ordering). In step-wise mode, num_responses may exceed num_prompts
because each trajectory is expanded into multiple per-turn samples.
"""
if len(generator_output["response_ids"]) <= 0:
raise RuntimeError("No outputs generated")

# check that input prompts, response ids, and prompt token ids are all the same length
num_responses = len(generator_output["response_ids"])
num_prompt_tokens = len(generator_output["prompt_token_ids"])
assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})"

if not step_wise:
assert num_prompts == num_responses, f"Mismatch between prompts ({num_prompts}) and responses ({num_responses})"

assert (
num_responses == num_prompt_tokens
), f"Mismatch between responses ({num_responses}) and prompt_token_ids ({num_prompt_tokens})"
Expand Down Expand Up @@ -660,6 +665,84 @@ def validate_generator_output(num_prompts: int, generator_output: GeneratorOutpu
not isinstance(reward, list) for reward in rewards
), "rewards must be `List[float]` or `List[List[float]]`"

if step_wise:
_validate_step_wise_fields(generator_output, num_responses)


def _validate_step_wise_fields(generator_output: GeneratorOutput, num_responses: int):
"""Validate step-wise specific fields in the generator output.

Checks that is_last_step and trajectory_ids are present, correctly sized,
contiguously ordered, and that is_last_step boundaries align with trajectory_id changes.

The contiguity check is critical: the trainer's advantage broadcast uses
``cumsum(shifted_is_last_step)`` to map each step to its trajectory, which
silently produces wrong results if steps from the same trajectory are interleaved
with steps from other trajectories.

For more, see https://docs.skyrl.ai/docs/tutorials/step-wise-training#generatoroutput-format
"""
assert (
generator_output.get("is_last_step") is not None
), "step_wise=True but `is_last_step` is missing from generator output"
assert (
generator_output.get("trajectory_ids") is not None
), "step_wise=True but `trajectory_ids` is missing from generator output"

is_last_step = generator_output["is_last_step"]
trajectory_ids = generator_output["trajectory_ids"]

assert (
len(is_last_step) == num_responses
), f"is_last_step length ({len(is_last_step)}) must equal response_ids length ({num_responses})"
assert (
len(trajectory_ids) == num_responses
), f"trajectory_ids length ({len(trajectory_ids)}) must equal response_ids length ({num_responses})"

assert (
is_last_step[-1] is True
), "is_last_step[-1] must be True (the last sample must be the final step of a trajectory)"

num_trajectories = sum(1 for x in is_last_step if x)
assert num_trajectories >= 1, "is_last_step must contain at least one True value"

# Validate contiguous ordering: all steps of the same trajectory must be adjacent.
seen_trajectory_ids = set()
prev_tid = None
for i, tid in enumerate(trajectory_ids):
tid_key = tid.to_string() if hasattr(tid, "to_string") else str(tid)
if tid_key != prev_tid:
assert tid_key not in seen_trajectory_ids, (
f"Non-contiguous trajectory at index {i}: trajectory '{tid_key}' appeared before "
f"(at earlier indices), then a different trajectory, then again here. "
f"Step-wise training requires all steps of the same trajectory to be adjacent."
)
if prev_tid is not None:
seen_trajectory_ids.add(prev_tid)
prev_tid = tid_key
if prev_tid is not None:
seen_trajectory_ids.add(prev_tid)

# Validate is_last_step aligns with trajectory boundaries (both directions)
for i in range(num_responses - 1):
tid_cur = trajectory_ids[i].to_string() if hasattr(trajectory_ids[i], "to_string") else str(trajectory_ids[i])
tid_next = (
trajectory_ids[i + 1].to_string()
if hasattr(trajectory_ids[i + 1], "to_string")
else str(trajectory_ids[i + 1])
)
if tid_cur != tid_next:
assert is_last_step[i] is True, (
f"Trajectory boundary at index {i} ('{tid_cur}' → '{tid_next}') "
f"but is_last_step[{i}] is False. Must be True at trajectory boundaries."
)
else:
assert is_last_step[i] is not True, (
f"is_last_step[{i}] is True but trajectory continues "
f"(trajectory '{tid_cur}' at index {i} and {i+1}). "
f"is_last_step must only be True at the final step of a trajectory."
)


def build_dataloader(
cfg: SkyRLTrainConfig, dataset: PromptDataset, is_train=True, is_fully_async=False
Expand Down
142 changes: 141 additions & 1 deletion tests/train/test_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pytest
import ray

from skyrl.train.generators.base import GeneratorInput, GeneratorOutput
from skyrl.train.generators.base import GeneratorInput, GeneratorOutput, TrajectoryID
from skyrl.train.utils.trainer_utils import (
build_dataloader,
calculate_per_dataset_metrics,
Expand Down Expand Up @@ -922,3 +922,143 @@ def test_validate_generator_output_invalid_rewards():

generator_output["rewards"] = [[0.5, 0.6], [0.7, 0.8]]
validate_generator_output(len(input_batch["prompts"]), generator_output)


# ============================================================
# Step-wise validation tests
# ============================================================


def _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3), contiguous=True):
"""Helper to build a step-wise GeneratorOutput for testing."""
items = []
for traj_idx in range(n_trajectories):
n_steps = steps_per_traj[traj_idx]
tid = TrajectoryID(instance_id=str(traj_idx), repetition_id=0)
for step in range(n_steps):
is_last = step == n_steps - 1
prompt = list(range(10 + traj_idx * 100, 10 + traj_idx * 100 + 3 + step))
resp = list(range(50 + traj_idx * 100 + step * 10, 50 + traj_idx * 100 + step * 10 + 3))
reward = [0.0, 0.0, float(traj_idx + 1) if is_last else 0.0]
items.append((prompt, resp, reward, [1, 1, 1], is_last, tid))

if not contiguous:
max_steps = max(steps_per_traj)
reordered = []
for step in range(max_steps):
for traj_idx in range(n_trajectories):
if step < steps_per_traj[traj_idx]:
idx = sum(steps_per_traj[:traj_idx]) + step
reordered.append(items[idx])
items = reordered

prompt_token_ids, response_ids, rewards, loss_masks = [], [], [], []
is_last_step, trajectory_ids = [], []
for prompt, resp, reward, mask, is_last, tid in items:
prompt_token_ids.append(prompt)
response_ids.append(resp)
rewards.append(reward)
loss_masks.append(mask)
is_last_step.append(is_last)
trajectory_ids.append(tid)

return {
"prompt_token_ids": prompt_token_ids,
"response_ids": response_ids,
"rewards": rewards,
"loss_masks": loss_masks,
"stop_reasons": ["complete"] * len(response_ids),
"rollout_metrics": {},
"rollout_logprobs": None,
"is_last_step": is_last_step,
"trajectory_ids": trajectory_ids,
}


def test_validate_stepwise_valid():
"""Valid step-wise output should pass validation."""
output = _make_stepwise_output(n_trajectories=3, steps_per_traj=(1, 2, 3))
validate_generator_output(num_prompts=3, generator_output=output, step_wise=True)


def test_validate_stepwise_single_step_trajectories():
"""All single-step trajectories should pass."""
output = _make_stepwise_output(n_trajectories=4, steps_per_traj=(1, 1, 1, 1))
validate_generator_output(num_prompts=4, generator_output=output, step_wise=True)


def test_validate_stepwise_missing_is_last_step():
"""Missing is_last_step should fail."""
output = _make_stepwise_output()
del output["is_last_step"]
with pytest.raises(AssertionError, match="is_last_step.*missing"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_missing_trajectory_ids():
"""Missing trajectory_ids should fail."""
output = _make_stepwise_output()
del output["trajectory_ids"]
with pytest.raises(AssertionError, match="trajectory_ids.*missing"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_is_last_step_length_mismatch():
"""is_last_step length mismatch should fail."""
output = _make_stepwise_output()
output["is_last_step"] = output["is_last_step"][:-1]
with pytest.raises(AssertionError, match="is_last_step length"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_last_element_not_true():
"""is_last_step[-1] must be True."""
output = _make_stepwise_output()
output["is_last_step"][-1] = False
with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_non_contiguous():
"""Non-contiguous trajectory ordering should fail."""
output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2), contiguous=False)
with pytest.raises(AssertionError, match="Non-contiguous trajectory"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_boundary_without_is_last():
"""Trajectory boundary where is_last_step is False should fail."""
output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 2))
# Traj 0 has steps at indices 0,1 and traj 1 at 2,3. Corrupt boundary.
output["is_last_step"][1] = False
with pytest.raises(AssertionError, match="Trajectory boundary at index 1"):
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_no_true_in_is_last_step():
"""is_last_step with no True values should fail."""
output = _make_stepwise_output(n_trajectories=1, steps_per_traj=(3,))
output["is_last_step"] = [False, False, False]
with pytest.raises(AssertionError, match="is_last_step\\[-1\\] must be True"):
validate_generator_output(num_prompts=1, generator_output=output, step_wise=True)


def test_validate_stepwise_num_prompts_not_checked():
"""In step-wise mode, num_prompts != num_responses is allowed (expansion)."""
output = _make_stepwise_output(n_trajectories=2, steps_per_traj=(2, 3))
# 5 step-samples from 2 prompts
validate_generator_output(num_prompts=2, generator_output=output, step_wise=True)


def test_validate_stepwise_multiple_is_last_step_true_per_trajectory():
"""Multiple is_last_step=True within a single trajectory should fail.

A trajectory with 3 steps should have is_last_step=[False, False, True],
not [True, True, True]. Having multiple True values would corrupt the
cumsum(shifted_is_last_step) advantage broadcast.
"""
output = _make_stepwise_output(n_trajectories=1, steps_per_traj=(3,))
# Corrupt: mark all steps as last
output["is_last_step"] = [True, True, True]
with pytest.raises(AssertionError, match="is_last_step.*True.*trajectory continues"):
validate_generator_output(num_prompts=1, generator_output=output, step_wise=True)
Loading