Skip to content

[train] Extract pad_batch() to training_batch.py#1523

Closed
CharlieFRuan wants to merge 1 commit intomainfrom
pad-batch-cpu-assert
Closed

[train] Extract pad_batch() to training_batch.py#1523
CharlieFRuan wants to merge 1 commit intomainfrom
pad-batch-cpu-assert

Conversation

@CharlieFRuan
Copy link
Copy Markdown
Member

@CharlieFRuan CharlieFRuan commented Apr 16, 2026

Closed by #1527

  • Moves pad_batch out of RayPPOTrainer into a module-level function in training_batch.py so that dispatch-level callers can share it.
  • Adds a mode kwarg: train_batch (callers own the full batch and want uids/trajectory_ids metadata extended with synthetic pad entries) vs mini_batch (callers pad a transient slice and must not touch parent metadata that would not correspond to the slice anyway).
  • Asserts the batch lives on CPU. Both real callers already stage on CPU, and padding allocates/concatenates — it's not something we want to do on the GPU hot path.
  • Allows pad_size > batch_size by cycling row 0 (regression: mini-batch padding can see mb_size=1, dp_size=4pad_size=3, and the old tensor[:pad_size] silently returned a shorter slice).
  • Handles TensorList fields (pixel_values, image_grid_thw) via cyclic cloning, matching the VLM path introduced in [train][multimodal][3/3] Trainer changes to extract multi-modal outputs from GeneratorOutput #1498.
  • Adds is_last_step to the TrainingInput TypedDict (it's already used everywhere; this makes the schema match reality).
  • Field-exhaustive unit tests mirror test_generator_output_concatenation: they enumerate TrainingInput.__annotations__ and fail loudly if a new field is added without updating pad_batch(). Also covers the pad_size > batch_size edge case, CPU-only assertion, both modes, and input-immutability.

… support

- Moves `pad_batch` out of `RayPPOTrainer` into a module-level function in
  `training_batch.py` so that dispatch-level callers can share it.
- Adds a `mode` kwarg: ``train_batch`` (callers own the full batch and want
  uids/trajectory_ids metadata extended with synthetic pad entries) vs
  ``mini_batch`` (callers pad a transient slice and must not touch parent
  metadata that would not correspond to the slice anyway).
- Asserts the batch lives on CPU. Both real callers already stage on CPU,
  and padding allocates/concatenates — it's not something we want to do
  on the GPU hot path.
- Allows ``pad_size > batch_size`` by cycling row 0 (regression: mini-batch
  padding can see ``mb_size=1, dp_size=4`` → ``pad_size=3``, and the old
  ``tensor[:pad_size]`` silently returned a shorter slice).
- Handles ``TensorList`` fields (``pixel_values``, ``image_grid_thw``) via
  cyclic cloning, matching the VLM path introduced in #1498.
- Adds ``is_last_step`` to the ``TrainingInput`` TypedDict (it's already
  used everywhere; this makes the schema match reality).
- Field-exhaustive unit tests mirror ``test_generator_output_concatenation``:
  they enumerate ``TrainingInput.__annotations__`` and fail loudly if a new
  field is added without updating ``pad_batch()``. Also covers the
  pad_size > batch_size edge case, CPU-only assertion, both modes, and
  input-immutability.
@CharlieFRuan CharlieFRuan changed the title [train] Extract pad_batch() to training_batch.py with mini_batch mode… [train] Extract pad_batch() to training_batch.py Apr 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant