Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions skyrl/backends/skyrl_train/training_batch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Defines interfaces for training data."""

import copy
import io
import pickle
from typing import Any, Dict, Generic, List, Optional, TypedDict, TypeVar
Expand Down Expand Up @@ -483,3 +484,60 @@ class TrainingOutputBatch(TensorBatch[Dict[str, torch.Tensor]]):
"""Training output data"""

pass


def pad_training_input_batch(unpadded_batch: TrainingInputBatch, pad_size: int) -> TrainingInputBatch:
"""Pad `pad_size` entries to `unpadded_batch`, return a newly allocated TrainingInputBatch. If pad_size is 0, return the original batch."""
# TODO(Charlie): This incurs 2x CPU memory usage when pad_size > 0. Optimize when needed.
# Padding allocates and concatenates; it should not happen on GPU hot path.
assert unpadded_batch.device is None or unpadded_batch.device == torch.device(
"cpu"
), f"pad_batch expects batch on CPU, got device={unpadded_batch.device}"
assert pad_size >= 0, f"pad_size must be >= 0, got {pad_size}"

# Handle the special case of no padding.
if pad_size == 0:
if unpadded_batch.metadata is None:
unpadded_batch.metadata = {}
unpadded_batch.metadata["pad_size"] = 0
return unpadded_batch
Comment thread
CharlieFRuan marked this conversation as resolved.

# Pad each tensor depending on its type.
new_tensors = {}
for key, tensor in unpadded_batch.items():
if tensor is None:
new_tensors[key] = None
continue

if isinstance(tensor, TensorList):
assert len(tensor) > 0, f"Cannot pad empty TensorList field {key!r}"
padding = TensorList([tensor[0].clone() for _ in range(pad_size)])
new_tensors[key] = TensorList.cat([tensor, padding])
elif key == "loss_mask":
# Ensures that padding tensors don't count towards the loss
additional_dims = tensor.shape[1:]
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
else:
# Copy row 0 `pad_size` times. Loss masked so values don't affect the loss. Just need valid shape/dtype.
assert tensor.shape[0] > 0, f"Cannot pad empty tensor field {key!r}"
pad_indices = [0] * pad_size
padding_tensor = tensor[pad_indices].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

# Update metadata as well.
new_metadata = {}
old_metadata = unpadded_batch.metadata or {}
for key, value in old_metadata.items():
if key == "uids":
new_metadata["uids"] = value + [f"pad{i}" for i in range(pad_size)]
elif key == "is_last_step":
new_metadata["is_last_step"] = value + [True for _ in range(pad_size)]
else:
new_metadata[key] = copy.deepcopy(value)
new_metadata["pad_size"] = pad_size

new_batch = TrainingInputBatch(new_tensors)
new_batch.metadata = new_metadata

return new_batch
29 changes: 9 additions & 20 deletions skyrl/backends/skyrl_train_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from skyrl.backends.skyrl_train.inference_servers.server_group import ServerGroup
from skyrl.backends.skyrl_train.inference_servers.utils import build_vllm_cli_args
from skyrl.backends.skyrl_train.inference_servers.vllm_router import VLLMRouter
from skyrl.backends.skyrl_train.training_batch import TrainingInputBatch
from skyrl.backends.skyrl_train.training_batch import (
TrainingInputBatch,
pad_training_input_batch,
)
from skyrl.backends.skyrl_train.workers.worker import PPORayActorGroup
from skyrl.backends.skyrl_train.workers.worker_dispatch import WorkerDispatch
from skyrl.env_vars import _SKYRL_USE_NEW_INFERENCE, SKYRL_RAY_PG_TIMEOUT_IN_S
Expand Down Expand Up @@ -448,25 +451,11 @@ def _pad_batch(
dp_size = self._dispatch.get_lcm_dp_size()
alignment = dp_size * micro_batch_size if micro_batch_size else dp_size
pad_size = (alignment - batch.batch_size % alignment) % alignment
if pad_size == 0:
return batch, 0

new_tensors = {}
for key, tensor in batch.items():
if tensor is not None:
if key == "loss_mask":
# Padding entries must not contribute to the loss
additional_dims = tensor.shape[1:]
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
else:
# Clone real data so shapes/dtypes are valid for the model
padding_tensor = tensor[torch.arange(pad_size) % tensor.shape[0]].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

padded = TrainingInputBatch(new_tensors)
padded.metadata = batch.metadata
logger.info(f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})")
return padded, pad_size
if pad_size > 0:
logger.info(
f"Padded batch from {batch.batch_size} to {batch.batch_size + pad_size} (alignment={alignment})"
)
return pad_training_input_batch(batch, pad_size), pad_size

def _extract_metrics(self, data: dict) -> dict[str, float]:
"""Extract training metrics from dispatch return dict.
Expand Down
55 changes: 8 additions & 47 deletions skyrl/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import math
import os
import shutil
Expand Down Expand Up @@ -27,7 +26,11 @@
from skyrl.backends.skyrl_train.inference_engines.utils import (
get_sampling_params_for_backend,
)
from skyrl.backends.skyrl_train.training_batch import TensorList, TrainingInputBatch
from skyrl.backends.skyrl_train.training_batch import (
TensorList,
TrainingInputBatch,
pad_training_input_batch,
)
from skyrl.backends.skyrl_train.utils import ppo_utils
from skyrl.backends.skyrl_train.utils.io import io
from skyrl.backends.skyrl_train.utils.ppo_utils import (
Expand Down Expand Up @@ -690,7 +693,9 @@ def convert_to_training_input(self, generator_output: GeneratorOutput, uids: Lis
) / len(response_ids)

logger.info(f"Number of sequences before padding: {len(training_input['sequences'])}")
training_input = self.pad_batch(training_input)
dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
training_input = pad_training_input_batch(training_input, pad_size)
logger.info(f"Number of sequences after padding: {len(training_input['sequences'])}")

return training_input
Expand Down Expand Up @@ -910,50 +915,6 @@ def dump_data(self, data: TrainingInputBatch, file_name: str):
data_save_dir.mkdir(parents=True, exist_ok=True)
data.save(data_save_dir / f"{file_name}.pkl")

def pad_batch(self, training_input: TrainingInputBatch) -> TrainingInputBatch:
"""Pad the batch to be divisible by dp size"""
import math

dp_size = self.dispatch.get_lcm_dp_size()
pad_size = math.ceil(training_input.batch_size / dp_size) * dp_size - training_input.batch_size
new_tensors = {}
training_input.metadata["pad_size"] = pad_size
if pad_size == 0:
return training_input
for key, tensor in training_input.items():
if tensor is not None:
if isinstance(tensor, TensorList):
n = len(tensor)
pad_indices = [i % n for i in range(pad_size)]
padding = TensorList([tensor[i].clone() for i in pad_indices])
new_tensors[key] = TensorList.cat([tensor, padding])
elif key == "loss_mask":
# ensures that padding tensors don't count towards the loss
additional_dims = tensor.shape[1:]
padding_tensor = torch.zeros(pad_size, *additional_dims, dtype=tensor.dtype, device=tensor.device)
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)
else:
# ensures all padding tensors are in a valid format by cloning `pad_size` from the original input
n = tensor.shape[0]
pad_indices = torch.arange(pad_size, device=tensor.device) % n
padding_tensor = tensor[pad_indices].clone()
new_tensors[key] = torch.cat([tensor, padding_tensor], dim=0)

