diff --git a/examples/ilql_sentiments_t5.py b/examples/ilql_sentiments_t5.py index eda0d4e13..5b4fae922 100644 --- a/examples/ilql_sentiments_t5.py +++ b/examples/ilql_sentiments_t5.py @@ -41,6 +41,7 @@ def get_positive_score(scores): ), tokenizer=TokenizerConfig( tokenizer_path="lvwerra/t5-imdb", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/examples/ppo_sentiments_t5.py b/examples/ppo_sentiments_t5.py index f40e0a275..08dfa3ef4 100644 --- a/examples/ppo_sentiments_t5.py +++ b/examples/ppo_sentiments_t5.py @@ -43,6 +43,7 @@ def get_positive_score(scores): ), tokenizer=TokenizerConfig( tokenizer_path="lvwerra/t5-imdb", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/examples/ppo_translation_t5.py b/examples/ppo_translation_t5.py index 804d43a3d..945b7a521 100644 --- a/examples/ppo_translation_t5.py +++ b/examples/ppo_translation_t5.py @@ -54,6 +54,7 @@ ), tokenizer=TokenizerConfig( tokenizer_path="t5-large", + padding_side="right", truncation_side="right", ), optimizer=OptimizerConfig( diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 1234ab4d8..0052cd842 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -109,10 +109,16 @@ class PromptPipeline(BasePipeline): max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to tokenizer's truncation setting. tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with. + add_special_tokens (`bool`): whether to encode prompts with tokenizer's special tokens (passed directly + into `tokenizer.encode`) """ def __init__( - self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer + self, + prompts: Union[Dict[str, Any], List[str]], + max_prompt_length: int, + tokenizer: PreTrainedTokenizer, + add_special_tokens: bool = False, ): super().__init__() @@ -123,7 +129,7 @@ def __init__( metadata = [{}] * len(prompts) model_inputs = tokenizer( - prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False + prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=add_special_tokens ) prompts_tokens = model_inputs["input_ids"] diff --git a/trlx/pipeline/ppo_pipeline.py b/trlx/pipeline/ppo_pipeline.py index 7808f35fb..762679d23 100644 --- a/trlx/pipeline/ppo_pipeline.py +++ b/trlx/pipeline/ppo_pipeline.py @@ -15,10 +15,11 @@ class PPORolloutStorage(BaseRolloutStore): Rollout storage for training PPO """ - def __init__(self, pad_token_id): + def __init__(self, pad_token_id, padding_side): super().__init__() self.pad_token_id = pad_token_id + self.padding_side = padding_side self.history: Iterable[PPORLElement] = [None] def push(self, exps: Iterable[PPORLElement]): @@ -51,13 +52,23 @@ def create_loader( shuffle: bool, ) -> DataLoader: def collate_fn(elems: Iterable[PPORLElement]): - return PPORLBatch( + if self.padding_side == "right": + # Right padding of already right-padded queries + query_tensors = pad_sequence( + [elem.query_tensor for elem in elems], + padding_value=self.pad_token_id, + batch_first=True, + ) + else: # Left padding of already left-padded queries - pad_sequence( + query_tensors = pad_sequence( [elem.query_tensor.flip(0) for elem in elems], padding_value=self.pad_token_id, batch_first=True, - ).flip(1), + ).flip(1) + + return PPORLBatch( + query_tensors, # Right pad the rest, to have a single horizontal query/response split pad_sequence( [elem.response_tensor for elem in elems], diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index e1105fa09..b36de866f 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -54,7 +54,7 @@ def __init__(self, config: TRLConfig, **kwargs): # Setup the rollout store # Rollouts contain the prompt & response, log probs, values and rewards - from each rollout - self.store = PPORolloutStorage(self.tokenizer.pad_token_id) + self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side) # Create the rollout store dataloader (for batching up rollouts) # TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future diff --git a/trlx/trlx.py b/trlx/trlx.py index a97a674fc..f4cab2f8e 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -94,7 +94,9 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) + pipeline = get_pipeline(config.train.pipeline)( + prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq" + ) trainer.add_prompt_pipeline(pipeline) if eval_prompts is None: @@ -118,7 +120,9 @@ def train( # noqa: C901 else: raise ValueError("Either `samples` or `reward_fn` should be given for training") - eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer) + eval_pipeline = get_pipeline(config.train.pipeline)( + eval_prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq" + ) trainer.add_eval_pipeline(eval_pipeline) trainer.learn()