[fix][train] Prompt-based mini-batching for step-wise training#1528
[fix][train] Prompt-based mini-batching for step-wise training#1528CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request centralizes the batch padding logic by introducing a new utility function, pad_training_input_batch, in training_batch.py. This replaces duplicate implementations in skyrl_train_backend.py and trainer.py, ensuring consistent handling of tensors, TensorList objects, and metadata across the codebase. Comprehensive tests have also been added to verify the padding behavior and input immutability. One issue was identified in the new function where it mutates the input batch's metadata even when the padding size is zero, which contradicts the expected behavior and could cause side effects.
| if pad_size == 0: | ||
| if unpadded_batch.metadata is None: | ||
| unpadded_batch.metadata = {} | ||
| unpadded_batch.metadata["pad_size"] = 0 | ||
| return unpadded_batch |
There was a problem hiding this comment.
The current implementation mutates the input unpadded_batch when pad_size is 0 by adding or modifying the pad_size key in its metadata. This contradicts the docstring which states it returns the original batch in this case, and it can lead to unexpected side effects if the batch is reused elsewhere. Since data.metadata.get("pad_size", 0) is used in other parts of the codebase (e.g., in trainer.py) to handle missing keys, it is safer to simply return the original batch without modification.
if pad_size == 0:
return unpadded_batch
Summary
Step-wise training decomposes multi-turn trajectories into one training sequence per LLM turn, producing a variable number of sequences per prompt. This broke the old fixed-size mini-batching in two ways:
This PR shifts mini-batching from sequence units to prompt units. Each mini-batch now contains sequences for exactly
policy_mini_batch_sizeprompts, regardless of how many sequences those prompts generated. This ensures the number of optimizer steps is alwaystrain_batch_size / policy_mini_batch_size * update_epochs_per_batch.Depends on #1527 (extract
pad_training_input_batch).Key changes
compute_prompt_mini_batch_boundaries()(skyrl/train/dataset/preprocess.py): walks a flatuidslist, detects prompt boundaries by consecutive-equal groups, and slices them into(start, end)boundary pairs for each mini-batch. Asserts uid contiguity (a uid cannot re-appear after a gap). Assertslen(unique_uids) == train_batch_size. For non-step-wise, asserts boundaries are uniform (backward compatible).MeshDispatch.stage_chunks()(dispatch.py): acceptsmini_batch_boundariesinstead of computing fixed-size chunks. Each mini-batch is individually padded todp_sizeusingpad_training_input_batch()._normalize_advantages()and_execute_training_step()(trainer.py): iterate over boundary pairs instead of fixed-size slicing.apply_loss_reduction_to_advantages_minibatch()(ppo_utils.py): uses ceiling division fornum_micro_batchesso ragged last micro-batches (from variable mini-batch sizes) are handled correctly.validate_batch_sizes()(utils.py): relaxespolicy_mini_batch_size % micro_batch_size == 0check for step-wise mode (the actual per-mini-batch sequence count is variable).WorkerDispatch.stage_data()(worker_dispatch.py): passes boundaries through tostage_chunks.Backward compatibility
For non-step-wise training, where each prompt has exactly
n_samples_per_promptsequences, boundaries remain uniform — identical to the original fixed-size slicing. An assertion incompute_prompt_mini_batch_boundariesverifies this.Test plan
tests/train/test_prompt_mini_batch.py: unit tests forcompute_prompt_mini_batch_boundaries(non-step-wise, step-wise, contiguity assertion, boundary uniformity parametrized),MeshDispatch.stage_chunks(padding, loss_mask zeros, variable sizes), and optimizer step count invariance.tests/backends/skyrl_train/test_train_batch.py: field-exhaustivepad_training_input_batchtests (from [qoc] Extractpad_batch()into a helper totraining_batch.py#1527).🤖 Generated with Claude Code