Skip to content

[train][step-wise] Three correctness/efficiency fixes for step-wise training#1539

Open
CharlieFRuan wants to merge 3 commits intomainfrom
charlie/stepwise-training-opts
Open

[train][step-wise] Three correctness/efficiency fixes for step-wise training#1539
CharlieFRuan wants to merge 3 commits intomainfrom
charlie/stepwise-training-opts

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 20, 2026

Running here: https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/runs/3e4ndo6d?nw=nwusercharlieruan

Summary

Three independent fixes to skyrl/train/generators/skyrl_gym_generator.py::agent_loop that compose cleanly. Each is its own commit for easy review.

1. Run the length check after re-tokenizing chat history

In the step-wise path, agent_loop_state.input_ids is only refreshed inside the loop via apply_chat_template(chat_history, ...). The prior layout checked len(input_ids) > max_input_length before that refresh, so at turn N the check ran against the stale input_ids left over from turn N-1's re-tokenize (which reflected chat_history through turn N-2). Re-tokenizing then produced the correct — and now potentially over-length — prompt for turn N, and we generated anyway, emitting a step whose prompt_ids could exceed max_input_length.

Observed in Search-R1 step-wise runs as generate/batch_padded_seq_len spiking up to ~10k vs ~4k in the non-step-wise path.

Fix: re-tokenize first, then check length. Non-step-wise paths are unaffected (the guard was already consistent there).

Adds a regression test that drives chat_history growth through a mock apply_chat_template, so the current-turn prompt length exceeds max_input_length; the test asserts exactly one emitted step and no step's prompt_ids exceeding the limit. Verified the test fails on the pre-fix code and passes on the fix.

2. Mask post-hoc appended EOS out of the loss

In agent_loop, when use_conversation_multi_turn=False and the final response does not already end with the tokenizer's EOS, we manually append one so downstream parsing sees a terminated sequence. That token was not produced by the model, so it has no real logprob — training on it pushes gradients toward a fabricated target.

Switch its loss_mask entry from 1 to 0 (the TODO comment that flagged this case). The token still lives in response_ids so the sequence terminates cleanly, but it no longer contributes to the loss.

Updates the two tests that asserted the appended EOS was masked-in.

3. Drop obs_ids from per-step response_ids (step-wise)

For step-wise training, each turn emits its own training sample. Prior behavior concatenated obs onto response_ids = output_ids + obs_ids. The obs tokens had loss_mask=0, but the training forward pass still computed logits over those positions, inflating per-step sequence length by thousands of tokens for retrieval-heavy envs (e.g. Search-R1 retrieved-doc blobs).

With gen-only responses, obs moves to the next turn's prompt (capped at max_input_length by the Task 1 fix), so per-step sequences stay at most max_input_length + max_generate_length ≈ 4596. Observed batch_padded_seq_len: 9434 pre-fix vs consistently under 3k post-fix.

Adds get_turn_gen_only_* helpers on TurnOutput so the step-wise branch uses clean accessors rather than inline slicing of obs-padded forms.

