Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pad prompts to the right in T5 examples and add EOS token to seq2seq prompts #422

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/ilql_sentiments_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_positive_score(scores):
),
tokenizer=TokenizerConfig(
tokenizer_path="lvwerra/t5-imdb",
padding_side="right",
truncation_side="right",
),
optimizer=OptimizerConfig(
Expand Down
1 change: 1 addition & 0 deletions examples/ppo_sentiments_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_positive_score(scores):
),
tokenizer=TokenizerConfig(
tokenizer_path="lvwerra/t5-imdb",
padding_side="right",
truncation_side="right",
),
optimizer=OptimizerConfig(
Expand Down
1 change: 1 addition & 0 deletions examples/ppo_translation_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
),
tokenizer=TokenizerConfig(
tokenizer_path="t5-large",
padding_side="right",
truncation_side="right",
),
optimizer=OptimizerConfig(
Expand Down
10 changes: 8 additions & 2 deletions trlx/pipeline/offline_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,16 @@ class PromptPipeline(BasePipeline):
max_prompt_length (`int`): max length of the prompt, if exceeded the prompt will be truncated according to
tokenizer's truncation setting.
tokenizer (`transformers.PreTrainedTokenizer`): a tokenizer to tokenize prompts with.
add_special_tokens (`bool`): whether to encode prompts with tokenizer's special tokens (passed directly
into `tokenizer.encode`)
"""

def __init__(
self, prompts: Union[Dict[str, Any], List[str]], max_prompt_length: int, tokenizer: PreTrainedTokenizer
self,
prompts: Union[Dict[str, Any], List[str]],
max_prompt_length: int,
tokenizer: PreTrainedTokenizer,
add_special_tokens: bool = False,
):
super().__init__()

Expand All @@ -123,7 +129,7 @@ def __init__(
metadata = [{}] * len(prompts)

model_inputs = tokenizer(
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False
prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=add_special_tokens
)

prompts_tokens = model_inputs["input_ids"]
Expand Down
19 changes: 15 additions & 4 deletions trlx/pipeline/ppo_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class PPORolloutStorage(BaseRolloutStore):
Rollout storage for training PPO
"""

def __init__(self, pad_token_id):
def __init__(self, pad_token_id, padding_side):
super().__init__()

self.pad_token_id = pad_token_id
self.padding_side = padding_side
self.history: Iterable[PPORLElement] = [None]

def push(self, exps: Iterable[PPORLElement]):
Expand Down Expand Up @@ -51,13 +52,23 @@ def create_loader(
shuffle: bool,
) -> DataLoader:
def collate_fn(elems: Iterable[PPORLElement]):
return PPORLBatch(
if self.padding_side == "right":
# Right padding of already right-padded queries
query_tensors = pad_sequence(
[elem.query_tensor for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
)
else:
# Left padding of already left-padded queries
pad_sequence(
query_tensors = pad_sequence(
[elem.query_tensor.flip(0) for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
).flip(1),
).flip(1)

return PPORLBatch(
query_tensors,
# Right pad the rest, to have a single horizontal query/response split
pad_sequence(
[elem.response_tensor for elem in elems],
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config: TRLConfig, **kwargs):

# Setup the rollout store
# Rollouts contain the prompt & response, log probs, values and rewards - from each rollout
self.store = PPORolloutStorage(self.tokenizer.pad_token_id)
self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side)

# Create the rollout store dataloader (for batching up rollouts)
# TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future
Expand Down
8 changes: 6 additions & 2 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def train( # noqa: C901
if eval_prompts is None:
eval_prompts = prompts[:batch_size]

pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer)
pipeline = get_pipeline(config.train.pipeline)(
prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq"
)
trainer.add_prompt_pipeline(pipeline)

if eval_prompts is None:
Expand All @@ -118,7 +120,9 @@ def train( # noqa: C901
else:
raise ValueError("Either `samples` or `reward_fn` should be given for training")

eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer)
eval_pipeline = get_pipeline(config.train.pipeline)(
eval_prompts, max_prompt_length, trainer.tokenizer, add_special_tokens=config.model.model_arch_type == "seq2seq"
)
trainer.add_eval_pipeline(eval_pipeline)

trainer.learn()
Expand Down