diff --git a/tests/test_minibatch.py b/tests/test_minibatch.py new file mode 100644 index 000000000..c9ab04d04 --- /dev/null +++ b/tests/test_minibatch.py @@ -0,0 +1,234 @@ +import unittest +from dataclasses import dataclass, is_dataclass + +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from trlx.pipeline import MiniBatchIterator +from trlx.pipeline.offline_pipeline import ( + ILQLRolloutStorage, + ILQLSeq2SeqRolloutStorage, + PromptPipeline, +) + + +@dataclass +class DataclassBatch: + query_tensors: torch.Tensor + response_tensors: torch.Tensor + logprobs: torch.Tensor + values: torch.Tensor + rewards: torch.Tensor + + +class DummyDataset(Dataset, DataclassBatch): + def __init__(self, num_samples): + self.query_tensors = torch.randn(num_samples, 64) + self.response_tensors = torch.randn(num_samples, 64) + self.logprobs = torch.randn(num_samples, 1) + self.values = torch.randn(num_samples, 1) + self.rewards = torch.randn(num_samples, 1) + + def __len__(self): + return len(self.query_tensors) + + def __getitem__(self, idx) -> DataclassBatch: + return DataclassBatch( + query_tensors=self.query_tensors[idx], + response_tensors=self.response_tensors[idx], + logprobs=self.logprobs[idx], + values=self.values[idx], + rewards=self.rewards[idx], + ) + + +def collate_fn(batch): + return DataclassBatch( + query_tensors=torch.stack([sample.query_tensors for sample in batch]), + response_tensors=torch.stack([sample.response_tensors for sample in batch]), + logprobs=torch.stack([sample.logprobs for sample in batch]), + values=torch.stack([sample.values for sample in batch]), + rewards=torch.stack([sample.rewards for sample in batch]), + ) + + +class BaseTestMiniBatchIterator(unittest.TestCase): + def check_mini_batch(self, mb, expected_mini_batch_size): + if is_dataclass(mb): + mb = mb.__dict__ + for key, value in mb.items(): + self.assertEqual(value.size(0), expected_mini_batch_size) + + +class TestMiniBatchDL(BaseTestMiniBatchIterator): + def test_batch(self): + batch = DataclassBatch( + torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5]) + ) + self.assertTrue(is_dataclass(batch)) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values())) + + def test_minibatch_iterator(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) + for minibatches in iterator: + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 4) + + def test_minibatch_iterator_with_undivisible_mbsize(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3) + + for minibatches in iterator: + for minibatch in minibatches[:-1]: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 3) + + # last minibatch has only 2 samples + minibatch = minibatches[-1] + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + def test_minibatch_iterator_with_remainder(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(36) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) + + for i in range(4): + minibatches = next(iterator) + for minibatch in minibatches[:-1]: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + # last iteration has only 2 minibatches + minibatches = next(iterator) + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + def test_minibatch_iterator_with_smaller_dataset(self): + # Create Dummy Dataset and DataLoader with size smaller than batch size + dummy_dataset = DummyDataset(6) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) + + minibatches = next(iterator) + + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + + with self.assertRaises(StopIteration): + minibatches = next(iterator) + + def test_minibatch_content(self): + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) + + idx = 0 + for minibatches in iterator: + for minibatch in minibatches: + for key in minibatch.__dict__.keys(): + original_data = getattr(dummy_dataset, key) + start_idx = idx * minibatch.__dict__[key].size(0) + end_idx = start_idx + minibatch.__dict__[key].size(0) + expected_data = original_data[start_idx:end_idx] + + # Check if the tensor content in the minibatch is consistent with the original dataset + self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data))) + idx += 1 + + # Test if the iterator covered all the samples in the dataset + self.assertEqual(idx * iterator.mb_size, len(dummy_dataset)) + + +class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator): + def test_minibatch_iterator_with_prompt_pipeline(self): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + + # Create prompts + prompts = ["This is a test prompt."] * 32 + + prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer) + + prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True) + + iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2) + for minibatches in iterator: + for minibatch in minibatches: + self.assertTrue("input_ids" in minibatch) + self.assertTrue("attention_mask" in minibatch) + self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor)) + self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor)) + self.check_mini_batch(minibatch, 4) + + +class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator): + def create_dummy_tensors(self, num_samples): + input_ids = torch.randint(0, 100, (num_samples, 10)) + attention_mask = torch.randint(0, 2, (num_samples, 10)) + rewards = torch.randn(num_samples, 1) + states_ixs = torch.randint(0, 100, (num_samples, 1)) + actions_ixs = torch.randint(0, 100, (num_samples, 1)) + dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool) + + return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones + + def test_minibatch_iterator_with_ilql_rollout_storage(self): + # Create dummy data + input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) + + # Create ILQLRolloutStorage instance + ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones) + + ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8) + + iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2) + + for minibatches in iterator: + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.check_mini_batch(minibatch, expected_mini_batch_size=4) + + def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self): + # Create dummy data + input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) + decoder_input_ids = torch.randint(0, 100, (32, 10)) + + # Create ILQLSeq2SeqRolloutStorage instance + ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage( + input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones + ) + + ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8) + + iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2) + + for minibatches in iterator: + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.check_mini_batch(minibatch, expected_mini_batch_size=4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_trainers.py b/tests/test_trainers.py index a54da8b3a..667c1c766 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -2,6 +2,7 @@ import tempfile import unittest from typing import List, Mapping +from unittest.mock import patch import trlx.utils.logging as logging from trlx.data.configs import ( @@ -132,3 +133,36 @@ def test_save_checkpoint(self): if total_steps % interval != 0: self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint"))) + + def test_accumulate_context(self): + config = self.get_default_config() + trainer = self.get_trainer(config) + trainer.accelerator.gradient_accumulation_steps = 3 + + def run_test(mb_count, num_mb, total_steps, should_call_no_sync): + trainer.mb_count = mb_count + trainer.num_mb = num_mb + trainer.config.train.total_steps = total_steps + + with patch.object(trainer.accelerator, "no_sync") as no_sync_tracker: + with patch("contextlib.nullcontext") as nullcontext_tracker: + with trainer._accumulate(): + pass + + self.assertEqual(no_sync_tracker.called, should_call_no_sync) + self.assertEqual(nullcontext_tracker.called, not should_call_no_sync) + + # Test case 1: the context manager should call accelerator.no_sync + run_test(mb_count=1, num_mb=2, total_steps=4, should_call_no_sync=True) + + # Test case 2: the context manager should sync because next mb_count is 3 (corresponds with gradient accumulation) + run_test(mb_count=2, num_mb=2, total_steps=4, should_call_no_sync=False) + + # Test case 3: the context manager should sync because next mb_count is final step even though it is not % by 3 + run_test(mb_count=3, num_mb=1, total_steps=4, should_call_no_sync=False) + + # Test case 4: the context manager should call accelerator.no_sync + run_test(mb_count=3, num_mb=1, total_steps=6, should_call_no_sync=True) + + # Test case 5: the context manager should sync because next mb_count is 28 and 28 // num_mb means it is the last step + run_test(mb_count=27, num_mb=4, total_steps=7, should_call_no_sync=False) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index 500f26019..26ad5b52f 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -196,6 +196,9 @@ class TrainConfig: :param seed: Random seed :type seed: int + + :param minibatch_size: Size of model input during one forward pass. Must divide batch size + :type minibatch_size: int """ total_steps: int @@ -224,6 +227,8 @@ class TrainConfig: seed: int = 1000 + minibatch_size: Optional[int] = None + @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) diff --git a/trlx/pipeline/__init__.py b/trlx/pipeline/__init__.py index a9e6d6b6f..b3fcbd519 100644 --- a/trlx/pipeline/__init__.py +++ b/trlx/pipeline/__init__.py @@ -1,15 +1,20 @@ import random import sys from abc import abstractmethod, abstractstaticmethod +from dataclasses import is_dataclass from typing import Any, Callable, Dict, Iterable from torch.utils.data import DataLoader, Dataset +from transformers.tokenization_utils_base import BatchEncoding from trlx.data import GeneralElement, RLElement +from trlx.utils import logging # specifies a dictionary of architectures _DATAPIPELINE: Dict[str, any] = {} # registry +logger = logging.get_logger(__name__) + def register_datapipeline(name): """Decorator used register a CARP architecture @@ -95,3 +100,71 @@ def create_loader( :type prep_fn: Callable """ pass + + +class MiniBatchIterator: + """ + A custom iterator for generating mini-batches from a PyTorch DataLoader. + """ + + def __init__(self, data_loader, mb_size, num_mb): + """ + Initializes the MiniBatchIterator. + + Args: + data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from. + mb_size (int): The size of each mini-batch. + num_mb (int): The number of mini-batches to generate for each iteration. + """ + self.data_loader = data_loader + self.data_loader_iter = iter(data_loader) + self.mb_size = mb_size + self.num_mb = num_mb + + def __iter__(self): + return self + + def __next__(self): + batch = next(self.data_loader_iter) + minibatches = [] + + for mbi in range(self.num_mb): + sliced_data = {} + batch_dict = batch + if is_dataclass(batch): + batch_dict = batch.__dict__ + for key, value in batch_dict.items(): + start_idx = mbi * self.mb_size + end_idx = (mbi + 1) * self.mb_size + sliced_data[key] = value[start_idx:end_idx] + + if len(sliced_data[key]) == 0: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with 0 elements. " + "This may be due to the wrong mb_size and/or num_mb or the last batch" + "in the dataset being smaller." + ) + sliced_data.pop(key) + break + elif len(sliced_data[key]) < self.mb_size: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. " + "This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset " + "being smaller." + ) + if not sliced_data: + break + + if isinstance(batch, BatchEncoding): + minibatch = BatchEncoding(sliced_data) + elif is_dataclass(batch): + minibatch = batch.__class__(**sliced_data) + # else: + # minibatch = sliced_data + + minibatches.append(minibatch) + + if not minibatches: + raise StopIteration + + return minibatches diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 251a1943d..a67d5faf4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -1,7 +1,9 @@ +import contextlib import json import os import sys from abc import abstractmethod +from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple @@ -15,6 +17,7 @@ import trlx.utils.logging as logging from trlx.data.configs import TRLConfig +from trlx.pipeline import MiniBatchIterator from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( filter_non_scalars, @@ -44,6 +47,13 @@ class AccelerateRLTrainer(BaseRLTrainer): def __init__(self, config, **kwargs): # noqa: C901 super().__init__(config, **kwargs) self.max_length = config.train.seq_length + if config.train.minibatch_size: + assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size" + self.mb_size = config.train.minibatch_size + else: + self.mb_size = config.train.batch_size + self.num_mb = config.train.batch_size // self.mb_size + self.mb_count = 0 self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir) if self.accelerator.state.deepspeed_plugin is not None: @@ -430,6 +440,22 @@ def evaluate(self): # noqa: C901 self.nth_evaluation += 1 return stats + @contextmanager + def _accumulate(self): + # We can't use accelerator.accumulate() since that checks if the dataloader is exhausted + # and we do exhaust the eval dataloader right before each training loop + self.mb_count += 1 + assert self.mb_count // self.num_mb <= self.config.train.total_steps, "Beyond total steps, something is wrong" + if ( + self.mb_count % self.accelerator.gradient_accumulation_steps == 0 + or self.mb_count // self.num_mb >= self.config.train.total_steps + ): + context = contextlib.nullcontext + else: + context = self.accelerator.no_sync + with context(self.model): + yield + def learn(self): # noqa: C901 """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` @@ -475,19 +501,31 @@ def learn(self): # noqa: C901 # For each epoch for _ in range(self.config.train.epochs): # For each batch - for batch in self.train_dataloader: + for mbs in MiniBatchIterator(self.train_dataloader, self.mb_size, self.num_mb): # For each update per batch for _ in range(self.n_updates_per_batch): # Note that whereas standard policy gradient methods perform one # gradient update per batch, PPO for example commonly performs # multiple gradient updates on the same batch of data. # https://arxiv.org/pdf/1707.06347.pdf - forward_time = time() - loss, stats = self.loss(batch) - forward_time = time() - forward_time - backward_time = time() - self.accelerator.backward(loss) - backward_time = time() - backward_time + forward_time = 0 + backward_time = 0 + stats_accum = [] + for mb in mbs: + with self._accumulate(): + forward_time -= time() + loss, stats = self.loss(mb) + forward_time += time() + backward_time -= time() + self.accelerator.backward(loss) + backward_time += time() + stats_accum.append(stats) + + forward_time /= self.num_mb + backward_time /= self.num_mb + # TODO(Dahoas): Best way to combine stats between mbs? + # How does accelerate do it? + stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]} self.opt.step() self.opt.zero_grad()