new_training_input = TrainingInputBatch(new_tensors)
new_training_input.metadata = {}
for key, value in training_input.metadata.items():
if key == "uids":
new_training_input.metadata["uids"] = training_input.metadata["uids"] + [
f"pad{i}" for i in range(pad_size)
]
elif key == "is_last_step":
new_training_input.metadata["is_last_step"] = training_input.metadata["is_last_step"] + [
True for i in range(pad_size)
]
else:
new_training_input.metadata[key] = copy.deepcopy(value)
return new_training_input

@torch.no_grad()
def fwd_logprobs_values_reward(
self,
Expand Down
179 changes: 178 additions & 1 deletion tests/backends/skyrl_train/test_train_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
import ray
import torch

from skyrl.backends.skyrl_train.training_batch import TensorBatch, TensorList
from skyrl.backends.skyrl_train.training_batch import (
TensorBatch,
TensorList,
TrainingInput,
TrainingInputBatch,
pad_training_input_batch,
)


def test_train_batch_initialization():
Expand Down Expand Up @@ -520,3 +526,174 @@ def test_tensor_batch_none_tensor_list():
pickled = pickle.dumps(batch)
unpickled = pickle.loads(pickled)
assert unpickled["pixel_values"] is None


# ---------------------------------------------------------------------------
# Tests for pad_training_input_batch
# ---------------------------------------------------------------------------


# The canonical set of TrainingInput fields. Adding a new field to the TrainingInput TypedDict
# without updating `pad_training_input_batch()` and this set will make
# `test_pad_batch_typeddict_matches_expected_fields` fail, forcing the author to decide how the
# new field should be padded.
EXPECTED_TRAINING_INPUT_FIELDS = {
"sequences",
"attention_mask",
"loss_mask",
"response_mask",
"action_log_probs",
"base_action_log_probs",
"values",
"returns",
"advantages",
"kl",
"rewards",
"rollout_logprobs",
"rollout_expert_indices",
"pixel_values",
"image_grid_thw",
}


def _make_full_training_batch(batch_size: int = 4, seq_len: int = 5) -> TrainingInputBatch:
"""Build a TrainingInputBatch populated with every TrainingInput field.

``pixel_values`` and ``image_grid_thw`` use variable per-row shapes to exercise TensorList.
"""
torch.manual_seed(0)
data = {
"sequences": torch.arange(batch_size * seq_len).reshape(batch_size, seq_len).long(),
"attention_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"loss_mask": torch.ones(batch_size, seq_len, dtype=torch.float),
"response_mask": torch.ones(batch_size, seq_len, dtype=torch.long),
"action_log_probs": torch.randn(batch_size, seq_len),
"base_action_log_probs": torch.randn(batch_size, seq_len),
"values": torch.randn(batch_size, seq_len),
"returns": torch.randn(batch_size, seq_len),
"advantages": torch.randn(batch_size, seq_len),
"kl": torch.randn(batch_size, seq_len),
"rewards": torch.randn(batch_size, seq_len),
"rollout_logprobs": torch.randn(batch_size, seq_len),
"rollout_expert_indices": torch.randint(0, 8, (batch_size, seq_len, 2, 3), dtype=torch.long),
"pixel_values": TensorList([torch.randn(i + 1, 3) for i in range(batch_size)]), # batch_size * (i + 1) * 3
"image_grid_thw": TensorList([torch.tensor([[1, 2, 3]]) for _ in range(batch_size)]), # batch_size * 1 * 3
}
batch = TrainingInputBatch(data)
batch.metadata = {
"uids": [f"u{i}" for i in range(batch_size)],
"is_last_step": [i % 2 == 1 for i in range(batch_size)], # [False, True, False, True, ...]
"response_length": seq_len,
}
return batch


def test_pad_batch_typeddict_matches_expected_fields():
"""
Guard: if TrainingInput gains a new field, bump EXPECTED_TRAINING_INPUT_FIELDS and make sure it is
well handled by pad_training_input_batch().
"""
typed_dict_fields = set(TrainingInput.__annotations__.keys())
assert typed_dict_fields == EXPECTED_TRAINING_INPUT_FIELDS, (
"TrainingInput fields changed. Update EXPECTED_TRAINING_INPUT_FIELDS AND make sure "
"pad_training_input_batch() handles the new field."
)


def test_pad_batch_all_fields():
"""Comprehensive test: pad_training_input_batch pads every field correctly.

Verifies: all tensor fields have correct batch dim, original rows are untouched,
loss_mask padding is zero, other tensor padding is row-0 clones, TensorList padding
is row-0 clones, metadata (uids, is_last_step) is extended correctly, and
the input batch is not mutated.
"""
batch_size, seq_len, pad_size = 4, 5, 3
batch = _make_full_training_batch(batch_size=batch_size, seq_len=seq_len)

# Sanity: the test fixture must exercise every TrainingInput field.
assert (
set(batch.keys()) == EXPECTED_TRAINING_INPUT_FIELDS
), "Test fixture is missing TrainingInput fields; update _make_full_training_batch."

# Snapshot input metadata before padding to verify immutability.
original_uids = list(batch.metadata["uids"])
original_is_last_step = list(batch.metadata["is_last_step"])

padded = pad_training_input_batch(batch, pad_size=pad_size)
assert padded.batch_size == batch_size + pad_size
assert set(padded.keys()) == EXPECTED_TRAINING_INPUT_FIELDS, "Padded batch is missing TrainingInput fields"

# --- Tensor fields ---
for key in EXPECTED_TRAINING_INPUT_FIELDS:
value = padded[key]
assert value is not None, f"Field {key!r} became None after padding"
assert len(value) == batch_size + pad_size, f"Field {key!r} has wrong batch dim"

# loss_mask: original rows untouched, padding rows all-zero.
assert torch.equal(padded["loss_mask"][:batch_size], batch["loss_mask"])
assert torch.all(padded["loss_mask"][batch_size:] == 0)

# Regular tensor fields (not loss_mask, not TensorList): original rows untouched,
# padding rows are copies of row 0.
regular_tensor_keys = EXPECTED_TRAINING_INPUT_FIELDS - {"loss_mask", "pixel_values", "image_grid_thw"}
for key in regular_tensor_keys:
assert torch.equal(padded[key][:batch_size], batch[key]), f"Original rows changed for {key!r}"
for i in range(batch_size, batch_size + pad_size):
assert torch.equal(padded[key][i], batch[key][0]), f"Padding row {i} of {key!r} is not row 0"

# TensorList fields (pixel_values, image_grid_thw): original rows untouched,
# padding rows are clones of row 0.
for key in ("pixel_values", "image_grid_thw"):
tl = padded[key]
assert isinstance(tl, TensorList)
for i in range(batch_size):
assert torch.equal(tl[i], batch[key][i]), f"Original TensorList row {i} changed for {key!r}"
for i in range(batch_size, batch_size + pad_size):
assert torch.equal(tl[i], batch[key][0]), f"TensorList padding row {i} of {key!r} is not row 0"

# --- Metadata ---
assert padded.metadata["uids"] == [f"u{i}" for i in range(batch_size)] + [f"pad{i}" for i in range(pad_size)]
assert padded.metadata["is_last_step"] == [i % 2 == 1 for i in range(batch_size)] + [True] * pad_size
assert padded.metadata["pad_size"] == pad_size
assert padded.metadata["response_length"] == seq_len

# --- Input immutability ---
assert batch.metadata["uids"] == original_uids, "pad_training_input_batch mutated input uids"
assert (
batch.metadata["is_last_step"] == original_is_last_step
), "pad_training_input_batch mutated input is_last_step"


def test_pad_batch_zero_pad_size_returns_same_batch():
batch = _make_full_training_batch(batch_size=3, seq_len=5)
padded = pad_training_input_batch(batch, pad_size=0)
assert padded is batch


def test_pad_batch_rejects_non_cpu_device():
if not torch.cuda.is_available():
pytest.skip("no CUDA available")
batch = _make_full_training_batch(batch_size=2, seq_len=5)
cuda_data = {
"sequences": batch["sequences"].cuda(),
"loss_mask": batch["loss_mask"].cuda(),
}
cuda_batch = TrainingInputBatch(cuda_data)
cuda_batch.metadata = {"uids": ["u0", "u1"]}
with pytest.raises(AssertionError, match="expects batch on CPU"):
pad_training_input_batch(cuda_batch, pad_size=1)


def test_pad_batch_preserves_none_fields():
batch = TrainingInputBatch(
{
"sequences": torch.arange(12).reshape(3, 4).long(),
"loss_mask": torch.ones(3, 4, dtype=torch.float),
"values": None,
}
)
batch.metadata = {"uids": ["a", "b", "c"]}
padded = pad_training_input_batch(batch, pad_size=1)
assert padded["values"] is None
assert padded.batch_size == 4
Loading