Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 118 additions & 2 deletions skyrl/backends/skyrl_train/training_batch.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Defines interfaces for training data."""

import copy
import io
import pickle
from typing import Any, Dict, Generic, List, Optional, TypedDict, TypeVar
from typing import Any, Dict, Generic, List, Literal, Optional, TypedDict, TypeVar

import numpy as np
import torch
from jaxtyping import Float, Integer
from jaxtyping import Bool, Float, Integer

DictType = TypeVar("DictType")

Expand Down Expand Up @@ -471,6 +472,7 @@ class TrainingInput(TypedDict, total=False):
rollout_expert_indices: Optional[Integer[torch.Tensor, "batch_size seq_len layer_num topk"]]
pixel_values: Optional[TensorList] # list of [num_patches_i, dim]
image_grid_thw: Optional[TensorList] # list of [num_images_i, 3]
is_last_step: Optional[Bool[torch.Tensor, "batch_size"]]


class TrainingInputBatch(TensorBatch[TrainingInput]):
Expand All @@ -483,3 +485,117 @@ class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]):
"""Training output data"""

pass


# Keys that pad_batch() pads with constants rather than cloned real data.
# - loss_mask: zero-filled so padding rows do not contribute to the loss.
# - is_last_step: True-filled so padding rows do not inflate per-trajectory counts
# (a padding row belongs to no real trajectory; treating it as its own terminal step
# keeps cumulative trajectory-id indexing consistent with the real rows).
_PAD_BATCH_CONSTANT_KEYS = ("loss_mask", "is_last_step")


def pad_batch(
batch: "TrainingInputBatch",
pad_size: int,
*,
mode: Literal["train_batch", "mini_batch"],
) -> "TrainingInputBatch":
"""Pad ``batch`` with ``pad_size`` dummy rows so that downstream slicing is well-defined.

Padding strategy per field:
- ``loss_mask``: zeros (so padding rows do not contribute to the loss).
- ``is_last_step``: ones/True (so cumulative trajectory-id indexing is consistent).
- Everything else (including ``TensorList`` fields like ``pixel_values``): the first row
is repeated ``pad_size`` times. ``pad_size`` is allowed to exceed the batch size (e.g.
per-mini-batch padding where mb_size=1 and dp_size=4). Since ``loss_mask=0`` on these
rows, their values don't affect the loss and only need to be shape/dtype-valid.

Args:
batch: Batch to pad. Must be on CPU.
pad_size: Number of rows to append. ``0`` is allowed and returns the input unchanged
(except for recording ``metadata["pad_size"] = 0``).
mode: Determines how metadata is handled.
- ``"train_batch"``: The caller owns the full training batch. ``uids`` and
``trajectory_ids`` in metadata are extended with synthetic ``"pad{i}"`` entries so
they stay aligned with the new batch size. Downstream boundary computation, metric
extraction, and advantage normalization need this alignment.
- ``"mini_batch"``: The caller is padding a transient mini-batch slice whose metadata
still references the parent batch (e.g. ``stage_chunks``). We do not mutate
``uids``/``trajectory_ids`` here because the parent's lists wouldn't line up with a
sliced window anyway, and the per-chunk workers don't rely on them.

