From b7ec6943571c96759c4d1853b749a45d513a433f Mon Sep 17 00:00:00 2001 From: Mikael Johansson Date: Wed, 5 Apr 2023 18:14:33 +0200 Subject: [PATCH 1/4] Pad prompts to the right in T5 examples. Add EOS token to prompts for seq2seq models like T5 --- examples/ilql_sentiments_t5.py | 1 + examples/ppo_sentiments_t5.py | 1 + examples/ppo_translation_t5.py | 1 + trlx/pipeline/offline_pipeline.py | 4 ++-- trlx/pipeline/ppo_pipeline.py | 19 +++++++++++++++---- trlx/trainer/accelerate_ppo_trainer.py | 2 +- trlx/trlx.py | 10 ++++++++-- 7 files changed, 29 insertions(+), 9 deletions(-) diff --git a/examples/ilql_sentiments_t5.py b/examples/ilql_sentiments_t5.py index eda0d4e13..5b4fae922 100644 --- a/examples/ilql_sentiments_t5.py +++ b/examples/ilql_sentiments_t5.py @@ -41,6 +41,7 @@ def get_positive_score(scores): ), tokenizer=TokenizerConfig( tokenizer_path="lvwerra/t5-imdb", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/examples/ppo_sentiments_t5.py b/examples/ppo_sentiments_t5.py index f40e0a275..08dfa3ef4 100644 --- a/examples/ppo_sentiments_t5.py +++ b/examples/ppo_sentiments_t5.py @@ -43,6 +43,7 @@ def get_positive_score(scores): ), tokenizer=TokenizerConfig( tokenizer_path="lvwerra/t5-imdb", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/examples/ppo_translation_t5.py b/examples/ppo_translation_t5.py index 804d43a3d..945b7a521 100644 --- a/examples/ppo_translation_t5.py +++ b/examples/ppo_translation_t5.py @@ -54,6 +54,7 @@ ), tokenizer=TokenizerConfig( tokenizer_path="t5-large", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 1234ab4d8..27870fe7d 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -112,7 +112,7 @@ class PromptPipeline(BasePipeline): """ def __init__( - self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer + self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer, add_special_tokens: bool ): super().__init__() @@ -123,7 +123,7 @@ def __init__( metadata = [{}] * len(prompts) model_inputs = tokenizer( - prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False + prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=add_special_tokens ) prompts_tokens = model_inputs["input_ids"] diff --git a/trlx/pipeline/ppo_pipeline.py b/trlx/pipeline/ppo_pipeline.py index 7808f35fb..762679d23 100644 --- a/trlx/pipeline/ppo_pipeline.py +++ b/trlx/pipeline/ppo_pipeline.py @@ -15,10 +15,11 @@ class PPORolloutStorage(BaseRolloutStore): Rollout storage for training PPO """ - def __init__(self, pad_token_id): + def __init__(self, pad_token_id, padding_side): super().__init__() self.pad_token_id = pad_token_id + self.padding_side = padding_side self.history: Iterable[PPORLElement] = [None] def push(self, exps: Iterable[PPORLElement]): @@ -51,13 +52,23 @@ def create_loader( shuffle: bool, ) -> DataLoader: def collate_fn(elems: Iterable[PPORLElement]): - return PPORLBatch( + if self.padding_side == "right": + # Right padding of already right-padded queries + query_tensors = pad_sequence( + [elem.query_tensor for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ) + else: # Left padding of already left-padded queries - pad_sequence( + query_tensors = pad_sequence( [elem.query_tensor.flip(0) for elem in elems], padding_value=self.pad_token_id, batch_first=True, - ).flip(1), + ).flip(1) + + return PPORLBatch( + query_tensors, # Right pad the rest, to have a single horizontal query/response split pad_sequence( [elem.response_tensor for elem in elems], diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index e1105fa09..b36de866f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -54,7 +54,7 @@ def __init__(self, config: TRLConfig, **kwargs): # Setup the rollout store # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout - self.store = PPORolloutStorage(self.tokenizer.pad_token_id) + self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side) # Create the rollout store dataloader (for batching up rollouts) # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future diff --git a/trlx/trlx.py b/trlx/trlx.py index a97a674fc..e72cb305c 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -94,7 +94,10 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) + pipeline = get_pipeline(config.train.pipeline)( + prompts, max_prompt_length, trainer.tokenizer, + add_special_tokens=config.model.model_arch_type == "seq2seq" + ) trainer.add_prompt_pipeline(pipeline) if eval_prompts is None: @@ -118,7 +121,10 @@ def train( # noqa: C901 else: raise ValueError("Either `samples` or `reward_fn` should be given for training") - eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer) + eval_pipeline = get_pipeline(config.train.pipeline)( + eval_prompts, max_prompt_length, trainer.tokenizer, + add_special_tokens=config.model.model_arch_type == "seq2seq" + ) trainer.add_eval_pipeline(eval_pipeline) trainer.learn() From dffc60a0040b8d0a0922208ef2c3aecb21656e3e Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:16:52 +0300 Subject: [PATCH 2/4] style: satisfy black --- trlx/pipeline/offline_pipeline.py | 6 +++++- trlx/trlx.py | 6 ++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 27870fe7d..9f951b3fc 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -112,7 +112,11 @@ class PromptPipeline(BasePipeline): """ def __init__( - self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer, add_special_tokens: bool + self, + prompts: Union[Dict[str, Any], List[str]], + max_prompt_length: int, + tokenizer: PreTrainedTokenizer, + add_special_tokens: bool, ): super().__init__() diff --git a/trlx/trlx.py b/trlx/trlx.py index e72cb305c..f4cab2f8e 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -95,8 +95,7 @@ def train( # noqa: C901 eval_prompts = prompts[:batch_size] pipeline = get_pipeline(config.train.pipeline)( - prompts, max_prompt_length, trainer.tokenizer, - add_special_tokens=config.model.model_arch_type == "seq2seq" + prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq" ) trainer.add_prompt_pipeline(pipeline) @@ -122,8 +121,7 @@ def train( # noqa: C901 raise ValueError("Either `samples` or `reward_fn` should be given for training") eval_pipeline = get_pipeline(config.train.pipeline)( - eval_prompts, max_prompt_length, trainer.tokenizer, - add_special_tokens=config.model.model_arch_type == "seq2seq" + eval_prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq" ) trainer.add_eval_pipeline(eval_pipeline) From 1865a8ff7b2cb3217a263e937a2a4480ed4f1ba2 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:20:44 +0300 Subject: [PATCH 3/4] fix(offline_pipeline): default `add_special_tokens` to `False` --- trlx/pipeline/offline_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 9f951b3fc..23816404c 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -109,6 +109,7 @@ class PromptPipeline(BasePipeline): max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to tokenizer's truncation setting. tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with. + add_special_tokens (`bool`): Whether to encode prompts with tokenizer's special tokens (passed directly into `tokenizer.encode`) """ def __init__( @@ -116,7 +117,7 @@ def __init__( prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer, - add_special_tokens: bool, + add_special_tokens: bool = False, ): super().__init__() From 77efc17598802ae0f3d49b1560109ddde383875b Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 27 Apr 2023 17:25:03 +0300 Subject: [PATCH 4/4] style: satisfy flake --- trlx/pipeline/offline_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 23816404c..0052cd842 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -109,7 +109,8 @@ class PromptPipeline(BasePipeline): max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to tokenizer's truncation setting. tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with. - add_special_tokens (`bool`): Whether to encode prompts with tokenizer's special tokens (passed directly into `tokenizer.encode`) + add_special_tokens (`bool`): whether to encode prompts with tokenizer's special tokens (passed directly + into `tokenizer.encode`) """ def __init__(