Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minibatch impl #364

Merged
merged 10 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 234 additions & 0 deletions tests/test_minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import unittest
from dataclasses import dataclass, is_dataclass

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer

from trlx.pipeline import MiniBatchIterator
from trlx.pipeline.offline_pipeline import (
ILQLRolloutStorage,
ILQLSeq2SeqRolloutStorage,
PromptPipeline,
)


@dataclass
class DataclassBatch:
query_tensors: torch.Tensor
response_tensors: torch.Tensor
logprobs: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor


class DummyDataset(Dataset, DataclassBatch):
def __init__(self, num_samples):
self.query_tensors = torch.randn(num_samples, 64)
self.response_tensors = torch.randn(num_samples, 64)
self.logprobs = torch.randn(num_samples, 1)
self.values = torch.randn(num_samples, 1)
self.rewards = torch.randn(num_samples, 1)

def __len__(self):
return len(self.query_tensors)

def __getitem__(self, idx) -> DataclassBatch:
return DataclassBatch(
query_tensors=self.query_tensors[idx],
response_tensors=self.response_tensors[idx],
logprobs=self.logprobs[idx],
values=self.values[idx],
rewards=self.rewards[idx],
)


def collate_fn(batch):
return DataclassBatch(
query_tensors=torch.stack([sample.query_tensors for sample in batch]),
response_tensors=torch.stack([sample.response_tensors for sample in batch]),
logprobs=torch.stack([sample.logprobs for sample in batch]),
values=torch.stack([sample.values for sample in batch]),
rewards=torch.stack([sample.rewards for sample in batch]),
)


class BaseTestMiniBatchIterator(unittest.TestCase):
def check_mini_batch(self, mb, expected_mini_batch_size):
if is_dataclass(mb):
mb = mb.__dict__
for key, value in mb.items():
self.assertEqual(value.size(0), expected_mini_batch_size)


class TestMiniBatchDL(BaseTestMiniBatchIterator):
def test_batch(self):
batch = DataclassBatch(
torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5])
)
self.assertTrue(is_dataclass(batch))
self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values()))

def test_minibatch_iterator(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 4)

def test_minibatch_iterator_with_undivisible_mbsize(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3)

for minibatches in iterator:
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 3)

# last minibatch has only 2 samples
minibatch = minibatches[-1]
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

def test_minibatch_iterator_with_remainder(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(36)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)

for i in range(4):
minibatches = next(iterator)
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

# last iteration has only 2 minibatches
minibatches = next(iterator)
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

def test_minibatch_iterator_with_smaller_dataset(self):
# Create Dummy Dataset and DataLoader with size smaller than batch size
dummy_dataset = DummyDataset(6)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)

minibatches = next(iterator)

for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))

with self.assertRaises(StopIteration):
minibatches = next(iterator)

def test_minibatch_content(self):
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)

idx = 0
for minibatches in iterator:
for minibatch in minibatches:
for key in minibatch.__dict__.keys():
original_data = getattr(dummy_dataset, key)
start_idx = idx * minibatch.__dict__[key].size(0)
end_idx = start_idx + minibatch.__dict__[key].size(0)
expected_data = original_data[start_idx:end_idx]

# Check if the tensor content in the minibatch is consistent with the original dataset
self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data)))
idx += 1

# Test if the iterator covered all the samples in the dataset
self.assertEqual(idx * iterator.mb_size, len(dummy_dataset))


class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator):
def test_minibatch_iterator_with_prompt_pipeline(self):
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Create prompts
prompts = ["This is a test prompt."] * 32

prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer)

prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True)

iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertTrue("input_ids" in minibatch)
self.assertTrue("attention_mask" in minibatch)
self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor))
self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor))
self.check_mini_batch(minibatch, 4)


class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator):
def create_dummy_tensors(self, num_samples):
input_ids = torch.randint(0, 100, (num_samples, 10))
attention_mask = torch.randint(0, 2, (num_samples, 10))
rewards = torch.randn(num_samples, 1)
states_ixs = torch.randint(0, 100, (num_samples, 1))
actions_ixs = torch.randint(0, 100, (num_samples, 1))
dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool)

return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones

def test_minibatch_iterator_with_ilql_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)

# Create ILQLRolloutStorage instance
ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones)

ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8)

iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2)

for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)

def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)
decoder_input_ids = torch.randint(0, 100, (32, 10))

# Create ILQLSeq2SeqRolloutStorage instance
ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage(
input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones
)

ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8)

iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2)

for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)


