Skip to content

Commit

Permalink
Minibatch impl (#364)
Browse files Browse the repository at this point in the history
This PR adds gradient accumulation and minibatching to the accelerate trainers.

* fixes half exp not implemented error

* added minibatching

* fix num_mb name

* fix minibatch indexing

* fixing style

* fixing style

* fixing style

* Minibatch iterator (#403)

* Add minibatch iterator

* Add tests

* Avoid gradient synchronization when accumulating (#396)

* Avoid gradient synchronization when accumulating

* Fix accumulation to account for dataloader

* Add some tests

---------

Co-authored-by: Enxhell <[email protected]>
  • Loading branch information
Dahoas and eluzhnica committed Apr 6, 2023
1 parent 9fd1f0a commit 565c316
Show file tree
Hide file tree
Showing 5 changed files with 391 additions and 7 deletions.
234 changes: 234 additions & 0 deletions tests/test_minibatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import unittest
from dataclasses import dataclass, is_dataclass

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer

from trlx.pipeline import MiniBatchIterator
from trlx.pipeline.offline_pipeline import (
ILQLRolloutStorage,
ILQLSeq2SeqRolloutStorage,
PromptPipeline,
)


@dataclass
class DataclassBatch:
query_tensors: torch.Tensor
response_tensors: torch.Tensor
logprobs: torch.Tensor
values: torch.Tensor
rewards: torch.Tensor


class DummyDataset(Dataset, DataclassBatch):
def __init__(self, num_samples):
self.query_tensors = torch.randn(num_samples, 64)
self.response_tensors = torch.randn(num_samples, 64)
self.logprobs = torch.randn(num_samples, 1)
self.values = torch.randn(num_samples, 1)
self.rewards = torch.randn(num_samples, 1)

def __len__(self):
return len(self.query_tensors)

def __getitem__(self, idx) -> DataclassBatch:
return DataclassBatch(
query_tensors=self.query_tensors[idx],
response_tensors=self.response_tensors[idx],
logprobs=self.logprobs[idx],
values=self.values[idx],
rewards=self.rewards[idx],
)


def collate_fn(batch):
return DataclassBatch(
query_tensors=torch.stack([sample.query_tensors for sample in batch]),
response_tensors=torch.stack([sample.response_tensors for sample in batch]),
logprobs=torch.stack([sample.logprobs for sample in batch]),
values=torch.stack([sample.values for sample in batch]),
rewards=torch.stack([sample.rewards for sample in batch]),
)


class BaseTestMiniBatchIterator(unittest.TestCase):
def check_mini_batch(self, mb, expected_mini_batch_size):
if is_dataclass(mb):
mb = mb.__dict__
for key, value in mb.items():
self.assertEqual(value.size(0), expected_mini_batch_size)


class TestMiniBatchDL(BaseTestMiniBatchIterator):
def test_batch(self):
batch = DataclassBatch(
torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5])
)
self.assertTrue(is_dataclass(batch))
self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values()))

def test_minibatch_iterator(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 4)

def test_minibatch_iterator_with_undivisible_mbsize(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3)

for minibatches in iterator:
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 3)

# last minibatch has only 2 samples
minibatch = minibatches[-1]
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

def test_minibatch_iterator_with_remainder(self):
# Create Dummy Dataset and DataLoader
dummy_dataset = DummyDataset(36)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)

for i in range(4):
minibatches = next(iterator)
for minibatch in minibatches[:-1]:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

# last iteration has only 2 minibatches
minibatches = next(iterator)
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))
self.check_mini_batch(minibatch, 2)

def test_minibatch_iterator_with_smaller_dataset(self):
# Create Dummy Dataset and DataLoader with size smaller than batch size
dummy_dataset = DummyDataset(6)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4)

minibatches = next(iterator)

for minibatch in minibatches:
self.assertIsInstance(minibatch, DataclassBatch)
self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values()))

with self.assertRaises(StopIteration):
minibatches = next(iterator)

def test_minibatch_content(self):
dummy_dataset = DummyDataset(32)
dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn)

iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2)

idx = 0
for minibatches in iterator:
for minibatch in minibatches:
for key in minibatch.__dict__.keys():
original_data = getattr(dummy_dataset, key)
start_idx = idx * minibatch.__dict__[key].size(0)
end_idx = start_idx + minibatch.__dict__[key].size(0)
expected_data = original_data[start_idx:end_idx]

# Check if the tensor content in the minibatch is consistent with the original dataset
self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data)))
idx += 1

# Test if the iterator covered all the samples in the dataset
self.assertEqual(idx * iterator.mb_size, len(dummy_dataset))


class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator):
def test_minibatch_iterator_with_prompt_pipeline(self):
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Create prompts
prompts = ["This is a test prompt."] * 32

prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer)

prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True)

iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2)
for minibatches in iterator:
for minibatch in minibatches:
self.assertTrue("input_ids" in minibatch)
self.assertTrue("attention_mask" in minibatch)
self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor))
self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor))
self.check_mini_batch(minibatch, 4)


class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator):
def create_dummy_tensors(self, num_samples):
input_ids = torch.randint(0, 100, (num_samples, 10))
attention_mask = torch.randint(0, 2, (num_samples, 10))
rewards = torch.randn(num_samples, 1)
states_ixs = torch.randint(0, 100, (num_samples, 1))
actions_ixs = torch.randint(0, 100, (num_samples, 1))
dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool)

