From 8b416c139614e48eb61929b1bd4cc05ec1341b12 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Tue, 7 Mar 2023 01:12:53 +0000 Subject: [PATCH 1/2] Enable infinite PromptPipeline --- trlx/pipeline/offline_pipeline.py | 11 +++++++++-- trlx/trainer/accelerate_ppo_trainer.py | 14 ++++---------- trlx/trlx.py | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 17077c647..619067bcb 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -62,7 +62,7 @@ class PromptPipeline(BasePipeline): Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right """ - def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer): + def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer, infinite=False): super().__init__() model_inputs = tokenizer( @@ -77,11 +77,18 @@ def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTra {"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask) ] + self.infinite = infinite + def __getitem__(self, ix: int): + if self.infinite: + ix = ix % len(self.prompts) return self.prompts[ix] def __len__(self) -> int: - return len(self.prompts) + if self.infinite: + return torch.iinfo(torch.int32).max + else: + return len(self.prompts) def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: collate_fn = DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index 7d1c34f45..c87b53225 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -246,8 +246,8 @@ def prepare_learning(self): def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) - self.prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) - self.prompt_iterator = iter(self.prompt_dataloader) + prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) + self.prompt_iterator = iter(prompt_dataloader) def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: """Make experiences @@ -277,14 +277,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq clock = Clock() while len(ppo_rl_elements) < num_rollouts: - # Get next batch in prompt dataset and refresh if exhausted - # TOOD (jon-tow): Make `prompt_dataloader` a cyclic/infinite DataLoader to not require manually - # "refreshing" the contents of the `prompt_iterator` - try: - batch: PromptBatch = next(self.prompt_iterator) - except StopIteration: - self.prompt_iterator = iter(self.prompt_dataloader) - batch = next(self.prompt_iterator) + # Get next batch in prompt dataset + batch: PromptBatch = next(self.prompt_iterator) exp_generate_time = time() diff --git a/trlx/trlx.py b/trlx/trlx.py index f50753d14..108ac02d0 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -92,7 +92,7 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) + pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer, infinite=True) trainer.add_prompt_pipeline(pipeline) if eval_prompts is None: From a50cc07727ab0636c0879673c4bdb6fb1e9b27d8 Mon Sep 17 00:00:00 2001 From: Alex Muzio Date: Tue, 7 Mar 2023 03:24:06 +0000 Subject: [PATCH 2/2] Adding `infinite_dataloader` since previous solution doesn't seem to work as expected --- trlx/pipeline/offline_pipeline.py | 11 ++--------- trlx/trainer/accelerate_ppo_trainer.py | 4 ++-- trlx/trlx.py | 2 +- trlx/utils/__init__.py | 14 +++++++++++++- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 619067bcb..17077c647 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -62,7 +62,7 @@ class PromptPipeline(BasePipeline): Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right """ - def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer, infinite=False): + def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer): super().__init__() model_inputs = tokenizer( @@ -77,18 +77,11 @@ def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTra {"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask) ] - self.infinite = infinite - def __getitem__(self, ix: int): - if self.infinite: - ix = ix % len(self.prompts) return self.prompts[ix] def __len__(self) -> int: - if self.infinite: - return torch.iinfo(torch.int32).max - else: - return len(self.prompts) + return len(self.prompts) def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: collate_fn = DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index c87b53225..066f23322 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -25,7 +25,7 @@ from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer -from trlx.utils import Clock +from trlx.utils import Clock, infinite_dataloader from trlx.utils.modeling import RunningMoments, logprobs_of_labels logger = logging.get_logger(__name__) @@ -247,7 +247,7 @@ def add_prompt_pipeline(self, pipeline: PromptPipeline): """Add a prompt pipeline dataloader to a trainer instance for the `make_experience` stage""" prompt_dataloader = pipeline.create_loader(self.config.method.chunk_size, shuffle=True) prompt_dataloader = self.accelerator.prepare_data_loader(prompt_dataloader) - self.prompt_iterator = iter(prompt_dataloader) + self.prompt_iterator = infinite_dataloader(prompt_dataloader) def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noqa: """Make experiences diff --git a/trlx/trlx.py b/trlx/trlx.py index 108ac02d0..f50753d14 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -92,7 +92,7 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer, infinite=True) + pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) trainer.add_prompt_pipeline(pipeline) if eval_prompts is None: diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 803557724..abc4d5458 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -5,8 +5,9 @@ import time from dataclasses import is_dataclass from enum import Enum +from itertools import repeat from numbers import Number -from typing import Any, Dict, Tuple +from typing import Any, Dict, Iterable, Tuple import numpy as np import torch @@ -226,3 +227,14 @@ def get_git_tag() -> Tuple[str, str]: return branch.decode()[:-1], output.decode()[1:-2] except subprocess.CalledProcessError: return "unknown", "unknown" + + +# Iter utils + + +def infinite_dataloader(dataloader: Iterable) -> Iterable: + """ + Returns a cyclic infinite dataloader from a finite dataloader + """ + for _ in repeat(dataloader): + yield from dataloader