if __name__ == "__main__":
unittest.main()
34 changes: 34 additions & 0 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import unittest
from typing import List, Mapping
from unittest.mock import patch

import trlx.utils.logging as logging
from trlx.data.configs import (
Expand Down Expand Up @@ -132,3 +133,36 @@ def test_save_checkpoint(self):
if total_steps % interval != 0:
self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint")))

def test_accumulate_context(self):
config = self.get_default_config()
trainer = self.get_trainer(config)
trainer.accelerator.gradient_accumulation_steps = 3

def run_test(mb_count, num_mb, total_steps, should_call_no_sync):
trainer.mb_count = mb_count
trainer.num_mb = num_mb
trainer.config.train.total_steps = total_steps

with patch.object(trainer.accelerator, "no_sync") as no_sync_tracker:
with patch("contextlib.nullcontext") as nullcontext_tracker:
with trainer._accumulate():
pass

self.assertEqual(no_sync_tracker.called, should_call_no_sync)
self.assertEqual(nullcontext_tracker.called, not should_call_no_sync)

# Test case 1: the context manager should call accelerator.no_sync
run_test(mb_count=1, num_mb=2, total_steps=4, should_call_no_sync=True)

# Test case 2: the context manager should sync because next mb_count is 3 (corresponds with gradient accumulation)
run_test(mb_count=2, num_mb=2, total_steps=4, should_call_no_sync=False)

# Test case 3: the context manager should sync because next mb_count is final step even though it is not % by 3
run_test(mb_count=3, num_mb=1, total_steps=4, should_call_no_sync=False)

# Test case 4: the context manager should call accelerator.no_sync
run_test(mb_count=3, num_mb=1, total_steps=6, should_call_no_sync=True)

# Test case 5: the context manager should sync because next mb_count is 28 and 28 // num_mb means it is the last step
run_test(mb_count=27, num_mb=4, total_steps=7, should_call_no_sync=False)
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class TrainConfig:

:param seed: Random seed
:type seed: int

:param minibatch_size: Size of model input during one forward pass. Must divide batch size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(very pedantic) nit, feel free to ignore: usually I've heard this called micro batch, with minibatch referring to what we usually call "a batch" (to distinguish from a single batch of the whole dataset)

:type minibatch_size: int
"""

total_steps: int
Expand Down Expand Up @@ -224,6 +227,8 @@ class TrainConfig:

seed: int = 1000

minibatch_size: Optional[int] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
73 changes: 73 additions & 0 deletions trlx/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import random
import sys
from abc import abstractmethod, abstractstaticmethod
from dataclasses import is_dataclass
from typing import Any, Callable, Dict, Iterable

from torch.utils.data import DataLoader, Dataset
from transformers.tokenization_utils_base import BatchEncoding

from trlx.data import GeneralElement, RLElement
from trlx.utils import logging

# specifies a dictionary of architectures
_DATAPIPELINE: Dict[str, any] = {} # registry

logger = logging.get_logger(__name__)


def register_datapipeline(name):
"""Decorator used register a CARP architecture
Expand Down Expand Up @@ -95,3 +100,71 @@ def create_loader(
:type prep_fn: Callable
"""
pass


class MiniBatchIterator:
"""
A custom iterator for generating mini-batches from a PyTorch DataLoader.
"""

def __init__(self, data_loader, mb_size, num_mb):
"""
Initializes the MiniBatchIterator.

Args:
data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from.
mb_size (int): The size of each mini-batch.
num_mb (int): The number of mini-batches to generate for each iteration.
"""
self.data_loader = data_loader
self.data_loader_iter = iter(data_loader)
self.mb_size = mb_size
self.num_mb = num_mb

def __iter__(self):
return self

def __next__(self):
batch = next(self.data_loader_iter)
minibatches = []

for mbi in range(self.num_mb):
sliced_data = {}
batch_dict = batch
if is_dataclass(batch):
batch_dict = batch.__dict__
for key, value in batch_dict.items():
start_idx = mbi * self.mb_size
end_idx = (mbi + 1) * self.mb_size
sliced_data[key] = value[start_idx:end_idx]

if len(sliced_data[key]) == 0:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with 0 elements. "
"This may be due to the wrong mb_size and/or num_mb or the last batch"
"in the dataset being smaller."
)
sliced_data.pop(key)
break
elif len(sliced_data[key]) < self.mb_size:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. "
"This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset "
"being smaller."
)
if not sliced_data:
break

if isinstance(batch, BatchEncoding):
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data

minibatches.append(minibatch)

if not minibatches:
raise StopIteration

return minibatches
Loading