[train] Make TrainingInputBatch to PAD only to left, hence response tensors be right-aligned#1285
[train] Make TrainingInputBatch to PAD only to left, hence response tensors be right-aligned#1285CharlieFRuan merged 3 commits intomainfrom
Conversation
7d3a7fc to
0a1c253
Compare
|
/gemini review |
There was a problem hiding this comment.
Code Review
The pull request refactors the token padding strategy in skyrl/train/dataset/preprocess.py to use unified left-padding for combined prompt-response sequences and right-alignment for response-specific tensors (rewards, loss masks, logprobs). This change aims to improve padding efficiency and better align with model output processing. The run_search.sh script was updated to be more flexible, introducing configurable multi-turn and step-wise training options via environment variables, and the previously separate run_search_conversation_format.sh script was removed, consolidating functionality. Corresponding tests were updated and new tests were added to validate the new padding and alignment logic, including a warning mechanism for exceeding max_seq_len without truncation.
4f2d8fa to
c3d911c
Compare
| Convert prompts and responses to batch tensors for training. | ||
|
|
||
| This function concatenates all prompts and responses to the following format: | ||
| Each sequence is laid out as a single left-padded block: |
There was a problem hiding this comment.
this in general makes sense to me, but can you test the forward pass for megatron in test_megatron_worker::test_forward to make sure things look ok for the logic there that removes padding?
just running a basic forward pass for any model should suffice
There was a problem hiding this comment.
made a megatron gsm8k run after rebasing, seems normal. Added to the PR description. I'll address some R3 things before merging
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
c3d911c to
07eefd6
Compare
|
GPU-Megatron test (passed): https://github.com/NovaSky-AI/SkyRL/actions/runs/23264218126 Failure unrelated to this PR: |
…ensors be right-aligned (#1285) ## Summary Fix cross-sample padding inflation in `convert_prompts_responses_to_batch_tensors` by replacing the two-segment padding layout with a unified left-padded layout. **Before:** sequences padded to `max(prompt_lens) + max(response_lens)`. When the longest prompt and longest response come from different samples (common in step-wise training), this can approach **2x max_seq_len**. **After:** sequences padded to `max(prompt_i + response_i)` — the tightest bound that preserves every real token. **No tokens are truncated.** ## Problem The old layout pads prompts and responses as independent segments: ``` | [PAD..PAD PROMPT] | [RESPONSE PAD..PAD] | |<-- max_prompt --->|<--- max_response -->| ``` In step-wise training, prompt and response lengths are anti-correlated across steps: - Early turns: short prompt (5K), long response (50K) - Late turns: long prompt (55K), short response (4K) The padded `seq_len = 55K + 50K = 105K`, far exceeding the actual `max_seq_len = 60K`. With 61,440 step-samples, this inflates the `sequences` tensor from ~75 GB to ~103 GB (for max_seq_len=80K) — pure padding waste. ## Solution Eliminate the two-segment layout. Each sequence is now a single left-padded block: ``` | [PAD..PAD PROMPT RESPONSE] | |<------- max_total ---------->| ``` Where `max_total = max(prompt_i + response_i)`. The response is always at the end of the sequence, so the existing model forward pass slicing (`log_probs[:, -num_actions-1:-1]`) works unchanged. ### Response data alignment change: left-aligned → right-aligned Because response tokens are now at the **end** of each sequence (with variable-length prompts before them), the response logprobs extracted by the model are **right-aligned** within the `(batch, max_response)` slice. Response-level tensors (action_mask, rewards, loss_masks, rollout_logprobs) are correspondingly right-aligned to match: ``` Old (left-aligned): [resp_tok, resp_tok, resp_tok, PAD, PAD] New (right-aligned): [PAD, PAD, resp_tok, resp_tok, resp_tok] ``` All downstream consumers use masked operations (`masked_mean(loss * loss_mask, loss_mask)`, `scores.unsqueeze(-1) * response_mask`, etc.) which are alignment-agnostic. The `loss_fn_outputs` extraction for the Tinker API path uses `action_mask.sum()` + `[:valid_len]` which would need a follow-up adjustment for that specific code path (currently not used in the standard RL training loop — it's popped at `trainer.py:1088`). ## Changes | File | Change | |------|--------| | `skyrl/train/dataset/preprocess.py` | Unified left-pad layout, right-aligned response data, optional `max_seq_len` warning | | `skyrl/train/trainer.py` | Pass `max_seq_len` from config to padding function | | `tests/train/dataset/test_preprocess.py` | 8 new tests for unified layout, right-alignment, anti-correlation, no-mutation, backward compat | ## Test plan - [x] All 12 unit tests pass (4 existing updated + 8 new) - [x] Verify step-wise training run produces same loss curves (right-alignment changes tensor layout but not masked loss values) - [x] Verify non-step-wise training is unaffected (max_total = max_prompt + max_response when not anti-correlated) ## Curves GSM8K CI run: <img width="1431" height="278" alt="image" src="https://github.com/user-attachments/assets/5cc4ea54-f6df-498d-ae7e-c3cf243610fa" /> https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/reports/Untitled-Report--VmlldzoxNjE3MTA0OQ Ran with 8xH100s - Baseline from previous PRs (blue) -- without TIS - Non-step wise search r1 ran with (red) -- with TIS ```bash USE_CONVERSATION_MULTI_TURN=true bash examples/train/search/run_search.sh \ generator.inference_engine.num_engines=8 \ generator.inference_engine.tensor_parallel_size=1 ``` - Step-wise search r1 ran with (brown) -- with TIS ```bash USE_CONVERSATION_MULTI_TURN=true STEP_WISE=true bash examples/train/search/run_search.sh \ generator.inference_engine.num_engines=8 \ generator.inference_engine.tensor_parallel_size=1 ``` <img width="559" height="252" alt="image" src="https://github.com/user-attachments/assets/0a287d07-5f5f-471e-a02e-570905ad468a" /> The step-wise training time is much worse as of now (roughly 4x, scales with num turns), and hopefully can be improved after #1277 <img width="839" height="500" alt="image" src="https://github.com/user-attachments/assets/2d1cbd65-1a25-4e7e-8c60-d5b221a97800" /> Gsm8k + Megatron CI run (purple is with this PR after rebasing) https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k_ci_megatron/runs/uoaga4uz?nw=nwusercharlieruan <img width="1464" height="271" alt="image" src="https://github.com/user-attachments/assets/6129afb9-1550-465b-92f8-d8d12063142c" /> --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
After PR #1285 unified the padding layout to left-pad (right-aligned response tensors), the loss_fn_outputs extraction in the Tinker API path still sliced with [:valid_len] (left-aligned), returning padding/prompt logprobs instead of actual response logprobs. Change [:valid_len] to [-valid_len:] at all 4 affected sites in both the FSDP worker and Megatron worker (SFT and RL paths each). Add CPU unit tests that verify right-aligned slicing returns the correct response values and would fail with the old left-aligned slicing. Fixes #1304 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…#1367) ## Summary - Fix `loss_fn_outputs` extraction using `[:valid_len]` (left-aligned) on right-aligned tensors, returning padding/prompt logprobs instead of actual response values - Change to `[-valid_len:]` at all 4 affected sites: FSDP worker SFT path, FSDP worker RL path, Megatron worker SFT path, Megatron worker RL path ## Details After #1285 unified the padding layout to left-pad (right-aligned response tensors), the `loss_fn_outputs` extraction in the Tinker API path still used `[:valid_len]`, which grabs values from the left (padding region) instead of the right (actual response region). Example with `action_mask = [0, 0, 1, 1, 1]` and `action_log_probs = [pad, pad, real_0, real_1, real_2]`: - `[:3]` → `[pad, pad, real_0]` (WRONG) - `[-3:]` → `[real_0, real_1, real_2]` (CORRECT) **Affected entrypoints**: The Tinker API (`skyrl.tinker.api`) — the standard RL training loop discards `loss_fn_outputs` and is unaffected. Note that #1285 does not affect the existing bug in the tinker entrypoint, since it calls [_to_training_batch()](https://github.com/NovaSky-AI/SkyRL/blob/954b2ee2d75fd20adf97e1729e6d1ca8342e820c/skyrl/backends/skyrl_train_backend.py#L286-L290), which also does right padding. Fixes #1304 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…s own response_mask (#1507) In step-wise training, `advantages = last_step_advantages[traj_ids]` copied the last step's already-masked tensor verbatim to every step of the trajectory. Because each step has its own right-aligned response_mask at its own token positions (prompts and responses vary in length across turns), the scalar advantage landed at the last step's token positions — which almost never overlap with the earlier step's own response positions. Earlier steps silently received advantages = 0 at their loss_mask=1 positions, so roughly (N-1)/N of step-samples are trained incorrectly. Fix: - Call the estimator with an all-ones mask so it returns the scalar advantage at every position (outcome-based estimators produce `scalar * mask`). - Broadcast by indexing into `traj_ids`, then multiply by each step's own `response_mask` so values land at the right token positions. - Reject `step_wise_trajectories=True` + `gae`/`reinforce++` in `validate_cfg`: temporal estimators produce per-token advantages that the scalar-broadcast model can't represent, and the step-wise path only sees the last step's slice anyway (no cross-step credit). - Add regression test `test_calc_advantages_and_returns_step_wise_broadcast` with a realistic right-aligned batch, plus a parametrized `test_validate_cfg_step_wise_estimator_compatibility` covering all 5 built-in estimators. - Document the outcome-only constraint in the step-wise tutorial. Fixes #1492. ### Testing Since we will run into #1501 at HEAD, we make this hacky hot fix, mainly the `pad_batch()` change: CharlieFRuan@f2783a4 ```py if self.cfg.generator.step_wise_trajectories: n_samples = self.cfg.generator.n_samples_per_prompt pad_target = max( self.cfg.trainer.policy_mini_batch_size * n_samples, self.cfg.trainer.critic_mini_batch_size * n_samples, dp_size, ) else: pad_target = dp_size pad_size = math.ceil(training_input.batch_size / pad_target) * pad_target - training_input.batch_size ``` We compare against the red curve from #1285 Note that the actual curve should converge slower than this after we merge in #1483, in which each mini batch will truly be one optimizer step. Currently we can have multiple optimizer step due to the number of turns can inflate the number of sequences in each mini batch. <img width="650" height="287" alt="image" src="https://github.com/user-attachments/assets/bd9ac452-36dd-4135-8203-e9e582509c50" /> The reward curve being higher is likely due to: 1) The fix we apply in this PR 2) In the previous PR (red line), we silently strip the amount of data and only train `num_mini_batches = len(data) // mini_batch_size`, which will no longer be an issue with the `pad_target` hacky fix above, or properly fixed after #1483 --------- Co-authored-by: Claude <noreply@anthropic.com>
Summary
Fix cross-sample padding inflation in
convert_prompts_responses_to_batch_tensorsby replacing the two-segment padding layout with a unified left-padded layout.Before: sequences padded to
max(prompt_lens) + max(response_lens). When the longest prompt and longest response come from different samples (common in step-wise training), this can approach 2x max_seq_len.After: sequences padded to
max(prompt_i + response_i)— the tightest bound that preserves every real token. No tokens are truncated.Problem
The old layout pads prompts and responses as independent segments:
In step-wise training, prompt and response lengths are anti-correlated across steps:
The padded
seq_len = 55K + 50K = 105K, far exceeding the actualmax_seq_len = 60K. With 61,440 step-samples, this inflates thesequencestensor from ~75 GB to ~103 GB (for max_seq_len=80K) — pure padding waste.Solution
Eliminate the two-segment layout. Each sequence is now a single left-padded block:
Where
max_total = max(prompt_i + response_i). The response is always at the end of the sequence, so the existing model forward pass slicing (log_probs[:, -num_actions-1:-1]) works unchanged.Response data alignment change: left-aligned → right-aligned
Because response tokens are now at the end of each sequence (with variable-length prompts before them), the response logprobs extracted by the model are right-aligned within the
(batch, max_response)slice. Response-level tensors (action_mask, rewards, loss_masks, rollout_logprobs) are correspondingly right-aligned to match:All downstream consumers use masked operations (
masked_mean(loss * loss_mask, loss_mask),scores.unsqueeze(-1) * response_mask, etc.) which are alignment-agnostic. Theloss_fn_outputsextraction for the Tinker API path usesaction_mask.sum()+[:valid_len]which would need a follow-up adjustment for that specific code path (currently not used in the standard RL training loop — it's popped attrainer.py:1088).Changes
skyrl/train/dataset/preprocess.pymax_seq_lenwarningskyrl/train/trainer.pymax_seq_lenfrom config to padding functiontests/train/dataset/test_preprocess.pyTest plan
Curves
GSM8K CI run:

https://wandb.ai/sky-posttraining-uc-berkeley/skyrl-search-padding/reports/Untitled-Report--VmlldzoxNjE3MTA0OQ
Ran with 8xH100s
The step-wise training time is much worse as of now (roughly 4x, scales with num turns), and hopefully can be improved after #1277
Gsm8k + Megatron CI run (purple is with this PR after rebasing) https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k_ci_megatron/runs/uoaga4uz?nw=nwusercharlieruan
