Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/content/docs/tutorials/step-wise-training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ When step-wise is enabled, a batch of T trajectories with an average of M turns
- **Each mini-batch contains the same number of sequences** (`policy_mini_batch_size * n_samples`), but those sequences are now step-samples rather than full trajectories. The effective number of trajectories per mini-batch is reduced. The number of mini-batches (and hence optimizer steps) per training batch increases by the average number of turns — so if you have `train_batch_size=mini_batch_size=32` with an average of 3 turns, you get 3 optimizer steps instead of 1 for each training step. It is also possible that a mini-batch boundary falls mid-trajectory.
- **Advantages are computed on last steps only**, then broadcast to all steps of the same trajectory. This is mathematically equivalent to non-step-wise advantage computation for GRPO.
- **Training time grows as O(T²) vs O(T)**, since each trajectory of T turns becomes T sequences to forward (each with a growing prompt prefix), as opposed to 1 sequence. SkyRL will support prefix-aware merging of per-step sequences when the prefix matches (WIP), which brings the cost back to O(T) in the common case.
- **Metrics** like `generate/avg_sequence_length` are per-turn rather than per-trajectory.
- **Metrics** like `generate/avg_num_tokens` and `generate/avg_response_length` are per-turn rather than per-trajectory, since each training sample is a single turn.

Some algorithms have their behavior altered by step-wise decomposition, since each turn is now treated as its own sequence:

Expand Down
12 changes: 3 additions & 9 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,15 +676,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
training_input.metadata["trajectory_ids"] = [
trajectory_id.to_string() for trajectory_id in generator_output["trajectory_ids"]
]
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids)
for sample_response_ids, is_last_step in zip(response_ids, generator_output["is_last_step"])
if is_last_step
) / len(response_ids)
else:
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)
training_input.metadata["avg_response_length"] = sum(
len(sample_response_ids) for sample_response_ids in response_ids
) / len(response_ids)

logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
Expand Down
Loading