Returns:
A new ``TrainingInputBatch`` with ``batch_size + pad_size`` rows and
``metadata["pad_size"] = pad_size`` recorded.
"""
assert pad_size >= 0, f"pad_size must be >= 0, got {pad_size}"
assert mode in ("train_batch", "mini_batch"), f"unknown pad_batch mode: {mode!r}"
# Padding allocates and concatenates; it must happen on the main-process CPU staging area,
# not on GPU workers where we'd be materializing extra memory on the hot path.
assert batch.device is None or batch.device == torch.device("cpu"), (
f"pad_batch expects batch on CPU, got device={batch.device}"
)

if pad_size == 0:
# Still record pad_size so downstream code has a uniform view.
new_batch = batch.__class__(dict(batch))
new_batch.metadata = copy.deepcopy(batch.metadata) if batch.metadata is not None else {}
new_batch.metadata["pad_size"] = 0
return new_batch

new_tensors: Dict[str, Any] = {}
for key, value in batch.items():
if value is None:
new_tensors[key] = None
continue

if isinstance(value, TensorList):
n = len(value)
assert n > 0, f"Cannot pad empty TensorList field {key!r}"
pad_indices = [i % n for i in range(pad_size)]
padding = TensorList([value[i].clone() for i in pad_indices])
new_tensors[key] = TensorList.cat([value, padding])
continue

assert isinstance(value, torch.Tensor), (
f"pad_batch expected Tensor or TensorList for field {key!r}, got {type(value).__name__}"
)

trailing_dims = tuple(value.shape[1:]) if value.ndim > 1 else ()
if key == "loss_mask":
padding_tensor = torch.zeros(pad_size, *trailing_dims, dtype=value.dtype, device=value.device)
elif key == "is_last_step":
padding_tensor = torch.ones(pad_size, *trailing_dims, dtype=value.dtype, device=value.device)
else:
# Cyclic clone of row 0 so that this works even when pad_size > len(value)
# (e.g. pad_size=3, batch_size=1 under dp_size=4).
n = value.shape[0]
assert n > 0, f"Cannot pad empty tensor field {key!r}"
pad_indices = torch.arange(pad_size, device=value.device) % n
padding_tensor = value[pad_indices].clone()
new_tensors[key] = torch.cat([value, padding_tensor], dim=0)

new_batch = batch.__class__(new_tensors)

# Metadata handling differs by mode.
old_metadata = batch.metadata if batch.metadata is not None else {}
if mode == "train_batch":
new_metadata: Dict[str, Any] = {}
if "uids" in old_metadata:
new_metadata["uids"] = list(old_metadata["uids"]) + [f"pad{i}" for i in range(pad_size)]
if "trajectory_ids" in old_metadata:
new_metadata["trajectory_ids"] = list(old_metadata["trajectory_ids"]) + [
f"pad{i}" for i in range(pad_size)
]
for key, value in old_metadata.items():
if key not in ("uids", "trajectory_ids"):
new_metadata[key] = copy.deepcopy(value)
else:
# mini_batch: the caller is padding a transient slice. Copy metadata as-is; do NOT touch
# uids/trajectory_ids because they wouldn't correspond to this slice anyway.
new_metadata = copy.deepcopy(old_metadata)

new_metadata["pad_size"] = pad_size
new_batch.metadata = new_metadata
return new_batch
51 changes: 4 additions & 47 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from skyrl.backends.skyrl_train.inference_engines.utils import (
get_sampling_params_for_backend,
)
from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch
from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch, pad_batch
from skyrl.backends.skyrl_train.utils import ppo_utils
from skyrl.backends.skyrl_train.utils.io import io
from skyrl.backends.skyrl_train.utils.ppo_utils import (
Expand Down Expand Up @@ -694,7 +694,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
) / len(response_ids)

logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
training_input = pad_batch(training_input, pad_size, mode="train_batch")
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

return training_input
Expand Down Expand Up @@ -913,51 +915,6 @@ def dump_data(self, data: TrainingInputBatch, file_name: str):
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")

def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
"""Pad the batch to be divisible by dp size"""
import math

dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
new_tensors = {}
training_input.metadata["pad_size"] = pad_size
if pad_size == 0:
return training_input
for key, tensor in training_input.items():
if tensor is not None:
if isinstance(tensor, TensorList):
n = len(tensor)
pad_indices = [i % n for i in range(pad_size)]
padding = TensorList([tensor[i].clone() for i in pad_indices])
new_tensors[key] = TensorList.cat([tensor, padding])
elif key == "is_last_step":
additional_dims = tensor.shape[1:]
padding_tensor = torch.ones(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
elif key == "loss_mask":
# ensures that padding tensors don't count towards the loss
additional_dims = tensor.shape[1:]
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
else:
# ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
n = tensor.shape[0]
pad_indices = torch.arange(pad_size, device=tensor.device) % n
padding_tensor = tensor[pad_indices].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

new_training_input = TrainingInputBatch(new_tensors)
new_training_input.metadata = {}
new_training_input.metadata["uids"] = training_input.metadata["uids"] + [f"pad{i}" for i in range(pad_size)]
if "trajectory_ids" in training_input.metadata:
new_training_input.metadata["trajectory_ids"] = training_input.metadata["trajectory_ids"] + [
f"pad{i}" for i in range(pad_size)
]
for key, value in training_input.metadata.items():
if key not in ["uids", "trajectory_ids"]:
new_training_input.metadata[key] = copy.deepcopy(value)
return new_training_input

@torch.no_grad()
def fwd_logprobs_values_reward(
self,
Expand Down
Loading
Loading