Skip to content

[fix][train] Prompt-based mini-batching for step-wise training#1528

Closed
CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
CharlieFRuan:prompt-based-mini-batching-v2
Closed

[fix][train] Prompt-based mini-batching for step-wise training#1528
CharlieFRuan wants to merge 2 commits intoNovaSky-AI:mainfrom
CharlieFRuan:prompt-based-mini-batching-v2

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 17, 2026

Summary

Step-wise training decomposes multi-turn trajectories into one training sequence per LLM turn, producing a variable number of sequences per prompt. This broke the old fixed-size mini-batching in two ways:

  1. Crash: Total sequence count wasn't always divisible by the fixed mini-batch size.
  2. LR schedule distortion: Variable turn counts caused variable numbers of optimizer steps per training batch.

This PR shifts mini-batching from sequence units to prompt units. Each mini-batch now contains sequences for exactly policy_mini_batch_size prompts, regardless of how many sequences those prompts generated. This ensures the number of optimizer steps is always train_batch_size / policy_mini_batch_size * update_epochs_per_batch.

Depends on #1527 (extract pad_training_input_batch).

Key changes

  • compute_prompt_mini_batch_boundaries() (skyrl/train/dataset/preprocess.py): walks a flat uids list, detects prompt boundaries by consecutive-equal groups, and slices them into (start, end) boundary pairs for each mini-batch. Asserts uid contiguity (a uid cannot re-appear after a gap). Asserts len(unique_uids) == train_batch_size. For non-step-wise, asserts boundaries are uniform (backward compatible).
  • MeshDispatch.stage_chunks() (dispatch.py): accepts mini_batch_boundaries instead of computing fixed-size chunks. Each mini-batch is individually padded to dp_size using pad_training_input_batch().
  • _normalize_advantages() and _execute_training_step() (trainer.py): iterate over boundary pairs instead of fixed-size slicing.
  • apply_loss_reduction_to_advantages_minibatch() (ppo_utils.py): uses ceiling division for num_micro_batches so ragged last micro-batches (from variable mini-batch sizes) are handled correctly.
  • validate_batch_sizes() (utils.py): relaxes policy_mini_batch_size % micro_batch_size == 0 check for step-wise mode (the actual per-mini-batch sequence count is variable).
  • WorkerDispatch.stage_data() (worker_dispatch.py): passes boundaries through to stage_chunks.

Backward compatibility

For non-step-wise training, where each prompt has exactly n_samples_per_prompt sequences, boundaries remain uniform — identical to the original fixed-size slicing. An assertion in compute_prompt_mini_batch_boundaries verifies this.

Test plan

  • tests/train/test_prompt_mini_batch.py: unit tests for compute_prompt_mini_batch_boundaries (non-step-wise, step-wise, contiguity assertion, boundary uniformity parametrized), MeshDispatch.stage_chunks (padding, loss_mask zeros, variable sizes), and optimizer step count invariance.
  • tests/backends/skyrl_train/test_train_batch.py: field-exhaustive pad_training_input_batch tests (from [qoc] Extract pad_batch() into a helper to training_batch.py #1527).

🤖 Generated with Claude Code


Open with Devin

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request centralizes the batch padding logic by introducing a new utility function, pad_training_input_batch, in training_batch.py. This replaces duplicate implementations in skyrl_train_backend.py and trainer.py, ensuring consistent handling of tensors, TensorList objects, and metadata across the codebase. Comprehensive tests have also been added to verify the padding behavior and input immutability. One issue was identified in the new function where it mutates the input batch's metadata even when the padding size is zero, which contradicts the expected behavior and could cause side effects.

Comment on lines +499 to +503
if pad_size == 0:
if unpadded_batch.metadata is None:
unpadded_batch.metadata = {}
unpadded_batch.metadata["pad_size"] = 0
return unpadded_batch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation mutates the input unpadded_batch when pad_size is 0 by adding or modifying the pad_size key in its metadata. This contradicts the docstring which states it returns the original batch in this case, and it can lead to unexpected side effects if the batch is reused elsewhere. Since data.metadata.get("pad_size", 0) is used in other parts of the codebase (e.g., in trainer.py) to handle missing keys, it is safer to simply return the original batch without modification.

    if pad_size == 0:
        return unpadded_batch

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 5 additional findings.

Open in Devin Review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant