diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index 4b482d83e8..9fba75d541 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -1,5 +1,6 @@ """Defines interfaces for training data.""" +import copy import io import pickle from typing import Any, Dict, Generic, List, Optional, TypedDict, TypeVar @@ -483,3 +484,60 @@ class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]): """Training output data""" pass + + +def pad_training_input_batch(unpadded_batch: TrainingInputBatch, pad_size: int) -> TrainingInputBatch: + """Pad `pad_size` entries to `unpadded_batch`, return a newly allocated TrainingInputBatch. If pad_size is 0, return the original batch.""" + # TODO(Charlie): This incurs 2x CPU memory usage when pad_size > 0. Optimize when needed. + # Padding allocates and concatenates; it should not happen on GPU hot path. + assert unpadded_batch.device is None or unpadded_batch.device == torch.device( + "cpu" + ), f"pad_batch expects batch on CPU, got device={unpadded_batch.device}" + assert pad_size >= 0, f"pad_size must be >= 0, got {pad_size}" + + # Handle the special case of no padding. + if pad_size == 0: + if unpadded_batch.metadata is None: + unpadded_batch.metadata = {} + unpadded_batch.metadata["pad_size"] = 0 + return unpadded_batch + + # Pad each tensor depending on its type. + new_tensors = {} + for key, tensor in unpadded_batch.items(): + if tensor is None: + new_tensors[key] = None + continue + + if isinstance(tensor, TensorList): + assert len(tensor) > 0, f"Cannot pad empty TensorList field {key!r}" + padding = TensorList([tensor[0].clone() for _ in range(pad_size)]) + new_tensors[key] = TensorList.cat([tensor, padding]) + elif key == "loss_mask": + # Ensures that padding tensors don't count towards the loss + additional_dims = tensor.shape[1:] + padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) + new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) + else: + # Copy row 0 `pad_size` times. Loss masked so values don't affect the loss. Just need valid shape/dtype. + assert tensor.shape[0] > 0, f"Cannot pad empty tensor field {key!r}" + pad_indices = [0] * pad_size + padding_tensor = tensor[pad_indices].clone() + new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) + + # Update metadata as well. + new_metadata = {} + old_metadata = unpadded_batch.metadata or {} + for key, value in old_metadata.items(): + if key == "uids": + new_metadata["uids"] = value + [f"pad{i}" for i in range(pad_size)] + elif key == "is_last_step": + new_metadata["is_last_step"] = value + [True for _ in range(pad_size)] + else: + new_metadata[key] = copy.deepcopy(value) + new_metadata["pad_size"] = pad_size + + new_batch = TrainingInputBatch(new_tensors) + new_batch.metadata = new_metadata + + return new_batch diff --git a/skyrl/backends/skyrl_train_backend.py b/skyrl/backends/skyrl_train_backend.py index b2ebb2ec7f..46608d5d64 100644 --- a/skyrl/backends/skyrl_train_backend.py +++ b/skyrl/backends/skyrl_train_backend.py @@ -29,7 +29,10 @@ from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args from skyrl.backends.skyrl_train.inference_servers.vllm_router import VLLMRouter -from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch +from skyrl.backends.skyrl_train.training_batch import ( + TrainingInputBatch, + pad_training_input_batch, +) from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S @@ -448,25 +451,11 @@ def _pad_batch( dp_size = self._dispatch.get_lcm_dp_size() alignment = dp_size * micro_batch_size if micro_batch_size else dp_size pad_size = (alignment - batch.batch_size % alignment) % alignment - if pad_size == 0: - return batch, 0 - - new_tensors = {} - for key, tensor in batch.items(): - if tensor is not None: - if key == "loss_mask": - # Padding entries must not contribute to the loss - additional_dims = tensor.shape[1:] - padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) - else: - # Clone real data so shapes/dtypes are valid for the model - padding_tensor = tensor[torch.arange(pad_size) % tensor.shape[0]].clone() - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - - padded = TrainingInputBatch(new_tensors) - padded.metadata = batch.metadata - logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})") - return padded, pad_size + if pad_size > 0: + logger.info( + f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})" + ) + return pad_training_input_batch(batch, pad_size), pad_size def _extract_metrics(self, data: dict) -> dict[str, float]: """Extract training metrics from dispatch return dict. diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index ed48434aad..82fdd78b11 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -1,4 +1,3 @@ -import copy import math import os import shutil @@ -27,7 +26,11 @@ from skyrl.backends.skyrl_train.inference_engines.utils import ( get_sampling_params_for_backend, ) -from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch +from skyrl.backends.skyrl_train.training_batch import ( + TensorList, + TrainingInputBatch, + pad_training_input_batch, +) from skyrl.backends.skyrl_train.utils import ppo_utils from skyrl.backends.skyrl_train.utils.io import io from skyrl.backends.skyrl_train.utils.ppo_utils import ( @@ -690,7 +693,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis ) / len(response_ids) logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}") - training_input = self.pad_batch(training_input) + dp_size = self.dispatch.get_lcm_dp_size() + pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size + training_input = pad_training_input_batch(training_input, pad_size) logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}") return training_input @@ -910,50 +915,6 @@ def dump_data(self, data: TrainingInputBatch, file_name: str): data_save_dir.mkdir(parents=True, exist_ok=True) data.save(data_save_dir / f"{file_name}.pkl") - def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: - """Pad the batch to be divisible by dp size""" - import math - - dp_size = self.dispatch.get_lcm_dp_size() - pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size - new_tensors = {} - training_input.metadata["pad_size"] = pad_size - if pad_size == 0: - return training_input - for key, tensor in training_input.items(): - if tensor is not None: - if isinstance(tensor, TensorList): - n = len(tensor) - pad_indices = [i % n for i in range(pad_size)] - padding = TensorList([tensor[i].clone() for i in pad_indices]) - new_tensors[key] = TensorList.cat([tensor, padding]) - elif key == "loss_mask": - # ensures that padding tensors don't count towards the loss - additional_dims = tensor.shape[1:] - padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - else: - # ensures all padding tensors are in a valid format by cloning `pad_size` from the original input - n = tensor.shape[0] - pad_indices = torch.arange(pad_size, device=tensor.device) % n - padding_tensor = tensor[pad_indices].clone() - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - - new_training_input = TrainingInputBatch(new_tensors) - new_training_input.metadata = {} - for key, value in training_input.metadata.items(): - if key == "uids": - new_training_input.metadata["uids"] = training_input.metadata["uids"] + [ - f"pad{i}" for i in range(pad_size) - ] - elif key == "is_last_step": - new_training_input.metadata["is_last_step"] = training_input.metadata["is_last_step"] + [ - True for i in range(pad_size) - ] - else: - new_training_input.metadata[key] = copy.deepcopy(value) - return new_training_input - @torch.no_grad() def fwd_logprobs_values_reward( self, diff --git a/tests/backends/skyrl_train/test_train_batch.py b/tests/backends/skyrl_train/test_train_batch.py index 364463af2a..b105750957 100644 --- a/tests/backends/skyrl_train/test_train_batch.py +++ b/tests/backends/skyrl_train/test_train_batch.py @@ -5,7 +5,13 @@ import ray import torch -from skyrl.backends.skyrl_train.training_batch import TensorBatch, TensorList +from skyrl.backends.skyrl_train.training_batch import ( + TensorBatch, + TensorList, + TrainingInput, + TrainingInputBatch, + pad_training_input_batch, +) def test_train_batch_initialization(): @@ -520,3 +526,174 @@ def test_tensor_batch_none_tensor_list(): pickled = pickle.dumps(batch) unpickled = pickle.loads(pickled) assert unpickled["pixel_values"] is None + + +# --------------------------------------------------------------------------- +# Tests for pad_training_input_batch +# --------------------------------------------------------------------------- + + +# The canonical set of TrainingInput fields. Adding a new field to the TrainingInput TypedDict +# without updating `pad_training_input_batch()` and this set will make +# `test_pad_batch_typeddict_matches_expected_fields` fail, forcing the author to decide how the +# new field should be padded. +EXPECTED_TRAINING_INPUT_FIELDS = { + "sequences", + "attention_mask", + "loss_mask", + "response_mask", + "action_log_probs", + "base_action_log_probs", + "values", + "returns", + "advantages", + "kl", + "rewards", + "rollout_logprobs", + "rollout_expert_indices", + "pixel_values", + "image_grid_thw", +} + + +def _make_full_training_batch(batch_size: int = 4, seq_len: int = 5) -> TrainingInputBatch: + """Build a TrainingInputBatch populated with every TrainingInput field. + + ``pixel_values`` and ``image_grid_thw`` use variable per-row shapes to exercise TensorList. + """ + torch.manual_seed(0) + data = { + "sequences": torch.arange(batch_size * seq_len).reshape(batch_size, seq_len).long(), + "attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long), + "loss_mask": torch.ones(batch_size, seq_len, dtype=torch.float), + "response_mask": torch.ones(batch_size, seq_len, dtype=torch.long), + "action_log_probs": torch.randn(batch_size, seq_len), + "base_action_log_probs": torch.randn(batch_size, seq_len), + "values": torch.randn(batch_size, seq_len), + "returns": torch.randn(batch_size, seq_len), + "advantages": torch.randn(batch_size, seq_len), + "kl": torch.randn(batch_size, seq_len), + "rewards": torch.randn(batch_size, seq_len), + "rollout_logprobs": torch.randn(batch_size, seq_len), + "rollout_expert_indices": torch.randint(0, 8, (batch_size, seq_len, 2, 3), dtype=torch.long), + "pixel_values": TensorList([torch.randn(i + 1, 3) for i in range(batch_size)]), # batch_size * (i + 1) * 3 + "image_grid_thw": TensorList([torch.tensor([[1, 2, 3]]) for _ in range(batch_size)]), # batch_size * 1 * 3 + } + batch = TrainingInputBatch(data) + batch.metadata = { + "uids": [f"u{i}" for i in range(batch_size)], + "is_last_step": [i % 2 == 1 for i in range(batch_size)], # [False, True, False, True, ...] + "response_length": seq_len, + } + return batch + + +def test_pad_batch_typeddict_matches_expected_fields(): + """ + Guard: if TrainingInput gains a new field, bump EXPECTED_TRAINING_INPUT_FIELDS and make sure it is + well handled by pad_training_input_batch(). + """ + typed_dict_fields = set(TrainingInput.__annotations__.keys()) + assert typed_dict_fields == EXPECTED_TRAINING_INPUT_FIELDS, ( + "TrainingInput fields changed. Update EXPECTED_TRAINING_INPUT_FIELDS AND make sure " + "pad_training_input_batch() handles the new field." + ) + + +def test_pad_batch_all_fields(): + """Comprehensive test: pad_training_input_batch pads every field correctly. + + Verifies: all tensor fields have correct batch dim, original rows are untouched, + loss_mask padding is zero, other tensor padding is row-0 clones, TensorList padding + is row-0 clones, metadata (uids, is_last_step) is extended correctly, and + the input batch is not mutated. + """ + batch_size, seq_len, pad_size = 4, 5, 3 + batch = _make_full_training_batch(batch_size=batch_size, seq_len=seq_len) + + # Sanity: the test fixture must exercise every TrainingInput field. + assert ( + set(batch.keys()) == EXPECTED_TRAINING_INPUT_FIELDS + ), "Test fixture is missing TrainingInput fields; update _make_full_training_batch." + + # Snapshot input metadata before padding to verify immutability. + original_uids = list(batch.metadata["uids"]) + original_is_last_step = list(batch.metadata["is_last_step"]) + + padded = pad_training_input_batch(batch, pad_size=pad_size) + assert padded.batch_size == batch_size + pad_size + assert set(padded.keys()) == EXPECTED_TRAINING_INPUT_FIELDS, "Padded batch is missing TrainingInput fields" + + # --- Tensor fields --- + for key in EXPECTED_TRAINING_INPUT_FIELDS: + value = padded[key] + assert value is not None, f"Field {key!r} became None after padding" + assert len(value) == batch_size + pad_size, f"Field {key!r} has wrong batch dim" + + # loss_mask: original rows untouched, padding rows all-zero. + assert torch.equal(padded["loss_mask"][:batch_size], batch["loss_mask"]) + assert torch.all(padded["loss_mask"][batch_size:] == 0) + + # Regular tensor fields (not loss_mask, not TensorList): original rows untouched, + # padding rows are copies of row 0. + regular_tensor_keys = EXPECTED_TRAINING_INPUT_FIELDS - {"loss_mask", "pixel_values", "image_grid_thw"} + for key in regular_tensor_keys: + assert torch.equal(padded[key][:batch_size], batch[key]), f"Original rows changed for {key!r}" + for i in range(batch_size, batch_size + pad_size): + assert torch.equal(padded[key][i], batch[key][0]), f"Padding row {i} of {key!r} is not row 0" + + # TensorList fields (pixel_values, image_grid_thw): original rows untouched, + # padding rows are clones of row 0. + for key in ("pixel_values", "image_grid_thw"): + tl = padded[key] + assert isinstance(tl, TensorList) + for i in range(batch_size): + assert torch.equal(tl[i], batch[key][i]), f"Original TensorList row {i} changed for {key!r}" + for i in range(batch_size, batch_size + pad_size): + assert torch.equal(tl[i], batch[key][0]), f"TensorList padding row {i} of {key!r} is not row 0" + + # --- Metadata --- + assert padded.metadata["uids"] == [f"u{i}" for i in range(batch_size)] + [f"pad{i}" for i in range(pad_size)] + assert padded.metadata["is_last_step"] == [i % 2 == 1 for i in range(batch_size)] + [True] * pad_size + assert padded.metadata["pad_size"] == pad_size + assert padded.metadata["response_length"] == seq_len + + # --- Input immutability --- + assert batch.metadata["uids"] == original_uids, "pad_training_input_batch mutated input uids" + assert ( + batch.metadata["is_last_step"] == original_is_last_step + ), "pad_training_input_batch mutated input is_last_step" + + +def test_pad_batch_zero_pad_size_returns_same_batch(): + batch = _make_full_training_batch(batch_size=3, seq_len=5) + padded = pad_training_input_batch(batch, pad_size=0) + assert padded is batch + + +def test_pad_batch_rejects_non_cpu_device(): + if not torch.cuda.is_available(): + pytest.skip("no CUDA available") + batch = _make_full_training_batch(batch_size=2, seq_len=5) + cuda_data = { + "sequences": batch["sequences"].cuda(), + "loss_mask": batch["loss_mask"].cuda(), + } + cuda_batch = TrainingInputBatch(cuda_data) + cuda_batch.metadata = {"uids": ["u0", "u1"]} + with pytest.raises(AssertionError, match="expects batch on CPU"): + pad_training_input_batch(cuda_batch, pad_size=1) + + +def test_pad_batch_preserves_none_fields(): + batch = TrainingInputBatch( + { + "sequences": torch.arange(12).reshape(3, 4).long(), + "loss_mask": torch.ones(3, 4, dtype=torch.float), + "values": None, + } + ) + batch.metadata = {"uids": ["a", "b", "c"]} + padded = pad_training_input_batch(batch, pad_size=1) + assert padded["values"] is None + assert padded.batch_size == 4