From 1a243ae12548f3b2cf11197136be937376b01cf0 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 13 Mar 2023 16:26:06 +0000 Subject: [PATCH 1/9] fixes half exp not implemented error --- trlx/trainer/accelerate_ppo_trainer.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 3e89761e4..16791f62a 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -434,11 +434,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:]) n_samples: int = samples.shape[0] - logprobs = logprobs.cpu() - ref_logprobs = ref_logprobs.cpu() - prompt_tensors = prompt_tensors.cpu() - sample_outputs = sample_outputs.cpu() - values = values.cpu()[:, :-1] # Estimate the KL divergence between the model and reference model if self.config.model.model_arch_type == "seq2seq": @@ -447,6 +442,15 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq else: start = prompt_tensors.shape[1] - 1 + log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1] + self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device) + + logprobs = logprobs.cpu() + ref_logprobs = ref_logprobs.cpu() + prompt_tensors = prompt_tensors.cpu() + sample_outputs = sample_outputs.cpu() + values = values.cpu()[:, :-1] + ends = start + attention_mask[:, start:].sum(1) # Get the logprobs and values, for tokens that are not padding @@ -454,9 +458,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)] all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)] - log_ratio = (logprobs - ref_logprobs) * attention_mask[:, :-1].cpu() - self.mean_kl = (log_ratio.exp() - 1 - log_ratio).mean().to(device) - kl_penalty = self.kl_ctl.value * -log_ratio + kl_penalty = self.kl_ctl.value * -log_ratio.cpu() kl_penalty = [xs[start : ends[ix]] for ix, xs in enumerate(kl_penalty)] rollout_count = 0 From 800c433c9871dbe9b44a9fb9988e99d300d3b91b Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 13 Mar 2023 16:02:19 +0000 Subject: [PATCH 2/9] added minibatching --- trlx/data/configs.py | 5 +++++ trlx/trainer/accelerate_base_trainer.py | 29 ++++++++++++++++++++----- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/trlx/data/configs.py b/trlx/data/configs.py index b19fc99e3..d4be30c74 100644 --- a/trlx/data/configs.py +++ b/trlx/data/configs.py @@ -196,6 +196,9 @@ class TrainConfig: :param seed: Random seed :type seed: int + + :param minibatch_size: Size of model input during one forward pass. Must divide batch size + :type minibatch_size: int """ total_steps: int @@ -223,6 +226,8 @@ class TrainConfig: seed: int = 1000 + minibatch_size: Optional[int] = None + @classmethod def from_dict(cls, config: Dict[str, Any]): return cls(**config) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 1526cde41..7b93bd889 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -45,6 +45,12 @@ class AccelerateRLTrainer(BaseRLTrainer): def __init__(self, config, **kwargs): # noqa: C901 super().__init__(config, **kwargs) self.max_length = config.train.seq_length + if config.train.minibatch_size: + assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size" + self.mbs = config.train.minibatch_size + else: + self.mbs = config.train.batch_size + self.num_mb = config.train.batch_size // self.mbs self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir) if self.accelerator.state.deepspeed_plugin is not None: @@ -474,12 +480,23 @@ def learn(self): # noqa: C901 # gradient update per batch, PPO for example commonly performs # multiple gradient updates on the same batch of data. # https://arxiv.org/pdf/1707.06347.pdf - forward_time = time() - loss, stats = self.loss(batch) - forward_time = time() - forward_time - backward_time = time() - self.accelerator.backward(loss) - backward_time = time() - backward_time + forward_time = 0 + backward_time = 0 + stats_accum = [] + for mbi in range(self.mb_num): + forward_time -= time() + loss, stats = self.loss(batch) + forward_time += time() + backward_time -= time() + self.accelerator.backward(loss) + backward_time += time() + stats_accum.append(stats) + + forward_time /= self.mb_num + backward_time /= self.mb_num + # TODO(Dahoas): Best way to combine stats between mbs? + # How does accelerate do it? + stats = {key: sum([stats[key] for stats in stats_accum]) / self.mb_num for key in stats_accum[0]} self.opt.step() self.opt.zero_grad() From 17b543a3bc3e1956ca0797f31737dca90cf446e3 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Mon, 13 Mar 2023 16:38:46 +0000 Subject: [PATCH 3/9] fix num_mb name --- trlx/trainer/accelerate_base_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 7b93bd889..3ba74f248 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -483,7 +483,7 @@ def learn(self): # noqa: C901 forward_time = 0 backward_time = 0 stats_accum = [] - for mbi in range(self.mb_num): + for mbi in range(self.num_mb): forward_time -= time() loss, stats = self.loss(batch) forward_time += time() @@ -492,11 +492,11 @@ def learn(self): # noqa: C901 backward_time += time() stats_accum.append(stats) - forward_time /= self.mb_num - backward_time /= self.mb_num + forward_time /= self.num_mb + backward_time /= self.num_mb # TODO(Dahoas): Best way to combine stats between mbs? # How does accelerate do it? - stats = {key: sum([stats[key] for stats in stats_accum]) / self.mb_num for key in stats_accum[0]} + stats = {key: sum([stats[key] for stats in stats_accum]) / self.num_mb for key in stats_accum[0]} self.opt.step() self.opt.zero_grad() From a196418a076f37acb10629e997d541ddf6e056d5 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 22 Mar 2023 15:27:09 +0000 Subject: [PATCH 4/9] fix minibatch indexing --- trlx/trainer/accelerate_base_trainer.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 3ba74f248..06dab3c69 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -16,6 +16,7 @@ import trlx.utils.logging as logging from trlx.data.configs import TRLConfig +from trlx.data.ppo_types import PPORLBatch from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( filter_non_scalars, @@ -47,10 +48,10 @@ def __init__(self, config, **kwargs): # noqa: C901 self.max_length = config.train.seq_length if config.train.minibatch_size: assert config.train.batch_size % config.train.minibatch_size == 0, "Minibatch size must divide batch size" - self.mbs = config.train.minibatch_size + self.mb_size = config.train.minibatch_size else: - self.mbs = config.train.batch_size - self.num_mb = config.train.batch_size // self.mbs + self.mb_size = config.train.batch_size + self.num_mb = config.train.batch_size // self.mb_size self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir) if self.accelerator.state.deepspeed_plugin is not None: @@ -474,6 +475,15 @@ def learn(self): # noqa: C901 for _ in range(self.config.train.epochs): # For each batch for batch in self.train_dataloader: + mbs = [ + PPORLBatch( + query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi+1) * self.mb_size], + response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi+1) * self.mb_size], + logprobs=batch.logprobs[mbi * self.mb_size : (mbi+1) * self.mb_size], + values=batch.values[mbi * self.mb_size : (mbi+1) * self.mb_size], + rewards=batch.rewards[mbi * self.mb_size : (mbi+1) * self.mb_size], + ) for mbi in range(self.num_mb) + ] # For each update per batch for _ in range(self.n_updates_per_batch): # Note that whereas standard policy gradient methods perform one @@ -483,10 +493,11 @@ def learn(self): # noqa: C901 forward_time = 0 backward_time = 0 stats_accum = [] - for mbi in range(self.num_mb): + for mb in mbs: forward_time -= time() - loss, stats = self.loss(batch) + loss, stats = self.loss(mb) forward_time += time() + loss /= self.num_mb backward_time -= time() self.accelerator.backward(loss) backward_time += time() From 17b12be50b884c3d8d90742b7b2487d920c5a65d Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 22 Mar 2023 15:34:13 +0000 Subject: [PATCH 5/9] fixing style --- trlx/trainer/accelerate_base_trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 06dab3c69..e7ab7a072 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -477,13 +477,14 @@ def learn(self): # noqa: C901 for batch in self.train_dataloader: mbs = [ PPORLBatch( - query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi+1) * self.mb_size], - response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi+1) * self.mb_size], - logprobs=batch.logprobs[mbi * self.mb_size : (mbi+1) * self.mb_size], - values=batch.values[mbi * self.mb_size : (mbi+1) * self.mb_size], - rewards=batch.rewards[mbi * self.mb_size : (mbi+1) * self.mb_size], - ) for mbi in range(self.num_mb) - ] ++ query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], ++ response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], ++ logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], ++ values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size], ++ rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size], ++ ) ++ for mbi in range(self.num_mb) ++ ] # For each update per batch for _ in range(self.n_updates_per_batch): # Note that whereas standard policy gradient methods perform one From 899418750c5076b6445c57aeb29004d3697bb8cb Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 22 Mar 2023 15:42:37 +0000 Subject: [PATCH 6/9] fixing style --- trlx/trainer/accelerate_base_trainer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index e7ab7a072..625cdc455 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -477,14 +477,14 @@ def learn(self): # noqa: C901 for batch in self.train_dataloader: mbs = [ PPORLBatch( -+ query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], -+ response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], -+ logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], -+ values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size], -+ rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size], -+ ) -+ for mbi in range(self.num_mb) -+ ] + query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], + response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], + logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], + values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size], + rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size], + ) + for mbi in range(self.num_mb) + ] # For each update per batch for _ in range(self.n_updates_per_batch): # Note that whereas standard policy gradient methods perform one From f389941f57e5e34276f310733b19c130e49bed79 Mon Sep 17 00:00:00 2001 From: Dahoas Date: Wed, 22 Mar 2023 15:44:42 +0000 Subject: [PATCH 7/9] fixing style --- trlx/trainer/accelerate_base_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 625cdc455..6f8f146f1 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -476,7 +476,7 @@ def learn(self): # noqa: C901 # For each batch for batch in self.train_dataloader: mbs = [ - PPORLBatch( + PPORLBatch( query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], From 8485e782dd05bada8cc229ff27ae1111828996ac Mon Sep 17 00:00:00 2001 From: Enxhell Date: Mon, 3 Apr 2023 20:36:28 +0200 Subject: [PATCH 8/9] Minibatch iterator (#403) * Add minibatch iterator * Add tests --- tests/test_minibatch.py | 234 ++++++++++++++++++++++++ trlx/pipeline/__init__.py | 73 ++++++++ trlx/trainer/accelerate_base_trainer.py | 14 +- 3 files changed, 309 insertions(+), 12 deletions(-) create mode 100644 tests/test_minibatch.py diff --git a/tests/test_minibatch.py b/tests/test_minibatch.py new file mode 100644 index 000000000..c9ab04d04 --- /dev/null +++ b/tests/test_minibatch.py @@ -0,0 +1,234 @@ +import unittest +from dataclasses import dataclass, is_dataclass + +import torch +from torch.utils.data import DataLoader, Dataset +from transformers import AutoTokenizer + +from trlx.pipeline import MiniBatchIterator +from trlx.pipeline.offline_pipeline import ( + ILQLRolloutStorage, + ILQLSeq2SeqRolloutStorage, + PromptPipeline, +) + + +@dataclass +class DataclassBatch: + query_tensors: torch.Tensor + response_tensors: torch.Tensor + logprobs: torch.Tensor + values: torch.Tensor + rewards: torch.Tensor + + +class DummyDataset(Dataset, DataclassBatch): + def __init__(self, num_samples): + self.query_tensors = torch.randn(num_samples, 64) + self.response_tensors = torch.randn(num_samples, 64) + self.logprobs = torch.randn(num_samples, 1) + self.values = torch.randn(num_samples, 1) + self.rewards = torch.randn(num_samples, 1) + + def __len__(self): + return len(self.query_tensors) + + def __getitem__(self, idx) -> DataclassBatch: + return DataclassBatch( + query_tensors=self.query_tensors[idx], + response_tensors=self.response_tensors[idx], + logprobs=self.logprobs[idx], + values=self.values[idx], + rewards=self.rewards[idx], + ) + + +def collate_fn(batch): + return DataclassBatch( + query_tensors=torch.stack([sample.query_tensors for sample in batch]), + response_tensors=torch.stack([sample.response_tensors for sample in batch]), + logprobs=torch.stack([sample.logprobs for sample in batch]), + values=torch.stack([sample.values for sample in batch]), + rewards=torch.stack([sample.rewards for sample in batch]), + ) + + +class BaseTestMiniBatchIterator(unittest.TestCase): + def check_mini_batch(self, mb, expected_mini_batch_size): + if is_dataclass(mb): + mb = mb.__dict__ + for key, value in mb.items(): + self.assertEqual(value.size(0), expected_mini_batch_size) + + +class TestMiniBatchDL(BaseTestMiniBatchIterator): + def test_batch(self): + batch = DataclassBatch( + torch.tensor([1]), torch.tensor([2]), torch.tensor([3]), torch.tensor([4]), torch.tensor([5]) + ) + self.assertTrue(is_dataclass(batch)) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in batch.__dict__.values())) + + def test_minibatch_iterator(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) + for minibatches in iterator: + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 4) + + def test_minibatch_iterator_with_undivisible_mbsize(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=3, num_mb=3) + + for minibatches in iterator: + for minibatch in minibatches[:-1]: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 3) + + # last minibatch has only 2 samples + minibatch = minibatches[-1] + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + def test_minibatch_iterator_with_remainder(self): + # Create Dummy Dataset and DataLoader + dummy_dataset = DummyDataset(36) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) + + for i in range(4): + minibatches = next(iterator) + for minibatch in minibatches[:-1]: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + # last iteration has only 2 minibatches + minibatches = next(iterator) + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + self.check_mini_batch(minibatch, 2) + + def test_minibatch_iterator_with_smaller_dataset(self): + # Create Dummy Dataset and DataLoader with size smaller than batch size + dummy_dataset = DummyDataset(6) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=2, num_mb=4) + + minibatches = next(iterator) + + for minibatch in minibatches: + self.assertIsInstance(minibatch, DataclassBatch) + self.assertTrue(all(isinstance(v, torch.Tensor) for v in minibatch.__dict__.values())) + + with self.assertRaises(StopIteration): + minibatches = next(iterator) + + def test_minibatch_content(self): + dummy_dataset = DummyDataset(32) + dummy_dataloader = DataLoader(dummy_dataset, batch_size=8, shuffle=False, num_workers=0, collate_fn=collate_fn) + + iterator = MiniBatchIterator(dummy_dataloader, mb_size=4, num_mb=2) + + idx = 0 + for minibatches in iterator: + for minibatch in minibatches: + for key in minibatch.__dict__.keys(): + original_data = getattr(dummy_dataset, key) + start_idx = idx * minibatch.__dict__[key].size(0) + end_idx = start_idx + minibatch.__dict__[key].size(0) + expected_data = original_data[start_idx:end_idx] + + # Check if the tensor content in the minibatch is consistent with the original dataset + self.assertTrue(torch.all(torch.eq(minibatch.__dict__[key], expected_data))) + idx += 1 + + # Test if the iterator covered all the samples in the dataset + self.assertEqual(idx * iterator.mb_size, len(dummy_dataset)) + + +class TestMiniBatchIteratorWithPromptPipeline(BaseTestMiniBatchIterator): + def test_minibatch_iterator_with_prompt_pipeline(self): + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") + + # Create prompts + prompts = ["This is a test prompt."] * 32 + + prompt_pipeline = PromptPipeline(prompts, max_prompt_length=20, tokenizer=tokenizer) + + prompt_dataloader = prompt_pipeline.create_loader(batch_size=8, shuffle=True) + + iterator = MiniBatchIterator(prompt_dataloader, mb_size=4, num_mb=2) + for minibatches in iterator: + for minibatch in minibatches: + self.assertTrue("input_ids" in minibatch) + self.assertTrue("attention_mask" in minibatch) + self.assertTrue(isinstance(minibatch["input_ids"], torch.Tensor)) + self.assertTrue(isinstance(minibatch["attention_mask"], torch.Tensor)) + self.check_mini_batch(minibatch, 4) + + +class TestMiniBatchIteratorWithILQLRollouts(BaseTestMiniBatchIterator): + def create_dummy_tensors(self, num_samples): + input_ids = torch.randint(0, 100, (num_samples, 10)) + attention_mask = torch.randint(0, 2, (num_samples, 10)) + rewards = torch.randn(num_samples, 1) + states_ixs = torch.randint(0, 100, (num_samples, 1)) + actions_ixs = torch.randint(0, 100, (num_samples, 1)) + dones = torch.randint(0, 2, (num_samples, 1), dtype=torch.bool) + + return input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones + + def test_minibatch_iterator_with_ilql_rollout_storage(self): + # Create dummy data + input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) + + # Create ILQLRolloutStorage instance + ilql_rollout_storage = ILQLRolloutStorage(input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones) + + ilql_dataloader = ilql_rollout_storage.create_loader(batch_size=8) + + iterator = MiniBatchIterator(ilql_dataloader, mb_size=4, num_mb=2) + + for minibatches in iterator: + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.check_mini_batch(minibatch, expected_mini_batch_size=4) + + def test_minibatch_iterator_with_ilql_seq2seq_rollout_storage(self): + # Create dummy data + input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones = self.create_dummy_tensors(32) + decoder_input_ids = torch.randint(0, 100, (32, 10)) + + # Create ILQLSeq2SeqRolloutStorage instance + ilql_seq2seq_rollout_storage = ILQLSeq2SeqRolloutStorage( + input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones + ) + + ilql_seq2seq_dataloader = ilql_seq2seq_rollout_storage.create_loader(batch_size=8) + + iterator = MiniBatchIterator(ilql_seq2seq_dataloader, mb_size=4, num_mb=2) + + for minibatches in iterator: + self.assertEqual(len(minibatches), 2) + for minibatch in minibatches: + self.check_mini_batch(minibatch, expected_mini_batch_size=4) + + +if __name__ == "__main__": + unittest.main() diff --git a/trlx/pipeline/__init__.py b/trlx/pipeline/__init__.py index a9e6d6b6f..b3fcbd519 100644 --- a/trlx/pipeline/__init__.py +++ b/trlx/pipeline/__init__.py @@ -1,15 +1,20 @@ import random import sys from abc import abstractmethod, abstractstaticmethod +from dataclasses import is_dataclass from typing import Any, Callable, Dict, Iterable from torch.utils.data import DataLoader, Dataset +from transformers.tokenization_utils_base import BatchEncoding from trlx.data import GeneralElement, RLElement +from trlx.utils import logging # specifies a dictionary of architectures _DATAPIPELINE: Dict[str, any] = {} # registry +logger = logging.get_logger(__name__) + def register_datapipeline(name): """Decorator used register a CARP architecture @@ -95,3 +100,71 @@ def create_loader( :type prep_fn: Callable """ pass + + +class MiniBatchIterator: + """ + A custom iterator for generating mini-batches from a PyTorch DataLoader. + """ + + def __init__(self, data_loader, mb_size, num_mb): + """ + Initializes the MiniBatchIterator. + + Args: + data_loader (torch.utils.data.DataLoader): The DataLoader to generate mini-batches from. + mb_size (int): The size of each mini-batch. + num_mb (int): The number of mini-batches to generate for each iteration. + """ + self.data_loader = data_loader + self.data_loader_iter = iter(data_loader) + self.mb_size = mb_size + self.num_mb = num_mb + + def __iter__(self): + return self + + def __next__(self): + batch = next(self.data_loader_iter) + minibatches = [] + + for mbi in range(self.num_mb): + sliced_data = {} + batch_dict = batch + if is_dataclass(batch): + batch_dict = batch.__dict__ + for key, value in batch_dict.items(): + start_idx = mbi * self.mb_size + end_idx = (mbi + 1) * self.mb_size + sliced_data[key] = value[start_idx:end_idx] + + if len(sliced_data[key]) == 0: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with 0 elements. " + "This may be due to the wrong mb_size and/or num_mb or the last batch" + "in the dataset being smaller." + ) + sliced_data.pop(key) + break + elif len(sliced_data[key]) < self.mb_size: + logger.warning( + "WARNING: MiniBatchIterator generated a minibatch with fewer elements than mb_size. " + "This may be due to the wrong mb_size and/or num_mb or the last batch in the dataset " + "being smaller." + ) + if not sliced_data: + break + + if isinstance(batch, BatchEncoding): + minibatch = BatchEncoding(sliced_data) + elif is_dataclass(batch): + minibatch = batch.__class__(**sliced_data) + # else: + # minibatch = sliced_data + + minibatches.append(minibatch) + + if not minibatches: + raise StopIteration + + return minibatches diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 6f8f146f1..0aea0f273 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -16,7 +16,7 @@ import trlx.utils.logging as logging from trlx.data.configs import TRLConfig -from trlx.data.ppo_types import PPORLBatch +from trlx.pipeline import MiniBatchIterator from trlx.trainer import BaseRLTrainer, register_trainer from trlx.utils import ( filter_non_scalars, @@ -474,17 +474,7 @@ def learn(self): # noqa: C901 # For each epoch for _ in range(self.config.train.epochs): # For each batch - for batch in self.train_dataloader: - mbs = [ - PPORLBatch( - query_tensors=batch.query_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], - response_tensors=batch.response_tensors[mbi * self.mb_size : (mbi + 1) * self.mb_size], - logprobs=batch.logprobs[mbi * self.mb_size : (mbi + 1) * self.mb_size], - values=batch.values[mbi * self.mb_size : (mbi + 1) * self.mb_size], - rewards=batch.rewards[mbi * self.mb_size : (mbi + 1) * self.mb_size], - ) - for mbi in range(self.num_mb) - ] + for mbs in MiniBatchIterator(self.train_dataloader, self.mb_size, self.num_mb): # For each update per batch for _ in range(self.n_updates_per_batch): # Note that whereas standard policy gradient methods perform one From 833d049b56cd04d536fca9af7de9e79fb73345cd Mon Sep 17 00:00:00 2001 From: Enxhell Date: Wed, 5 Apr 2023 18:58:23 +0200 Subject: [PATCH 9/9] Avoid gradient synchronization when accumulating (#396) * Avoid gradient synchronization when accumulating * Fix accumulation to account for dataloader * Add some tests --- tests/test_trainers.py | 34 ++++++++++++++++++++++++ trlx/trainer/accelerate_base_trainer.py | 35 +++++++++++++++++++------ 2 files changed, 61 insertions(+), 8 deletions(-) diff --git a/tests/test_trainers.py b/tests/test_trainers.py index a54da8b3a..667c1c766 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -2,6 +2,7 @@ import tempfile import unittest from typing import List, Mapping +from unittest.mock import patch import trlx.utils.logging as logging from trlx.data.configs import ( @@ -132,3 +133,36 @@ def test_save_checkpoint(self): if total_steps % interval != 0: self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}"))) self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint"))) + + def test_accumulate_context(self): + config = self.get_default_config() + trainer = self.get_trainer(config) + trainer.accelerator.gradient_accumulation_steps = 3 + + def run_test(mb_count, num_mb, total_steps, should_call_no_sync): + trainer.mb_count = mb_count + trainer.num_mb = num_mb + trainer.config.train.total_steps = total_steps + + with patch.object(trainer.accelerator, "no_sync") as no_sync_tracker: + with patch("contextlib.nullcontext") as nullcontext_tracker: + with trainer._accumulate(): + pass + + self.assertEqual(no_sync_tracker.called, should_call_no_sync) + self.assertEqual(nullcontext_tracker.called, not should_call_no_sync) + + # Test case 1: the context manager should call accelerator.no_sync + run_test(mb_count=1, num_mb=2, total_steps=4, should_call_no_sync=True) + + # Test case 2: the context manager should sync because next mb_count is 3 (corresponds with gradient accumulation) + run_test(mb_count=2, num_mb=2, total_steps=4, should_call_no_sync=False) + + # Test case 3: the context manager should sync because next mb_count is final step even though it is not % by 3 + run_test(mb_count=3, num_mb=1, total_steps=4, should_call_no_sync=False) + + # Test case 4: the context manager should call accelerator.no_sync + run_test(mb_count=3, num_mb=1, total_steps=6, should_call_no_sync=True) + + # Test case 5: the context manager should sync because next mb_count is 28 and 28 // num_mb means it is the last step + run_test(mb_count=27, num_mb=4, total_steps=7, should_call_no_sync=False) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ead565972..a67d5faf4 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -1,7 +1,9 @@ +import contextlib import json import os import sys from abc import abstractmethod +from contextlib import contextmanager from time import time from typing import Dict, List, Optional, Tuple @@ -51,6 +53,7 @@ def __init__(self, config, **kwargs): # noqa: C901 else: self.mb_size = config.train.batch_size self.num_mb = config.train.batch_size // self.mb_size + self.mb_count = 0 self.accelerator = Accelerator(log_with=config.train.tracker, logging_dir=config.train.logging_dir) if self.accelerator.state.deepspeed_plugin is not None: @@ -437,6 +440,22 @@ def evaluate(self): # noqa: C901 self.nth_evaluation += 1 return stats + @contextmanager + def _accumulate(self): + # We can't use accelerator.accumulate() since that checks if the dataloader is exhausted + # and we do exhaust the eval dataloader right before each training loop + self.mb_count += 1 + assert self.mb_count // self.num_mb <= self.config.train.total_steps, "Beyond total steps, something is wrong" + if ( + self.mb_count % self.accelerator.gradient_accumulation_steps == 0 + or self.mb_count // self.num_mb >= self.config.train.total_steps + ): + context = contextlib.nullcontext + else: + context = self.accelerator.no_sync + with context(self.model): + yield + def learn(self): # noqa: C901 """ Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` @@ -493,14 +512,14 @@ def learn(self): # noqa: C901 backward_time = 0 stats_accum = [] for mb in mbs: - forward_time -= time() - loss, stats = self.loss(mb) - forward_time += time() - loss /= self.num_mb - backward_time -= time() - self.accelerator.backward(loss) - backward_time += time() - stats_accum.append(stats) + with self._accumulate(): + forward_time -= time() + loss, stats = self.loss(mb) + forward_time += time() + backward_time -= time() + self.accelerator.backward(loss) + backward_time += time() + stats_accum.append(stats) forward_time /= self.num_mb backward_time /= self.num_mb