diff --git a/skyrl/backends/skyrl_train/training_batch.py b/skyrl/backends/skyrl_train/training_batch.py index e4661e33df..3609ecefdb 100644 --- a/skyrl/backends/skyrl_train/training_batch.py +++ b/skyrl/backends/skyrl_train/training_batch.py @@ -1,12 +1,13 @@ """Defines interfaces for training data.""" +import copy import io import pickle -from typing import Any, Dict, Generic, List, Optional, TypedDict, TypeVar +from typing import Any, Dict, Generic, List, Literal, Optional, TypedDict, TypeVar import numpy as np import torch -from jaxtyping import Float, Integer +from jaxtyping import Bool, Float, Integer DictType = TypeVar("DictType") @@ -471,6 +472,7 @@ class TrainingInput(TypedDict, total=False): rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]] pixel_values: Optional[TensorList] # list of [num_patches_i, dim] image_grid_thw: Optional[TensorList] # list of [num_images_i, 3] + is_last_step: Optional[Bool[torch.Tensor, "batch_size"]] class TrainingInputBatch(TensorBatch[TrainingInput]): @@ -483,3 +485,117 @@ class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]): """Training output data""" pass + + +# Keys that pad_batch() pads with constants rather than cloned real data. +# - loss_mask: zero-filled so padding rows do not contribute to the loss. +# - is_last_step: True-filled so padding rows do not inflate per-trajectory counts +# (a padding row belongs to no real trajectory; treating it as its own terminal step +# keeps cumulative trajectory-id indexing consistent with the real rows). +_PAD_BATCH_CONSTANT_KEYS = ("loss_mask", "is_last_step") + + +def pad_batch( + batch: "TrainingInputBatch", + pad_size: int, + *, + mode: Literal["train_batch", "mini_batch"], +) -> "TrainingInputBatch": + """Pad ``batch`` with ``pad_size`` dummy rows so that downstream slicing is well-defined. + + Padding strategy per field: + - ``loss_mask``: zeros (so padding rows do not contribute to the loss). + - ``is_last_step``: ones/True (so cumulative trajectory-id indexing is consistent). + - Everything else (including ``TensorList`` fields like ``pixel_values``): the first row + is repeated ``pad_size`` times. ``pad_size`` is allowed to exceed the batch size (e.g. + per-mini-batch padding where mb_size=1 and dp_size=4). Since ``loss_mask=0`` on these + rows, their values don't affect the loss and only need to be shape/dtype-valid. + + Args: + batch: Batch to pad. Must be on CPU. + pad_size: Number of rows to append. ``0`` is allowed and returns the input unchanged + (except for recording ``metadata["pad_size"] = 0``). + mode: Determines how metadata is handled. + - ``"train_batch"``: The caller owns the full training batch. ``uids`` and + ``trajectory_ids`` in metadata are extended with synthetic ``"pad{i}"`` entries so + they stay aligned with the new batch size. Downstream boundary computation, metric + extraction, and advantage normalization need this alignment. + - ``"mini_batch"``: The caller is padding a transient mini-batch slice whose metadata + still references the parent batch (e.g. ``stage_chunks``). We do not mutate + ``uids``/``trajectory_ids`` here because the parent's lists wouldn't line up with a + sliced window anyway, and the per-chunk workers don't rely on them. + + Returns: + A new ``TrainingInputBatch`` with ``batch_size + pad_size`` rows and + ``metadata["pad_size"] = pad_size`` recorded. + """ + assert pad_size >= 0, f"pad_size must be >= 0, got {pad_size}" + assert mode in ("train_batch", "mini_batch"), f"unknown pad_batch mode: {mode!r}" + # Padding allocates and concatenates; it must happen on the main-process CPU staging area, + # not on GPU workers where we'd be materializing extra memory on the hot path. + assert batch.device is None or batch.device == torch.device("cpu"), ( + f"pad_batch expects batch on CPU, got device={batch.device}" + ) + + if pad_size == 0: + # Still record pad_size so downstream code has a uniform view. + new_batch = batch.__class__(dict(batch)) + new_batch.metadata = copy.deepcopy(batch.metadata) if batch.metadata is not None else {} + new_batch.metadata["pad_size"] = 0 + return new_batch + + new_tensors: Dict[str, Any] = {} + for key, value in batch.items(): + if value is None: + new_tensors[key] = None + continue + + if isinstance(value, TensorList): + n = len(value) + assert n > 0, f"Cannot pad empty TensorList field {key!r}" + pad_indices = [i % n for i in range(pad_size)] + padding = TensorList([value[i].clone() for i in pad_indices]) + new_tensors[key] = TensorList.cat([value, padding]) + continue + + assert isinstance(value, torch.Tensor), ( + f"pad_batch expected Tensor or TensorList for field {key!r}, got {type(value).__name__}" + ) + + trailing_dims = tuple(value.shape[1:]) if value.ndim > 1 else () + if key == "loss_mask": + padding_tensor = torch.zeros(pad_size, *trailing_dims, dtype=value.dtype, device=value.device) + elif key == "is_last_step": + padding_tensor = torch.ones(pad_size, *trailing_dims, dtype=value.dtype, device=value.device) + else: + # Cyclic clone of row 0 so that this works even when pad_size > len(value) + # (e.g. pad_size=3, batch_size=1 under dp_size=4). + n = value.shape[0] + assert n > 0, f"Cannot pad empty tensor field {key!r}" + pad_indices = torch.arange(pad_size, device=value.device) % n + padding_tensor = value[pad_indices].clone() + new_tensors[key] = torch.cat([value, padding_tensor], dim=0) + + new_batch = batch.__class__(new_tensors) + + # Metadata handling differs by mode. + old_metadata = batch.metadata if batch.metadata is not None else {} + if mode == "train_batch": + new_metadata: Dict[str, Any] = {} + if "uids" in old_metadata: + new_metadata["uids"] = list(old_metadata["uids"]) + [f"pad{i}" for i in range(pad_size)] + if "trajectory_ids" in old_metadata: + new_metadata["trajectory_ids"] = list(old_metadata["trajectory_ids"]) + [ + f"pad{i}" for i in range(pad_size) + ] + for key, value in old_metadata.items(): + if key not in ("uids", "trajectory_ids"): + new_metadata[key] = copy.deepcopy(value) + else: + # mini_batch: the caller is padding a transient slice. Copy metadata as-is; do NOT touch + # uids/trajectory_ids because they wouldn't correspond to this slice anyway. + new_metadata = copy.deepcopy(old_metadata) + + new_metadata["pad_size"] = pad_size + new_batch.metadata = new_metadata + return new_batch diff --git a/skyrl/train/trainer.py b/skyrl/train/trainer.py index 9cb01d5f30..49c9ec1b5f 100644 --- a/skyrl/train/trainer.py +++ b/skyrl/train/trainer.py @@ -27,7 +27,7 @@ from skyrl.backends.skyrl_train.inference_engines.utils import ( get_sampling_params_for_backend, ) -from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch +from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch, pad_batch from skyrl.backends.skyrl_train.utils import ppo_utils from skyrl.backends.skyrl_train.utils.io import io from skyrl.backends.skyrl_train.utils.ppo_utils import ( @@ -694,7 +694,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis ) / len(response_ids) logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}") - training_input = self.pad_batch(training_input) + dp_size = self.dispatch.get_lcm_dp_size() + pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size + training_input = pad_batch(training_input, pad_size, mode="train_batch") logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}") return training_input @@ -913,51 +915,6 @@ def dump_data(self, data: TrainingInputBatch, file_name: str): data_save_dir.mkdir(parents=True, exist_ok=True) data.save(data_save_dir / f"{file_name}.pkl") - def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch: - """Pad the batch to be divisible by dp size""" - import math - - dp_size = self.dispatch.get_lcm_dp_size() - pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size - new_tensors = {} - training_input.metadata["pad_size"] = pad_size - if pad_size == 0: - return training_input - for key, tensor in training_input.items(): - if tensor is not None: - if isinstance(tensor, TensorList): - n = len(tensor) - pad_indices = [i % n for i in range(pad_size)] - padding = TensorList([tensor[i].clone() for i in pad_indices]) - new_tensors[key] = TensorList.cat([tensor, padding]) - elif key == "is_last_step": - additional_dims = tensor.shape[1:] - padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - elif key == "loss_mask": - # ensures that padding tensors don't count towards the loss - additional_dims = tensor.shape[1:] - padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device) - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - else: - # ensures all padding tensors are in a valid format by cloning `pad_size` from the original input - n = tensor.shape[0] - pad_indices = torch.arange(pad_size, device=tensor.device) % n - padding_tensor = tensor[pad_indices].clone() - new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0) - - new_training_input = TrainingInputBatch(new_tensors) - new_training_input.metadata = {} - new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)] - if "trajectory_ids" in training_input.metadata: - new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [ - f"pad{i}" for i in range(pad_size) - ] - for key, value in training_input.metadata.items(): - if key not in ["uids", "trajectory_ids"]: - new_training_input.metadata[key] = copy.deepcopy(value) - return new_training_input - @torch.no_grad() def fwd_logprobs_values_reward( self, diff --git a/tests/backends/skyrl_train/test_train_batch.py b/tests/backends/skyrl_train/test_train_batch.py index 364463af2a..4ff8dcf32d 100644 --- a/tests/backends/skyrl_train/test_train_batch.py +++ b/tests/backends/skyrl_train/test_train_batch.py @@ -5,7 +5,13 @@ import ray import torch -from skyrl.backends.skyrl_train.training_batch import TensorBatch, TensorList +from skyrl.backends.skyrl_train.training_batch import ( + TensorBatch, + TensorList, + TrainingInput, + TrainingInputBatch, + pad_batch, +) def test_train_batch_initialization(): @@ -520,3 +526,222 @@ def test_tensor_batch_none_tensor_list(): pickled = pickle.dumps(batch) unpickled = pickle.loads(pickled) assert unpickled["pixel_values"] is None + + +# --------------------------------------------------------------------------- +# Tests for pad_batch +# --------------------------------------------------------------------------- + + +# The canonical set of TrainingInput fields. Adding a new field to the TrainingInput TypedDict +# without updating `pad_batch()` and this set will make `test_pad_batch_covers_all_fields` fail, +# forcing the author to decide how the new field should be padded. +EXPECTED_TRAINING_INPUT_FIELDS = { + "sequences", + "attention_mask", + "loss_mask", + "response_mask", + "action_log_probs", + "base_action_log_probs", + "values", + "returns", + "advantages", + "kl", + "rewards", + "rollout_logprobs", + "rollout_expert_indices", + "pixel_values", + "image_grid_thw", + "is_last_step", +} + + +def _make_full_training_batch(batch_size: int = 4, seq_len: int = 5) -> TrainingInputBatch: + """Build a TrainingInputBatch populated with every TrainingInput field. + + The values are deterministic and distinctive so tests can compare slices of the padded + result against the original rows. ``pixel_values`` and ``image_grid_thw`` use variable + per-row shapes to exercise the TensorList branch. + """ + torch.manual_seed(0) + data = { + "sequences": torch.arange(batch_size * seq_len).reshape(batch_size, seq_len).long(), + "attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long), + "loss_mask": torch.ones(batch_size, seq_len, dtype=torch.float), + "response_mask": torch.ones(batch_size, seq_len, dtype=torch.long), + "action_log_probs": torch.randn(batch_size, seq_len), + "base_action_log_probs": torch.randn(batch_size, seq_len), + "values": torch.randn(batch_size, seq_len), + "returns": torch.randn(batch_size, seq_len), + "advantages": torch.randn(batch_size, seq_len), + "kl": torch.randn(batch_size, seq_len), + "rewards": torch.randn(batch_size, seq_len), + "rollout_logprobs": torch.randn(batch_size, seq_len), + # rollout_expert_indices is 4D (batch, seq, layer, topk) + "rollout_expert_indices": torch.randint(0, 8, (batch_size, seq_len, 2, 3), dtype=torch.long), + "pixel_values": TensorList([torch.randn(i + 1, 3) for i in range(batch_size)]), + "image_grid_thw": TensorList([torch.tensor([[1, 2, 3]]) for _ in range(batch_size)]), + "is_last_step": torch.tensor([False, True, False, True], dtype=torch.bool)[:batch_size], + } + batch = TrainingInputBatch(data) + batch.metadata = { + "uids": [f"u{i}" for i in range(batch_size)], + "trajectory_ids": [f"t{i}" for i in range(batch_size)], + "response_length": seq_len, + } + return batch + + +def test_pad_batch_typeddict_matches_expected_fields(): + """Guard: if TrainingInput gains a new field, bump EXPECTED_TRAINING_INPUT_FIELDS AND pad_batch.""" + typed_dict_fields = set(TrainingInput.__annotations__.keys()) + assert typed_dict_fields == EXPECTED_TRAINING_INPUT_FIELDS, ( + "TrainingInput fields changed. Update EXPECTED_TRAINING_INPUT_FIELDS AND make sure " + "pad_batch() handles the new field. This mirrors the pattern used by " + "test_generator_output_concatenation to keep pad_batch() in sync with the schema." + ) + + +def test_pad_batch_covers_all_fields(): + """pad_batch must produce correctly-shaped padding for every field in TrainingInput.""" + batch = _make_full_training_batch(batch_size=4, seq_len=5) + # Sanity: the test fixture itself must exercise every field. + assert set(batch.keys()) == EXPECTED_TRAINING_INPUT_FIELDS, ( + "Test fixture is missing TrainingInput fields; update _make_full_training_batch." + ) + + padded = pad_batch(batch, pad_size=3, mode="train_batch") + assert padded.batch_size == 4 + 3 + # Every original field must still be present (and non-None) with correct batch dim. + for key in EXPECTED_TRAINING_INPUT_FIELDS: + value = padded[key] + assert value is not None, f"Field {key!r} became None after padding" + assert len(value) == 4 + 3, f"Field {key!r} has wrong batch dim after padding" + + +def test_pad_batch_loss_mask_is_zero_on_padding(): + batch = _make_full_training_batch(batch_size=4, seq_len=5) + padded = pad_batch(batch, pad_size=2, mode="train_batch") + # Original rows untouched, padding rows all-zero. + assert torch.equal(padded["loss_mask"][:4], batch["loss_mask"]) + assert torch.all(padded["loss_mask"][4:] == 0) + + +def test_pad_batch_is_last_step_is_true_on_padding(): + batch = _make_full_training_batch(batch_size=4, seq_len=5) + padded = pad_batch(batch, pad_size=2, mode="train_batch") + assert torch.equal(padded["is_last_step"][:4], batch["is_last_step"]) + assert torch.all(padded["is_last_step"][4:]) + + +def test_pad_batch_other_fields_cycle_from_real_rows(): + batch = _make_full_training_batch(batch_size=4, seq_len=5) + padded = pad_batch(batch, pad_size=5, mode="train_batch") # pad_size > batch_size + # Sequences should be original[0..3] + cycling(original[0..3])[:5] + assert torch.equal(padded["sequences"][:4], batch["sequences"]) + cycle_idx = torch.arange(5) % 4 + assert torch.equal(padded["sequences"][4:], batch["sequences"][cycle_idx]) + + +def test_pad_batch_pad_size_larger_than_batch_size(): + """Regression: mini_batch=1, dp_size=4 -> pad_size=3. Must not slice off the end.""" + batch = _make_full_training_batch(batch_size=1, seq_len=5) + padded = pad_batch(batch, pad_size=3, mode="mini_batch") + assert padded.batch_size == 4 + # All real-data fields should be cycles of row 0. + for key in ("sequences", "advantages", "response_mask"): + real = batch[key] + for i in range(4): + assert torch.equal(padded[key][i], real[0]) + # loss_mask still zero on padding rows. + assert torch.all(padded["loss_mask"][1:] == 0) + + +def test_pad_batch_tensor_list_handles_pad_size_larger_than_batch(): + """TensorList fields (VLM): cyclic clone works even when pad_size > batch_size.""" + batch = _make_full_training_batch(batch_size=2, seq_len=5) + padded = pad_batch(batch, pad_size=5, mode="train_batch") + pv = padded["pixel_values"] + assert isinstance(pv, TensorList) + assert len(pv) == 2 + 5 + # First 2 are originals, next 5 cycle: 0,1,0,1,0 + for i in range(2): + assert torch.equal(pv[i], batch["pixel_values"][i]) + expected_cycle = [0, 1, 0, 1, 0] + for i, src in enumerate(expected_cycle): + assert torch.equal(pv[2 + i], batch["pixel_values"][src]) + + +def test_pad_batch_train_batch_mode_extends_metadata_uids(): + batch = _make_full_training_batch(batch_size=3, seq_len=5) + padded = pad_batch(batch, pad_size=2, mode="train_batch") + assert padded.metadata["uids"] == ["u0", "u1", "u2", "pad0", "pad1"] + assert padded.metadata["trajectory_ids"] == ["t0", "t1", "t2", "pad0", "pad1"] + assert padded.metadata["pad_size"] == 2 + # Other metadata keys are deep-copied through. + assert padded.metadata["response_length"] == 5 + + +def test_pad_batch_train_batch_mode_does_not_mutate_input_metadata(): + batch = _make_full_training_batch(batch_size=3, seq_len=5) + original_uids = list(batch.metadata["uids"]) + _ = pad_batch(batch, pad_size=2, mode="train_batch") + assert batch.metadata["uids"] == original_uids, "pad_batch must not mutate the input batch" + + +def test_pad_batch_mini_batch_mode_does_not_extend_uids(): + """In mini_batch mode, uids/trajectory_ids are passed through unchanged. + + Rationale: mini_batch mode runs on a transient slice whose metadata still references the + parent's uid list (which doesn't correspond to the slice anyway), so extending it would + produce nonsense. + """ + batch = _make_full_training_batch(batch_size=3, seq_len=5) + padded = pad_batch(batch, pad_size=2, mode="mini_batch") + assert padded.metadata["uids"] == ["u0", "u1", "u2"] + assert padded.metadata["trajectory_ids"] == ["t0", "t1", "t2"] + assert padded.metadata["pad_size"] == 2 + + +def test_pad_batch_zero_pad_size_is_noop_but_records_pad_size(): + batch = _make_full_training_batch(batch_size=3, seq_len=5) + padded = pad_batch(batch, pad_size=0, mode="train_batch") + assert padded.batch_size == 3 + assert padded.metadata["pad_size"] == 0 + # Content unchanged + assert torch.equal(padded["sequences"], batch["sequences"]) + + +def test_pad_batch_rejects_non_cpu_device(): + if not torch.cuda.is_available(): + pytest.skip("no CUDA available") + batch = _make_full_training_batch(batch_size=2, seq_len=5) + # Build a CUDA-only batch. + cuda_data = { + "sequences": batch["sequences"].cuda(), + "loss_mask": batch["loss_mask"].cuda(), + } + cuda_batch = TrainingInputBatch(cuda_data) + cuda_batch.metadata = {"uids": ["u0", "u1"]} + with pytest.raises(AssertionError, match="expects batch on CPU"): + pad_batch(cuda_batch, pad_size=1, mode="train_batch") + + +def test_pad_batch_rejects_unknown_mode(): + batch = _make_full_training_batch(batch_size=2, seq_len=5) + with pytest.raises(AssertionError, match="unknown pad_batch mode"): + pad_batch(batch, pad_size=1, mode="bogus") # type: ignore[arg-type] + + +def test_pad_batch_preserves_none_fields(): + batch = TrainingInputBatch( + { + "sequences": torch.arange(12).reshape(3, 4).long(), + "loss_mask": torch.ones(3, 4, dtype=torch.float), + "values": None, + } + ) + batch.metadata = {"uids": ["a", "b", "c"]} + padded = pad_batch(batch, pad_size=1, mode="train_batch") + assert padded["values"] is None + assert padded.batch_size == 4