Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ for i, step_output in enumerate(trajectory_steps):
rewards = [0.0] * len(step_output.response_ids)
```

Note that SkyRL currently only supports trajectory-level reward for step-wise training. Therefore, the reward should be placed at the last step's last token, and all non-last-step rewards are ignored. We then use the last step's reward to estimate the advantage and broadcast it to previous turns. Because of this, you should only use outcome-based advantage estimators (`cfg.trainer.algorithm.advantage_estimator` in `grpo`, `rloo`, or `maxrl`); `reinforce++` and `gae` are rejected at config validation time.

### 4. Ensure Contiguous Ordering

All steps of trajectory A must appear before any steps of trajectory B in the output lists:
Expand Down
30 changes: 21 additions & 9 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,11 +801,23 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn
is_last_step = data["is_last_step"].bool()
index = np.array(data.metadata["uids"])
values = data["values"]
# Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
# NOTE(Charlie): so we ignore per-step rewards in step-wise training.
# Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
# We use the last step of each trajectory to compute advantages and broadcast them to
# all steps of that trajectory, so we ignore per-step rewards in step-wise training.
# We pass an all-ones mask here so the estimator returns the scalar advantage at every
# position. The real per-step `response_mask` is re-applied on broadcast below.
# Shapes:
# traj_ids, (batch_size,): trajectory id per step (cumsum of shifted is_last_step)
# last_step_advantages/returns,
# (num_traj, seqlen): scalar advantage/return per trajectory at every position
# last_step_advantages/returns[traj_ids],
# (batch_size, seqlen): broadcast to every step of the owning trajectory
# response_mask_float,
# (batch_size, seqlen): per-step response mask
last_step_response_mask = data["response_mask"][is_last_step]
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=data["response_mask"][is_last_step],
response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
Expand All @@ -814,16 +826,16 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
# Broadcast each trajectory's advantage and return to all steps of each trajectory.
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_groups = traj_ids[-1].item() + 1
assert num_groups == len(
num_traj = traj_ids[-1].item() + 1
assert num_traj == len(
last_step_advantages
), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
advantages = last_step_advantages[traj_ids]
returns = last_step_returns[traj_ids]
), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
advantages = last_step_advantages[traj_ids] * response_mask_float
returns = last_step_returns[traj_ids] * response_mask_float
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
Expand Down
14 changes: 14 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,20 @@ def validate_cfg(cfg: SkyRLTrainConfig):
f"Must be one of {available_advantage_estimators}"
)

# Step-wise training collapses each trajectory to a single scalar advantage that is broadcast
# uniformly to every step's response tokens. This only makes sense for outcome-based estimators.
# Temporal estimators (GAE, REINFORCE++) produce per-token advantages, which the broadcast
# discards. Reject the combination explicitly.
if cfg.generator.step_wise_trajectories and cfg.trainer.algorithm.advantage_estimator in ("gae", "reinforce++"):
raise ValueError(
f"advantage_estimator={cfg.trainer.algorithm.advantage_estimator!r} is not supported with "
f"step_wise_trajectories=True. The step-wise branch collapses each trajectory to a single "
f"scalar advantage, which discards the per-token temporal structure these estimators produce, "
f"and the estimator only sees the last step's slice — there is no cross-step temporal "
f"connection. Use an outcome-based estimator (grpo, rloo, maxrl) or disable "
f"step_wise_trajectories."
)

assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"token_mean_legacy",
Expand Down
1 change: 1 addition & 0 deletions tests/train/step_wise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# CPU tests for step-wise training.
52 changes: 52 additions & 0 deletions tests/train/step_wise/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""CPU tests for step-wise training config validation.

Run:
uv run --isolated --extra dev --extra skyrl-train pytest tests/train/step_wise/
"""

from unittest.mock import patch

import pytest

from skyrl.train.utils.utils import validate_cfg
from tests.train.util import example_dummy_config


@pytest.fixture
def dummy_config():
return example_dummy_config()


@pytest.mark.parametrize(
("estimator", "should_raise"),
[
("gae", True),
("reinforce++", True),
("grpo", False),
("rloo", False),
("maxrl", False),
],
)
@patch("skyrl.train.utils.utils.validate_batch_sizes", new=lambda cfg: None)
@patch("skyrl.train.utils.utils.validate_generator_cfg", new=lambda cfg: None)
def test_validate_cfg_step_wise_estimator_compatibility(dummy_config, estimator, should_raise):
"""``validate_cfg`` must reject step-wise training with temporal estimators (GAE, REINFORCE++)
and accept it with outcome-based estimators (GRPO, RLOO, MAXRL).

Step-wise training collapses each trajectory to a single scalar advantage broadcast uniformly
to every step's response tokens. The temporal credit assignment that GAE / REINFORCE++ produce
is lost in that collapse, so we refuse the combination at startup.

``validate_batch_sizes`` and ``validate_generator_cfg`` are patched to no-ops so the test
exercises only the step-wise compatibility check on the minimal dummy config.
"""
dummy_config.generator.step_wise_trajectories = True
dummy_config.trainer.algorithm.advantage_estimator = estimator
if estimator == "gae":
dummy_config.trainer.critic.model.path = "dummy-critic-path"

if should_raise:
with pytest.raises(ValueError, match="not supported with step_wise_trajectories"):
validate_cfg(dummy_config)
else:
validate_cfg(dummy_config)
84 changes: 84 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,90 @@ def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config):
)


def test_calc_advantages_and_returns_step_wise_broadcast(dummy_config):
"""Regression test for the step-wise advantage broadcast across trajectories.

See https://github.com/NovaSky-AI/SkyRL/issues/1492.
"""
dummy_config.generator.step_wise_trajectories = True
dummy_config.trainer.algorithm.advantage_estimator = "grpo"
dummy_config.trainer.algorithm.grpo_norm_by_std = False

trainer = RayPPOTrainer(
cfg=dummy_config,
tracker=None,
tokenizer=None,
train_dataset=DummyDataset(),
eval_dataset=DummyDataset(),
inference_engine_client=None,
generator=dummy_generator,
)

# Two trajectories (A, B), each with two steps (one intermediate, one last).
# Response-level tensors are right-aligned within (batch, max_response) — see
# ``convert_prompts_responses_to_batch_tensors`` in ``skyrl/train/dataset/preprocess.py``.
# Intermediate and last steps have different response lengths so their mask tails live at
# different positions; this is what exposes the broadcast bug.
#
# row traj step resp_len response_mask
# ─── ──── ────────── ──────── ──────────────────
# 0 A intermed. 4 [0, 0, 1, 1, 1, 1]
# 1 A last 1 [0, 0, 0, 0, 0, 1]
# 2 B intermed. 3 [0, 0, 0, 1, 1, 1]
# 3 B last 2 [0, 0, 0, 0, 1, 1]
batch_size, seqlen = 4, 6
response_mask = torch.tensor(
[
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
],
dtype=torch.int32,
)
# Reward lives at the last token of each trajectory's last step — i.e., at the tail
# position where that step's response_mask is 1. Traj A -> 2.0, Traj B -> 0.0.
rewards = torch.zeros(batch_size, seqlen)
rewards[1, -1] = 2.0
is_last_step = torch.tensor([False, True, False, True])

data = TrainingInputBatch(
{
"sequences": torch.zeros(batch_size, seqlen, dtype=torch.long),
"attention_mask": torch.ones(batch_size, seqlen, dtype=torch.int32),
"loss_mask": response_mask.clone(),
"response_mask": response_mask,
"rewards": rewards,
"values": torch.zeros(batch_size, seqlen),
"is_last_step": is_last_step,
},
)
# Both trajectories share a GRPO group so the group has the 2 samples needed to produce a mean.
data.metadata = {
"uids": np.array(["grp0", "grp0", "grp0", "grp0"]),
"response_length": seqlen,
"avg_response_length": (4 + 1 + 3 + 2) / 4,
}

data = trainer.compute_advantages_and_returns(data)

# GRPO without std normalization: group mean = (2.0 + 0.0) / 2 = 1.0, so
# scalar_A = 2.0 - 1.0 = 1.0 and scalar_B = 0.0 - 1.0 = -1.0.
# Each step's advantages must equal `scalar * response_mask` for THAT step, so the
# full advantage tensor is the row-wise product of the per-trajectory scalar and the
# right-aligned per-step response mask. Returns equal advantages for GRPO.
expected_advantages = torch.tensor(
[
[0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
[0.0, 0.0, 0.0, 0.0, -1.0, -1.0],
]
)
assert torch.allclose(data["advantages"], expected_advantages)
assert torch.allclose(data["returns"], expected_advantages)


def test_micro_batches_accumulated_initialized():
"""Test that _micro_batches_accumulated is initialized to 0 in worker __init__."""

Expand Down
Loading