Skip to content

Commit

Permalink
Pad on correct side in PPORolloutStorage
Browse files Browse the repository at this point in the history
  • Loading branch information
mikljohansson committed Apr 5, 2023
1 parent 33ba4ce commit 01ed405
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
19 changes: 15 additions & 4 deletions trlx/pipeline/ppo_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class PPORolloutStorage(BaseRolloutStore):
Rollout storage for training PPO
"""

def __init__(self, pad_token_id):
def __init__(self, pad_token_id, padding_side):
super().__init__()

self.pad_token_id = pad_token_id
self.padding_side = padding_side
self.history: Iterable[PPORLElement] = [None]

def push(self, exps: Iterable[PPORLElement]):
Expand Down Expand Up @@ -51,13 +52,23 @@ def create_loader(
shuffle: bool,
) -> DataLoader:
def collate_fn(elems: Iterable[PPORLElement]):
return PPORLBatch(
if self.padding_side == "right":
# Right padding of already right-padded queries
query_tensors = pad_sequence(
[elem.query_tensor for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
)
else:
# Left padding of already left-padded queries
pad_sequence(
query_tensors = pad_sequence(
[elem.query_tensor.flip(0) for elem in elems],
padding_value=self.pad_token_id,
batch_first=True,
).flip(1),
).flip(1)

return PPORLBatch(
query_tensors,
# Right pad the rest, to have a single horizontal query/response split
pad_sequence(
[elem.response_tensor for elem in elems],
Expand Down
2 changes: 1 addition & 1 deletion trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self, config: TRLConfig, **kwargs):

# Setup the rollout store
# Rollouts contain the prompt & response, log probs, values and rewards - from each rollout
self.store = PPORolloutStorage(self.tokenizer.pad_token_id)
self.store = PPORolloutStorage(self.tokenizer.pad_token_id, self.tokenizer.padding_side)

# Create the rollout store dataloader (for batching up rollouts)
# TODO (jon-tow): This is only used to satisfy to `accelerator.prepare` call constraint below - remove in future
Expand Down

0 comments on commit 01ed405

Please sign in to comment.