[train] Add validation for step-wise GeneratorOutput#1281
[train] Add validation for step-wise GeneratorOutput#1281CharlieFRuan merged 1 commit intoNovaSky-AI:mainfrom
Conversation
3eee05b to
5e94393
Compare
5e94393 to
b19a9b5
Compare
There was a problem hiding this comment.
Code Review
The pull request introduces crucial validation for step-wise generator outputs, which was previously missing. This addresses potential silent errors in training results due to malformed or non-contiguous trajectory data, particularly important for the advantage broadcast mechanism. The changes involve modifying validate_generator_output to handle step-wise specific checks and extracting these into a new helper function _validate_step_wise_fields. Comprehensive unit tests have been added to cover various validation scenarios, significantly improving the robustness of the training pipeline.
There was a problem hiding this comment.
🟡 Step-wise avg_response_length divides by total steps instead of number of trajectories
At skyrl/train/trainer.py:679-683, the step-wise avg_response_length calculation sums response lengths only for steps where is_last_step=True, but divides by len(response_ids) (total number of steps across all trajectories). For example, with 2 trajectories having 3 and 2 steps respectively, and last-step response lengths of 10 and 8: numerator = 18, denominator = 5, result = 3.6 — instead of the correct 9.0. The denominator should be the number of trajectories (i.e., sum(1 for x in generator_output["is_last_step"] if x)). Compare with the non-step-wise version at line 685-687 which correctly uses len(response_ids) as denominator when summing all responses.
(Refers to lines 679-683)
Was this helpful? React with 👍 or 👎 to provide feedback.
There was a problem hiding this comment.
will fix in a separate PR
Fixes comment #1281 (comment) There was no good reason to only account for `is_last_step` since response IDs are not cumulative in step-wise training (unlike input tokens, which are cumulative)
Alternatively we do not care the step order and not use the `cumsum`
trick in advantage broadcast
### Summary
Previously, `validate_generator_output()` was **skipped entirely** when
`step_wise_trajectories=True`:
```python
if not self.cfg.generator.step_wise_trajectories:
validate_generator_output(len(input_batch["prompts"]), generator_output)
```
This meant step-wise generator outputs had no validation at all —
malformed `is_last_step`, missing `trajectory_ids`, or non-contiguous
trajectory ordering would silently produce wrong training results.
The non-contiguous case is particularly dangerous: the trainer's
advantage broadcast uses a `cumsum` trick that assumes all steps of the
same trajectory are adjacent in the batch. If steps are interleaved
across trajectories, advantages are silently mapped to the wrong steps
with no error.
### Changes
**`skyrl/train/utils/trainer_utils.py`**
- Added `step_wise: bool = False` parameter to
`validate_generator_output()` (backward compatible — existing callers
are unaffected)
- Extracted `_validate_step_wise_fields()` for step-wise specific
checks:
- `is_last_step` and `trajectory_ids` are present and correctly sized
- `is_last_step[-1]` is `True` (last sample must be a trajectory's final
step)
- **Contiguous ordering**: all steps of the same trajectory are adjacent
(catches the silent `cumsum` bug)
- **Boundary alignment**: `is_last_step[i]` is `True` wherever (and only
when) `trajectory_ids` changes between consecutive samples
- In step-wise mode, `num_prompts != num_responses` is allowed (step
expansion is expected)
**`skyrl/train/trainer.py`**
- Changed from skipping validation to calling with `step_wise=True`:
```python
validate_generator_output(
len(input_batch["prompts"]),
generator_output,
step_wise=self.cfg.generator.step_wise_trajectories,
)
```
**`tests/train/test_trainer_utils.py`**
- 9 new tests covering all step-wise validation cases
### Test plan
- [x] `pytest tests/train/test_trainer_utils.py` — all 44 tests pass (35
existing + 9 new)
- [x] Existing non-step-wise validation tests unaffected (backward
compatible `step_wise=False` default)
- [x] New tests cover: valid output, single-step trajectories, missing
fields, length mismatches, non-contiguous ordering, boundary
misalignment, all-False `is_last_step`
### E2E test
Ran the multi-turn gsm8k example E2E. Made sure it is indeed multi-turn
since `generate/batch_num_seq` is ~6800 rather than 2560 (512 * 5)
```bash
# Run training (script defaults to 1 GPU, override for 8 GPU + step-wise multi-turn)
bash examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh \
generator.step_wise_trajectories=true \
generator.use_conversation_multi_turn=true \
generator.max_turns=5 \
```
Fixes comment #1281 (comment) There was no good reason to only account for `is_last_step` since response IDs are not cumulative in step-wise training (unlike input tokens, which are cumulative)
Alternatively we do not care the step order and not use the
cumsumtrick in advantage broadcastSummary
Previously,
validate_generator_output()was skipped entirely whenstep_wise_trajectories=True:This meant step-wise generator outputs had no validation at all — malformed
is_last_step, missingtrajectory_ids, or non-contiguous trajectory ordering would silently produce wrong training results.The non-contiguous case is particularly dangerous: the trainer's advantage broadcast uses a
cumsumtrick that assumes all steps of the same trajectory are adjacent in the batch. If steps are interleaved across trajectories, advantages are silently mapped to the wrong steps with no error.Changes
skyrl/train/utils/trainer_utils.pystep_wise: bool = Falseparameter tovalidate_generator_output()(backward compatible — existing callers are unaffected)_validate_step_wise_fields()for step-wise specific checks:is_last_stepandtrajectory_idsare present and correctly sizedis_last_step[-1]isTrue(last sample must be a trajectory's final step)cumsumbug)is_last_step[i]isTruewherever (and only when)trajectory_idschanges between consecutive samplesnum_prompts != num_responsesis allowed (step expansion is expected)skyrl/train/trainer.pystep_wise=True:tests/train/test_trainer_utils.pyTest plan
pytest tests/train/test_trainer_utils.py— all 44 tests pass (35 existing + 9 new)step_wise=Falsedefault)is_last_stepE2E test
Ran the multi-turn gsm8k example E2E. Made sure it is indeed multi-turn since
generate/batch_num_seqis ~6800 rather than 2560 (512 * 5)# Run training (script defaults to 1 GPU, override for 8 GPU + step-wise multi-turn) bash examples/train/turn_level_rewards/run_gsm8k_multi_turn.sh \ generator.step_wise_trajectories=true \ generator.use_conversation_multi_turn=true \ generator.max_turns=5 \