[train][step-wise] Three correctness/efficiency fixes for step-wise training#1539
[train][step-wise] Three correctness/efficiency fixes for step-wise training#1539CharlieFRuan wants to merge 3 commits intomainfrom
Conversation
For the step-wise path, `agent_loop_state.input_ids` is only refreshed inside the loop by `apply_chat_template(chat_history, ...)`. The previous layout checked `len(input_ids) > max_input_length` BEFORE that refresh, so at turn N the check ran against the stale input_ids left over from turn N-1's re-tokenize — which reflected chat_history through turn N-2. Re-tokenizing then produced the correct (now-bigger) prompt for turn N and we generated anyway, recording a step whose `prompt_ids` could exceed `max_input_length`. Observed in Search-R1 step-wise runs as `generate/batch_padded_seq_len` spikes up to ~10k vs ~4k in the non-step-wise path. Fix: re-tokenize first, then check length. Adds a regression test that drives chat_history growth through a mock apply_chat_template so the current-turn prompt length exceeds max_input_length; the test expects exactly one emitted step and no step's prompt exceeding the limit.
In `agent_loop`, when `use_conversation_multi_turn=False` and the final response does not already end with the tokenizer's EOS, we manually append one so downstream parsing sees a terminated sequence. That token was not actually produced by the model, so it has no real logprob — training on it pushes gradients toward a fabricated target. Switch its loss_mask entry from 1 to 0 (the TODO that flagged this case). The token still ends up in response_ids so the sequence terminates cleanly, but it no longer contributes to the loss. Updates the two tests that asserted the appended EOS was masked-in.
For step-wise training, each turn emits its own training sample — there is no
requirement for the response to embed the observation that followed the
model's generation. Concatenating obs onto `response_ids` (the prior behavior)
kept the obs out of the loss (mask=0) but still forced the training forward
pass to compute logits over those positions, inflating per-step sequence
length by thousands of tokens for retrieval-heavy envs.
With gen-only responses, obs appears in the next turn's prompt (capped at
`max_input_length` by the preceding length-check fix), so per-step sequences
stay at most `max_input_length + max_generate_length`. Observed
`batch_padded_seq_len: 9434` with obs-in-response vs capped ~4596 without.
Adds `get_turn_gen_only_*` helpers on `TurnOutput` so the step-wise branch
calls clean accessors rather than slicing the obs-padded forms.
`merge_stepwise_output`'s `obs_delta = prompt_{i}[len(full_merged):]` branch
recovers the obs tokens on merge, so prefix-aware merging still works.
Advantage compute is unchanged: it sums token-level rewards, and the single
non-zero reward still lands at the last output token position.
There was a problem hiding this comment.
Code Review
This pull request introduces 'gen-only' variants for step-wise generation in SkyRLGymGenerator, ensuring that observation tokens are excluded from training samples to maintain consistent sequence lengths. It also moves the input length validation to occur after re-tokenization and masks post-hoc appended EOS tokens in the loss mask. A potential IndexError was identified in get_turn_gen_only_rollout_expert_indices when handling empty expert indices with an added EOS.
| def get_turn_gen_only_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]: | ||
| """Like `get_turn_rollout_expert_indices()` but without the trailing observation pads.""" | ||
| if self.rollout_expert_indices is None: | ||
| return None | ||
| if not self.rollout_expert_indices: | ||
| return self.rollout_expert_indices | ||
| indices = list(self.rollout_expert_indices) | ||
| if self.added_eos: | ||
| layer_num = len(self.rollout_expert_indices[0]) | ||
| topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0 | ||
| indices.append([[0] * topk for _ in range(layer_num)]) | ||
| return indices |
There was a problem hiding this comment.
The implementation of get_turn_gen_only_rollout_expert_indices contains a potential IndexError if self.rollout_expert_indices is an empty list (e.g., in certain mock or edge cases where the model generates no tokens). Accessing self.rollout_expert_indices[0] at line 169 will fail. While line 165 provides a safety check, it returns the empty list immediately, which would cause a length mismatch with response_ids if added_eos is True. Consider ensuring that layer_num and topk can be determined or return None if the shape is ambiguous.
Running here: https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/runs/3e4ndo6d?nw=nwusercharlieruan
Summary
Three independent fixes to
skyrl/train/generators/skyrl_gym_generator.py::agent_loopthat compose cleanly. Each is its own commit for easy review.1. Run the length check after re-tokenizing chat history
In the step-wise path,
agent_loop_state.input_idsis only refreshed inside the loop viaapply_chat_template(chat_history, ...). The prior layout checkedlen(input_ids) > max_input_lengthbefore that refresh, so at turn N the check ran against the staleinput_idsleft over from turn N-1's re-tokenize (which reflected chat_history through turn N-2). Re-tokenizing then produced the correct — and now potentially over-length — prompt for turn N, and we generated anyway, emitting a step whoseprompt_idscould exceedmax_input_length.Observed in Search-R1 step-wise runs as
generate/batch_padded_seq_lenspiking up to ~10k vs ~4k in the non-step-wise path.Fix: re-tokenize first, then check length. Non-step-wise paths are unaffected (the guard was already consistent there).
Adds a regression test that drives
chat_historygrowth through a mockapply_chat_template, so the current-turn prompt length exceedsmax_input_length; the test asserts exactly one emitted step and no step'sprompt_idsexceeding the limit. Verified the test fails on the pre-fix code and passes on the fix.2. Mask post-hoc appended EOS out of the loss
In
agent_loop, whenuse_conversation_multi_turn=Falseand the final response does not already end with the tokenizer's EOS, we manually append one so downstream parsing sees a terminated sequence. That token was not produced by the model, so it has no real logprob — training on it pushes gradients toward a fabricated target.Switch its
loss_maskentry from 1 to 0 (the TODO comment that flagged this case). The token still lives inresponse_idsso the sequence terminates cleanly, but it no longer contributes to the loss.Updates the two tests that asserted the appended EOS was masked-in.
3. Drop
obs_idsfrom per-stepresponse_ids(step-wise)For step-wise training, each turn emits its own training sample. Prior behavior concatenated obs onto
response_ids = output_ids + obs_ids. The obs tokens hadloss_mask=0, but the training forward pass still computed logits over those positions, inflating per-step sequence length by thousands of tokens for retrieval-heavy envs (e.g. Search-R1 retrieved-doc blobs).With gen-only responses, obs moves to the next turn's prompt (capped at
max_input_lengthby the Task 1 fix), so per-step sequences stay at mostmax_input_length + max_generate_length≈ 4596. Observedbatch_padded_seq_len: 9434pre-fix vs consistently under 3k post-fix.Adds
get_turn_gen_only_*helpers onTurnOutputso the step-wise branch uses clean accessors rather than inline slicing of obs-padded forms.merge_stepwise_output(#1538) is compatible: itsobs_delta = prompt_{i}[len(full_merged):]branch recovers the obs tokens on merge. Advantage compute is unchanged — it sums token-level rewards, and the single non-zero reward still lands at the last output token position.Tests
All tests pass on top of current main (with #1538 merged):
```bash
UV_CACHE_DIR=/mnt/local_storage/uv-cache uv run --frozen --extra dev --extra fsdp pytest \
tests/train/test_trainer_utils.py \
tests/train/test_generator_postprocess.py \
tests/train/test_prompt_mini_batch.py \
tests/train/generators/test_generator_output_utils.py \
tests/train/generators/test_skyrl_gym_generator.py \
-v
102 passed
```
This includes the 16 merge-stepwise tests from #1538 and a new regression test (
test_step_wise_trajectories_length_check_uses_current_prompt) for the Task 1 fix.Validation: Search-R1 step-wise run
Reproduced the validation command from PR #1529 on the Qwen2.5-3B Search-R1 recipe:
```bash
USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true MERGE_STEPWISE=true \
bash examples/train/search/run_search.sh \
generator.inference_engine.num_engines=8 \
generator.inference_engine.tensor_parallel_size=1 \
trainer.project_name=skyrl-search-padding \
trainer.run_name=stepwise-opts-v3-merge \
trainer.ckpt_interval=5 trainer.eval_interval=999
```
batch_padded_seq_lencomparison (same recipe, sync step-wise, Qwen2.5-3B)The pre-fix step-3 spike at 9434 is exactly the Issue-2 long-obs case: one batch where a retrieved-doc observation pushed
response_ids = output + obsto ~5000+ tokens, plus a 4000-token prompt. Post-fix all 10+ observed steps stay under 3k.Training-health signals
reward/avg_raw_reward: 0.27–0.32 across all three runs, matching the reference run's early band.policy_kl~1.5e-3,is_ratio_mean~1.00 across all steps — no policy drift.avg_response_length: pre-fix 1000–9000 (variable with obs); post-fix ~95–100 gen-only; +MERGE ~130–145 (merged-pair responses include the obs_delta between consecutive turns). All consistent with the mechanics.WandB project:
sky-posttraining-uc-berkeley/skyrl-search-padding; run namesstepwise-opts-v2-fresh(Tasks 1+2+loss_mask) andstepwise-opts-v3-merge(+MERGE).Test plan
test_step_wise_trajectories_length_check_uses_current_prompt) fails on pre-fix, passes on fix.🤖 Generated with Claude Code