diff --git a/docs/content/docs/tutorials/step-wise-training.mdx b/docs/content/docs/tutorials/step-wise-training.mdx index ac1eccf22b..9459beff6b 100644 --- a/docs/content/docs/tutorials/step-wise-training.mdx +++ b/docs/content/docs/tutorials/step-wise-training.mdx @@ -183,6 +183,8 @@ for i, step_output in enumerate(trajectory_steps): rewards = [0.0] * len(step_output.response_ids) ``` +Note that SkyRL currently only supports trajectory-level reward for step-wise training. Therefore, the reward should be placed at the last step's last token, and all non-last-step rewards are ignored. We then use the last step's reward to estimate the advantage and broadcast it to previous turns. Because of this, you should only use outcome-based advantage estimators (`cfg.trainer.algorithm.advantage_estimator` in `grpo`, `rloo`, or `maxrl`); `reinforce++` and `gae` are rejected at config validation time. + ### 4. Ensure Contiguous Ordering All steps of trajectory A must appear before any steps of trajectory B in the output lists: diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 3b1ed89d0a..808ae8ac20 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -801,11 +801,23 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn is_last_step = data["is_last_step"].bool() index = np.array(data.metadata["uids"]) values = data["values"] - # Use the last step of each trajectory to compute advantages. Compatible with any advantage estimator - # NOTE(Charlie): so we ignore per-step rewards in step-wise training. + # Step-wise only supports outcome-based estimators (GRPO, RLOO, MAXRL); ensured by `validate_cfg`. + # We use the last step of each trajectory to compute advantages and broadcast them to + # all steps of that trajectory, so we ignore per-step rewards in step-wise training. + # We pass an all-ones mask here so the estimator returns the scalar advantage at every + # position. The real per-step `response_mask` is re-applied on broadcast below. + # Shapes: + # traj_ids, (batch_size,): trajectory id per step (cumsum of shifted is_last_step) + # last_step_advantages/returns, + # (num_traj, seqlen): scalar advantage/return per trajectory at every position + # last_step_advantages/returns[traj_ids], + # (batch_size, seqlen): broadcast to every step of the owning trajectory + # response_mask_float, + # (batch_size, seqlen): per-step response mask + last_step_response_mask = data["response_mask"][is_last_step] last_step_advantages, last_step_returns = ppo_utils.compute_advantages_and_returns( token_level_rewards=token_level_rewards[is_last_step], - response_mask=data["response_mask"][is_last_step], + response_mask=torch.ones_like(last_step_response_mask, dtype=torch.float), index=index[is_last_step.cpu().numpy()], adv_estimator=self.cfg.trainer.algorithm.advantage_estimator, values=values[is_last_step] if values is not None else None, @@ -814,16 +826,16 @@ def compute_advantages_and_returns(self, data: TrainingInputBatch) -> TrainingIn lambd=self.cfg.trainer.algorithm.lambd, grpo_norm_by_std=self.cfg.trainer.algorithm.grpo_norm_by_std, ) - # Broadcast each trajectory's advantage and return to all steps of each trajectory. traj_ids = ( torch.cat([torch.tensor([False], device=is_last_step.device), is_last_step[:-1]]).int().cumsum(dim=0) ) - num_groups = traj_ids[-1].item() + 1 - assert num_groups == len( + num_traj = traj_ids[-1].item() + 1 + assert num_traj == len( last_step_advantages - ), f"number of groups {num_groups} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed" - advantages = last_step_advantages[traj_ids] - returns = last_step_returns[traj_ids] + ), f"num_traj {num_traj} doesn't match the number of trajectories as given by `is_last_step` {len(last_step_advantages)}. The `is_last_step` tensor is likely malformed" + response_mask_float = data["response_mask"].to(last_step_advantages.dtype) + advantages = last_step_advantages[traj_ids] * response_mask_float + returns = last_step_returns[traj_ids] * response_mask_float else: advantages, returns = ppo_utils.compute_advantages_and_returns( token_level_rewards=token_level_rewards, diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index f6ef6a2dc4..bf8a22e4f1 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -274,6 +274,20 @@ def validate_cfg(cfg: SkyRLTrainConfig): f"Must be one of {available_advantage_estimators}" ) + # Step-wise training collapses each trajectory to a single scalar advantage that is broadcast + # uniformly to every step's response tokens. This only makes sense for outcome-based estimators. + # Temporal estimators (GAE, REINFORCE++) produce per-token advantages, which the broadcast + # discards. Reject the combination explicitly. + if cfg.generator.step_wise_trajectories and cfg.trainer.algorithm.advantage_estimator in ("gae", "reinforce++"): + raise ValueError( + f"advantage_estimator={cfg.trainer.algorithm.advantage_estimator!r} is not supported with " + f"step_wise_trajectories=True. The step-wise branch collapses each trajectory to a single " + f"scalar advantage, which discards the per-token temporal structure these estimators produce, " + f"and the estimator only sees the last step's slice — there is no cross-step temporal " + f"connection. Use an outcome-based estimator (grpo, rloo, maxrl) or disable " + f"step_wise_trajectories." + ) + assert cfg.trainer.algorithm.loss_reduction in ( "token_mean", "token_mean_legacy", diff --git a/tests/train/step_wise/__init__.py b/tests/train/step_wise/__init__.py new file mode 100644 index 0000000000..712c952258 --- /dev/null +++ b/tests/train/step_wise/__init__.py @@ -0,0 +1 @@ +# CPU tests for step-wise training. diff --git a/tests/train/step_wise/test_config.py b/tests/train/step_wise/test_config.py new file mode 100644 index 0000000000..136b8614f8 --- /dev/null +++ b/tests/train/step_wise/test_config.py @@ -0,0 +1,52 @@ +"""CPU tests for step-wise training config validation. + +Run: + uv run --isolated --extra dev --extra skyrl-train pytest tests/train/step_wise/ +""" + +from unittest.mock import patch + +import pytest + +from skyrl.train.utils.utils import validate_cfg +from tests.train.util import example_dummy_config + + +@pytest.fixture +def dummy_config(): + return example_dummy_config() + + +@pytest.mark.parametrize( + ("estimator", "should_raise"), + [ + ("gae", True), + ("reinforce++", True), + ("grpo", False), + ("rloo", False), + ("maxrl", False), + ], +) +@patch("skyrl.train.utils.utils.validate_batch_sizes", new=lambda cfg: None) +@patch("skyrl.train.utils.utils.validate_generator_cfg", new=lambda cfg: None) +def test_validate_cfg_step_wise_estimator_compatibility(dummy_config, estimator, should_raise): + """``validate_cfg`` must reject step-wise training with temporal estimators (GAE, REINFORCE++) + and accept it with outcome-based estimators (GRPO, RLOO, MAXRL). + + Step-wise training collapses each trajectory to a single scalar advantage broadcast uniformly + to every step's response tokens. The temporal credit assignment that GAE / REINFORCE++ produce + is lost in that collapse, so we refuse the combination at startup. + + ``validate_batch_sizes`` and ``validate_generator_cfg`` are patched to no-ops so the test + exercises only the step-wise compatibility check on the minimal dummy config. + """ + dummy_config.generator.step_wise_trajectories = True + dummy_config.trainer.algorithm.advantage_estimator = estimator + if estimator == "gae": + dummy_config.trainer.critic.model.path = "dummy-critic-path" + + if should_raise: + with pytest.raises(ValueError, match="not supported with step_wise_trajectories"): + validate_cfg(dummy_config) + else: + validate_cfg(dummy_config) diff --git a/tests/train/test_trainer.py b/tests/train/test_trainer.py index dea5054b6b..2bb81e1c6b 100644 --- a/tests/train/test_trainer.py +++ b/tests/train/test_trainer.py @@ -168,6 +168,90 @@ def test_calc_advantages_and_returns(mock_compute_adv_and_ret, dummy_config): ) +def test_calc_advantages_and_returns_step_wise_broadcast(dummy_config): + """Regression test for the step-wise advantage broadcast across trajectories. + + See https://github.com/NovaSky-AI/SkyRL/issues/1492. + """ + dummy_config.generator.step_wise_trajectories = True + dummy_config.trainer.algorithm.advantage_estimator = "grpo" + dummy_config.trainer.algorithm.grpo_norm_by_std = False + + trainer = RayPPOTrainer( + cfg=dummy_config, + tracker=None, + tokenizer=None, + train_dataset=DummyDataset(), + eval_dataset=DummyDataset(), + inference_engine_client=None, + generator=dummy_generator, + ) + + # Two trajectories (A, B), each with two steps (one intermediate, one last). + # Response-level tensors are right-aligned within (batch, max_response) — see + # ``convert_prompts_responses_to_batch_tensors`` in ``skyrl/train/dataset/preprocess.py``. + # Intermediate and last steps have different response lengths so their mask tails live at + # different positions; this is what exposes the broadcast bug. + # + # row traj step resp_len response_mask + # ─── ──── ────────── ──────── ────────────────── + # 0 A intermed. 4 [0, 0, 1, 1, 1, 1] + # 1 A last 1 [0, 0, 0, 0, 0, 1] + # 2 B intermed. 3 [0, 0, 0, 1, 1, 1] + # 3 B last 2 [0, 0, 0, 0, 1, 1] + batch_size, seqlen = 4, 6 + response_mask = torch.tensor( + [ + [0, 0, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 0, 1, 1], + ], + dtype=torch.int32, + ) + # Reward lives at the last token of each trajectory's last step — i.e., at the tail + # position where that step's response_mask is 1. Traj A -> 2.0, Traj B -> 0.0. + rewards = torch.zeros(batch_size, seqlen) + rewards[1, -1] = 2.0 + is_last_step = torch.tensor([False, True, False, True]) + + data = TrainingInputBatch( + { + "sequences": torch.zeros(batch_size, seqlen, dtype=torch.long), + "attention_mask": torch.ones(batch_size, seqlen, dtype=torch.int32), + "loss_mask": response_mask.clone(), + "response_mask": response_mask, + "rewards": rewards, + "values": torch.zeros(batch_size, seqlen), + "is_last_step": is_last_step, + }, + ) + # Both trajectories share a GRPO group so the group has the 2 samples needed to produce a mean. + data.metadata = { + "uids": np.array(["grp0", "grp0", "grp0", "grp0"]), + "response_length": seqlen, + "avg_response_length": (4 + 1 + 3 + 2) / 4, + } + + data = trainer.compute_advantages_and_returns(data) + + # GRPO without std normalization: group mean = (2.0 + 0.0) / 2 = 1.0, so + # scalar_A = 2.0 - 1.0 = 1.0 and scalar_B = 0.0 - 1.0 = -1.0. + # Each step's advantages must equal `scalar * response_mask` for THAT step, so the + # full advantage tensor is the row-wise product of the per-trajectory scalar and the + # right-aligned per-step response mask. Returns equal advantages for GRPO. + expected_advantages = torch.tensor( + [ + [0.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, -1.0, -1.0, -1.0], + [0.0, 0.0, 0.0, 0.0, -1.0, -1.0], + ] + ) + assert torch.allclose(data["advantages"], expected_advantages) + assert torch.allclose(data["returns"], expected_advantages) + + def test_micro_batches_accumulated_initialized(): """Test that _micro_batches_accumulated is initialized to 0 in worker __init__."""