return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones

def test_minibatch_iterator_with_ilql_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)

# Create ILQLRolloutStorage instance
ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones)

ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8)

iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2)

for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)

def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self):
# Create dummy data
input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32)
decoder_input_ids = torch.randint(0, 100, (32, 10))

# Create ILQLSeq2SeqRolloutStorage instance
ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage(
input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones
)

ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8)

iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2)

for minibatches in iterator:
self.assertEqual(len(minibatches), 2)
for minibatch in minibatches:
self.check_mini_batch(minibatch, expected_mini_batch_size=4)


if __name__ == "__main__":
unittest.main()
34 changes: 34 additions & 0 deletions tests/test_trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import unittest
from typing import List, Mapping
from unittest.mock import patch

import trlx.utils.logging as logging
from trlx.data.configs import (
Expand Down Expand Up @@ -132,3 +133,36 @@ def test_save_checkpoint(self):
if total_steps % interval != 0:
self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint")))

def test_accumulate_context(self):
config = self.get_default_config()
trainer = self.get_trainer(config)
trainer.accelerator.gradient_accumulation_steps = 3

def run_test(mb_count, num_mb, total_steps, should_call_no_sync):
trainer.mb_count = mb_count
trainer.num_mb = num_mb
trainer.config.train.total_steps = total_steps

with patch.object(trainer.accelerator, "no_sync") as no_sync_tracker:
with patch("contextlib.nullcontext") as nullcontext_tracker:
with trainer._accumulate():
pass

self.assertEqual(no_sync_tracker.called, should_call_no_sync)
self.assertEqual(nullcontext_tracker.called, not should_call_no_sync)

# Test case 1: the context manager should call accelerator.no_sync
run_test(mb_count=1, num_mb=2, total_steps=4, should_call_no_sync=True)

# Test case 2: the context manager should sync because next mb_count is 3 (corresponds with gradient accumulation)
run_test(mb_count=2, num_mb=2, total_steps=4, should_call_no_sync=False)

# Test case 3: the context manager should sync because next mb_count is final step even though it is not % by 3
run_test(mb_count=3, num_mb=1, total_steps=4, should_call_no_sync=False)

# Test case 4: the context manager should call accelerator.no_sync
run_test(mb_count=3, num_mb=1, total_steps=6, should_call_no_sync=True)

# Test case 5: the context manager should sync because next mb_count is 28 and 28 // num_mb means it is the last step
run_test(mb_count=27, num_mb=4, total_steps=7, should_call_no_sync=False)
5 changes: 5 additions & 0 deletions trlx/data/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ class TrainConfig:
:param seed: Random seed
:type seed: int
:param minibatch_size: Size of model input during one forward pass. Must divide batch size
:type minibatch_size: int
"""

total_steps: int
Expand Down Expand Up @@ -224,6 +227,8 @@ class TrainConfig:

seed: int = 1000

minibatch_size: Optional[int] = None

@classmethod
def from_dict(cls, config: Dict[str, Any]):
return cls(**config)
Expand Down
73 changes: 73 additions & 0 deletions trlx/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import random
import sys
from abc import abstractmethod, abstractstaticmethod
from dataclasses import is_dataclass
from typing import Any, Callable, Dict, Iterable

from torch.utils.data import DataLoader, Dataset
from transformers.tokenization_utils_base import BatchEncoding

from trlx.data import GeneralElement, RLElement
from trlx.utils import logging

# specifies a dictionary of architectures
_DATAPIPELINE: Dict[str, any] = {} # registry

logger = logging.get_logger(__name__)


def register_datapipeline(name):
"""Decorator used register a CARP architecture
Expand Down Expand Up @@ -95,3 +100,71 @@ def create_loader(
:type prep_fn: Callable
"""
pass


class MiniBatchIterator:
"""
A custom iterator for generating mini-batches from a PyTorch DataLoader.
"""

def __init__(self, data_loader, mb_size, num_mb):
"""
Initializes the MiniBatchIterator.
Args:
data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from.
mb_size (int): The size of each mini-batch.
num_mb (int): The number of mini-batches to generate for each iteration.
"""
self.data_loader = data_loader
self.data_loader_iter = iter(data_loader)
self.mb_size = mb_size
self.num_mb = num_mb

def __iter__(self):
return self

def __next__(self):
batch = next(self.data_loader_iter)
minibatches = []

for mbi in range(self.num_mb):
sliced_data = {}
batch_dict = batch
if is_dataclass(batch):
batch_dict = batch.__dict__
for key, value in batch_dict.items():
start_idx = mbi * self.mb_size
end_idx = (mbi + 1) * self.mb_size
sliced_data[key] = value[start_idx:end_idx]

if len(sliced_data[key]) == 0:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with 0 elements. "
"This may be due to the wrong mb_size and/or num_mb or the last batch"
"in the dataset being smaller."
)
sliced_data.pop(key)
break
elif len(sliced_data[key]) < self.mb_size:
logger.warning(
"WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. "
"This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset "
"being smaller."
)
if not sliced_data:
break

if isinstance(batch, BatchEncoding):
minibatch = BatchEncoding(sliced_data)
elif is_dataclass(batch):
minibatch = batch.__class__(**sliced_data)
# else:
# minibatch = sliced_data

minibatches.append(minibatch)

if not minibatches:
raise StopIteration

return minibatches
Loading

0 comments on commit 565c316

Please sign in to comment.