Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ for i, step_output in enumerate(trajectory_steps):
rewards = [0.0] * len(step_output.response_ids)
```

Note that SkyRL currently only supports trajectory-level reward for step-wise training. Therefore, the reward should be placed at the last step's last token, and all non-last-step rewards are ignored. We then use the last step's reward to estimate the advantage and broadcast it to previous turns. Because of this, you should only use outcome-based advantage estimators (`cfg.trainer.algorithm.advantage_estimator` in `grpo`, `rloo`, or `maxrl`); `reinforce++` and `gae` are rejected at config validation time.

### 4. Ensure Contiguous Ordering

All steps of trajectory A must appear before any steps of trajectory B in the output lists:
Expand Down
30 changes: 21 additions & 9 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,11 +801,23 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn
is_last_step = data["is_last_step"].bool()
index = np.array(data.metadata["uids"])
values = data["values"]
# Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator
# NOTE(Charlie): so we ignore per-step rewards in step-wise training.
# Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`.
# We use the last step of each trajectory to compute advantages and broadcast them to
# all steps of that trajectory, so we ignore per-step rewards in step-wise training.
# We pass an all-ones mask here so the estimator returns the scalar advantage at every
# position. The real per-step `response_mask` is re-applied on broadcast below. See issue #1492.
# Shapes:
# traj_ids, (batch_size,): trajectory id per step (cumsum of shifted is_last_step)
# last_step_advantages/returns,
# (num_traj, seqlen): scalar advantage/return per trajectory at every position
# last_step_advantages/returns[traj_ids],
# (batch_size, seqlen): broadcast to every step of the owning trajectory
# response_mask_float,
# (batch_size, seqlen): per-step response mask
last_step_response_mask = data["response_mask"][is_last_step]
last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards[is_last_step],
response_mask=data["response_mask"][is_last_step],
response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float),
index=index[is_last_step.cpu().numpy()],
adv_estimator=self.cfg.trainer.algorithm.advantage_estimator,
values=values[is_last_step] if values is not None else None,
Expand All @@ -814,16 +826,16 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn
lambd=self.cfg.trainer.algorithm.lambd,
grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std,
)
# Broadcast each trajectory's advantage and return to all steps of each trajectory.
traj_ids = (
torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0)
)
num_groups = traj_ids[-1].item() + 1
assert num_groups == len(
num_traj = traj_ids[-1].item() + 1
assert num_traj == len(
last_step_advantages
), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
advantages = last_step_advantages[traj_ids]
returns = last_step_returns[traj_ids]
), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed"
response_mask_float = data["response_mask"].to(last_step_advantages.dtype)
advantages = last_step_advantages[traj_ids] * response_mask_float
returns = last_step_returns[traj_ids] * response_mask_float
else:
advantages, returns = ppo_utils.compute_advantages_and_returns(
token_level_rewards=token_level_rewards,
Expand Down
16 changes: 16 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,22 @@ def validate_cfg(cfg: SkyRLTrainConfig):
f"Must be one of {available_advantage_estimators}"
)

# Step-wise training collapses each trajectory to a single scalar advantage that is broadcast
# uniformly to every step's response tokens. This only makes sense for outcome-based estimators
# that return `scalar * response_mask`. Temporal estimators (GAE, REINFORCE++) produce per-token
# advantages via backward discounted returns, which the broadcast discards — and in step-wise
# mode they only see the last step's slice anyway, so there is no cross-step credit assignment
# to preserve. Reject the combination explicitly rather than silently producing wrong gradients.
if cfg.generator.step_wise_trajectories and cfg.trainer.algorithm.advantage_estimator in ("gae", "reinforce++"):
raise ValueError(
f"advantage_estimator={cfg.trainer.algorithm.advantage_estimator!r} is not supported with "
f"step_wise_trajectories=True. The step-wise branch collapses each trajectory to a single "
f"scalar advantage, which discards the per-token temporal structure these estimators produce, "
f"and the estimator only sees the last step's slice — there is no cross-step temporal "
f"connection. Use an outcome-based estimator (grpo, rloo, maxrl) or disable "
f"step_wise_trajectories."
)

assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"token_mean_legacy",
Expand Down
1 change: 1 addition & 0 deletions tests/train/step_wise/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# CPU tests for step-wise training.
52 changes: 52 additions & 0 deletions tests/train/step_wise/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""CPU tests for step-wise training config validation.

Run:
uv run --isolated --extra dev --extra skyrl-train pytest tests/train/step_wise/
"""

from unittest.mock import patch

import pytest

from skyrl.train.utils.utils import validate_cfg
from tests.train.util import example_dummy_config


@pytest.fixture
def dummy_config():
return example_dummy_config()


@pytest.mark.parametrize(
("estimator", "should_raise"),
[
("gae", True),
("reinforce++", True),
("grpo", False),
("rloo", False),
("maxrl", False),
],
)
@patch("skyrl.train.utils.utils.validate_batch_sizes", new=lambda cfg: None)
@patch("skyrl.train.utils.utils.validate_generator_cfg", new=lambda cfg: None)
def test_validate_cfg_step_wise_estimator_compatibility(dummy_config, estimator, should_raise):
"""``validate_cfg`` must reject step-wise training with temporal estimators (GAE, REINFORCE++)
and accept it with outcome-based estimators (GRPO, RLOO, MAXRL).

Step-wise training collapses each trajectory to a single scalar advantage broadcast uniformly
to every step's response tokens. The temporal credit assignment that GAE / REINFORCE++ produce
is lost in that collapse, so we refuse the combination at startup.

``validate_batch_sizes`` and ``validate_generator_cfg`` are patched to no-ops so the test
exercises only the step-wise compatibility check on the minimal dummy config.
"""
dummy_config.generator.step_wise_trajectories = True
dummy_config.trainer.algorithm.advantage_estimator = estimator
if estimator == "gae":
dummy_config.trainer.critic.model.path = "dummy-critic-path"

if should_raise:
with pytest.raises(ValueError, match="not supported with step_wise_trajectories"):
validate_cfg(dummy_config)
else:
validate_cfg(dummy_config)
84 changes: 84 additions & 0 deletions tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,90 @@ def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config):
)


def test_calc_advantages_and_returns_step_wise_broadcast(dummy_config):
"""Regression test for the step-wise advantage broadcast across trajectories.

See https://github.com/NovaSky-AI/SkyRL/issues/1492.
"""
dummy_config.generator.step_wise_trajectories = True
dummy_config.trainer.algorithm.advantage_estimator = "grpo"
dummy_config.trainer.algorithm.grpo_norm_by_std = False

trainer = RayPPOTrainer(
cfg=dummy_config,
tracker=None,
tokenizer=None,
train_dataset=DummyDataset(),
eval_dataset=DummyDataset(),
inference_engine_client=None,
generator=dummy_generator,
)

# Two trajectories (A, B), each with two steps (one intermediate, one last).
# Response-level tensors are right-aligned within (batch, max_response) — see
# ``convert_prompts_responses_to_batch_tensors`` in ``skyrl/train/dataset/preprocess.py``.
# Intermediate and last steps have different response lengths so their mask tails live at
# different positions; this is what exposes the broadcast bug.
#
# row traj step resp_len response_mask
# ─── ──── ────────── ──────── ──────────────────
# 0 A intermed. 4 [0, 0, 1, 1, 1, 1]
# 1 A last 1 [0, 0, 0, 0, 0, 1]
# 2 B intermed. 3 [0, 0, 0, 1, 1, 1]
# 3 B last 2 [0, 0, 0, 0, 1, 1]
batch_size, seqlen = 4, 6
response_mask = torch.tensor(
[
[0, 0, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 0, 1, 1],
],
dtype=torch.int32,
)
# Reward lives at the last token of each trajectory's last step — i.e., at the tail
# position where that step's response_mask is 1. Traj A -> 2.0, Traj B -> 0.0.
rewards = torch.zeros(batch_size, seqlen)
rewards[1, -1] = 2.0
is_last_step = torch.tensor([False, True, False, True])

data = TrainingInputBatch(
{
"sequences": torch.zeros(batch_size, seqlen, dtype=torch.long),
"attention_mask": torch.ones(batch_size, seqlen, dtype=torch.int32),
"loss_mask": response_mask.clone(),
"response_mask": response_mask,
"rewards": rewards,
"values": torch.zeros(batch_size, seqlen),
"is_last_step": is_last_step,
},
)
# Both trajectories share a GRPO group so the group has the 2 samples needed to produce a mean.
data.metadata = {
"uids": np.array(["grp0", "grp0", "grp0", "grp0"]),
"response_length": seqlen,
"avg_response_length": (4 + 1 + 3 + 2) / 4,
}

data = trainer.compute_advantages_and_returns(data)

# GRPO without std normalization: group mean = (2.0 + 0.0) / 2 = 1.0, so
# scalar_A = 2.0 - 1.0 = 1.0 and scalar_B = 0.0 - 1.0 = -1.0.
# Each step's advantages must equal `scalar * response_mask` for THAT step, so the
# full advantage tensor is the row-wise product of the per-trajectory scalar and the
# right-aligned per-step response mask. Returns equal advantages for GRPO.
expected_advantages = torch.tensor(
[
[0.0, 0.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, -1.0, -1.0, -1.0],
[0.0, 0.0, 0.0, 0.0, -1.0, -1.0],
]
)
assert torch.allclose(data["advantages"], expected_advantages)
assert torch.allclose(data["returns"], expected_advantages)


def test_micro_batches_accumulated_initialized():
"""Test that _micro_batches_accumulated is initialized to 0 in worker __init__."""

Expand Down
Loading