merge_stepwise_output (#1538) is compatible: its obs_delta = prompt_{i}[len(full_merged):] branch recovers the obs tokens on merge. Advantage compute is unchanged — it sums token-level rewards, and the single non-zero reward still lands at the last output token position.

Tests

All tests pass on top of current main (with #1538 merged):

```bash
UV_CACHE_DIR=/mnt/local_storage/uv-cache uv run --frozen --extra dev --extra fsdp pytest \
tests/train/test_trainer_utils.py \
tests/train/test_generator_postprocess.py \
tests/train/test_prompt_mini_batch.py \
tests/train/generators/test_generator_output_utils.py \
tests/train/generators/test_skyrl_gym_generator.py \
-v

102 passed

```

This includes the 16 merge-stepwise tests from #1538 and a new regression test (test_step_wise_trajectories_length_check_uses_current_prompt) for the Task 1 fix.

Validation: Search-R1 step-wise run

Reproduced the validation command from PR #1529 on the Qwen2.5-3B Search-R1 recipe:

```bash
USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true MERGE_STEPWISE=true \
bash examples/train/search/run_search.sh \
generator.inference_engine.num_engines=8 \
generator.inference_engine.tensor_parallel_size=1 \
trainer.project_name=skyrl-search-padding \
trainer.run_name=stepwise-opts-v3-merge \
trainer.ckpt_interval=5 trainer.eval_interval=999
```

batch_padded_seq_len comparison (same recipe, sync step-wise, Qwen2.5-3B)

run step 1 step 2 step 3 step 4 step 5 max over 10+ steps
pre-fix (baseline) 3077 2714 9434 2692 2744 9434
with Tasks 1+2+loss_mask 2510 2828 2919 3092 3897 3897
with Tasks 1+2+loss_mask + MERGE 2882 2517 2822 2843 2738 2920

The pre-fix step-3 spike at 9434 is exactly the Issue-2 long-obs case: one batch where a retrieved-doc observation pushed response_ids = output + obs to ~5000+ tokens, plus a 4000-token prompt. Post-fix all 10+ observed steps stay under 3k.

Training-health signals

  • Step times: pre-fix ~210s avg; post-fix (with MERGE) ~190s avg.
  • reward/avg_raw_reward: 0.27–0.32 across all three runs, matching the reference run's early band.
  • policy_kl ~1.5e-3, is_ratio_mean ~1.00 across all steps — no policy drift.
  • avg_response_length: pre-fix 1000–9000 (variable with obs); post-fix ~95–100 gen-only; +MERGE ~130–145 (merged-pair responses include the obs_delta between consecutive turns). All consistent with the mechanics.

WandB project: sky-posttraining-uc-berkeley/skyrl-search-padding; run names stepwise-opts-v2-fresh (Tasks 1+2+loss_mask) and stepwise-opts-v3-merge (+MERGE).

Test plan

  • CPU test suite (102 tests) passes on the rebased branch, including all 16 merge-stepwise tests.
  • New regression test (test_step_wise_trajectories_length_check_uses_current_prompt) fails on pre-fix, passes on fix.
  • Search-R1 step-wise run: `batch_padded_seq_len` capped under 3k across 11+ steps (vs pre-fix 9434 spike).
  • Search-R1 step-wise + MERGE run: same cap, ~190s step time, no divergence from reference reward trajectory.

🤖 Generated with Claude Code


Open in Devin Review

For the step-wise path, `agent_loop_state.input_ids` is only refreshed inside
the loop by `apply_chat_template(chat_history, ...)`. The previous layout
checked `len(input_ids) > max_input_length` BEFORE that refresh, so at turn N
the check ran against the stale input_ids left over from turn N-1's
re-tokenize — which reflected chat_history through turn N-2. Re-tokenizing
then produced the correct (now-bigger) prompt for turn N and we generated
anyway, recording a step whose `prompt_ids` could exceed `max_input_length`.

Observed in Search-R1 step-wise runs as `generate/batch_padded_seq_len` spikes
up to ~10k vs ~4k in the non-step-wise path.

Fix: re-tokenize first, then check length.

Adds a regression test that drives chat_history growth through a mock
apply_chat_template so the current-turn prompt length exceeds
max_input_length; the test expects exactly one emitted step and no step's
prompt exceeding the limit.
In `agent_loop`, when `use_conversation_multi_turn=False` and the final
response does not already end with the tokenizer's EOS, we manually append
one so downstream parsing sees a terminated sequence. That token was not
actually produced by the model, so it has no real logprob — training on it
pushes gradients toward a fabricated target.

Switch its loss_mask entry from 1 to 0 (the TODO that flagged this case).
The token still ends up in response_ids so the sequence terminates cleanly,
but it no longer contributes to the loss.

Updates the two tests that asserted the appended EOS was masked-in.
For step-wise training, each turn emits its own training sample — there is no
requirement for the response to embed the observation that followed the
model's generation. Concatenating obs onto `response_ids` (the prior behavior)
kept the obs out of the loss (mask=0) but still forced the training forward
pass to compute logits over those positions, inflating per-step sequence
length by thousands of tokens for retrieval-heavy envs.

With gen-only responses, obs appears in the next turn's prompt (capped at
`max_input_length` by the preceding length-check fix), so per-step sequences
stay at most `max_input_length + max_generate_length`. Observed
`batch_padded_seq_len: 9434` with obs-in-response vs capped ~4596 without.

Adds `get_turn_gen_only_*` helpers on `TurnOutput` so the step-wise branch
calls clean accessors rather than slicing the obs-padded forms.

`merge_stepwise_output`'s `obs_delta = prompt_{i}[len(full_merged):]` branch
recovers the obs tokens on merge, so prefix-aware merging still works.
Advantage compute is unchanged: it sums token-level rewards, and the single
non-zero reward still lands at the last output token position.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces 'gen-only' variants for step-wise generation in SkyRLGymGenerator, ensuring that observation tokens are excluded from training samples to maintain consistent sequence lengths. It also moves the input length validation to occur after re-tokenization and masks post-hoc appended EOS tokens in the loss mask. A potential IndexError was identified in get_turn_gen_only_rollout_expert_indices when handling empty expert indices with an added EOS.

Comment on lines +161 to +172
def get_turn_gen_only_rollout_expert_indices(self) -> Optional[List[List[List[int]]]]:
"""Like `get_turn_rollout_expert_indices()` but without the trailing observation pads."""
if self.rollout_expert_indices is None:
return None
if not self.rollout_expert_indices:
return self.rollout_expert_indices
indices = list(self.rollout_expert_indices)
if self.added_eos:
layer_num = len(self.rollout_expert_indices[0])
topk = len(self.rollout_expert_indices[0][0]) if layer_num > 0 else 0
indices.append([[0] * topk for _ in range(layer_num)])
return indices
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The implementation of get_turn_gen_only_rollout_expert_indices contains a potential IndexError if self.rollout_expert_indices is an empty list (e.g., in certain mock or edge cases where the model generates no tokens). Accessing self.rollout_expert_indices[0] at line 169 will fail. While line 165 provides a safety check, it returns the empty list immediately, which would cause a length mismatch with response_ids if added_eos is True. Consider ensuring that layer_num and topk can be determined or return None if the shape is ambiguous.

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant