Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Preserve <eos> token and in-place it after trimming #401

Merged
merged 2 commits into from
Mar 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def decode(
prompts: List[torch.LongTensor],
samples: List[torch.LongTensor],
prompt_sizes: torch.LongTensor = None,
append_eos_token: bool = False,
) -> Tuple[List[str], List[str], List[str]]:
"""
Decode tensor generations into lists of strings (`samples`: List[str], `prompts`: List[str], `outputs`: List[str])
Expand All @@ -197,13 +198,21 @@ def decode(

str_prompt = self.tokenizer.decode(prompt[:prompt_size], skip_special_tokens=True)
str_output = self.tokenizer.decode(sample[output_start_ix:], skip_special_tokens=True)

# Trim outputs up to `self.stop_sequences` if any are present
trimmed = False
if self.stop_sequences:
for stop in self.stop_sequences:
stop_ix = str_output.find(stop)
if stop_ix >= 0:
str_output = str_output[:stop_ix].rstrip()
trimmed = True

# Recover the last <eos> if it was present in the original sample
# or add one if it was trimmed with `self.stop_sequences`.
# Only in cases when a generation ended due to `max_new_tokens` exhaustion,
# <eos> token would not be present in the original sample
if append_eos_token and (trimmed or sample[-1] == self.tokenizer.eos_token_id):
str_output += self.tokenizer.eos_token

str_prompts.append(str_prompt)
str_outputs.append(str_output)
Expand Down
15 changes: 6 additions & 9 deletions trlx/trainer/accelerate_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq

if self.accelerator.is_main_process:
all_str_samples, all_str_prompts, all_str_outputs = self.decode(
gathered_prompts, gathered_samples, gathered_prompt_sizes
gathered_prompts, gathered_samples, gathered_prompt_sizes, append_eos_token=True
)

exp_score_time = time()
Expand All @@ -327,7 +327,7 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
else:
scores = all_scores[0].clone().detach()

str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples)
str_samples, str_prompts, str_outputs = self.decode(prompt_tensors, samples, append_eos_token=True)

# Pad the sample outputs
outputs = self.tokenizer(str_outputs).input_ids
Expand Down Expand Up @@ -445,10 +445,10 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
sample_outputs = sample_outputs.cpu()
values = values.cpu()[:, :-1]

ends = start + attention_mask[:, start:].sum(1)

# Get the logprobs and values, for tokens that are not padding
# or beginning of sequences tokens. These are from the model (not the reference model)
# Get the logprobs and values, for tokens that are not padding,
# from the start of the prompt up to the <eos> token, while also including the latter
# (these are taken from the student model and not the reference model)
ends = start + attention_mask[:, start:].sum(1) + 1
Copy link
Collaborator

@jon-tow jon-tow Mar 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe add a comment saying something like "add 1 to account for appended eos token" since the decode call is a few blocks up

all_values = [values[ix, start : ends[ix]] for ix in range(n_samples)]
all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n_samples)]

Expand All @@ -458,9 +458,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # noq
rollout_count = 0

for sample_idx in range(n_samples):
if len(kl_penalty[sample_idx]) == 0 or len(all_logprobs[sample_idx]) == 0:
continue

rewards = kl_penalty[sample_idx]
rewards[-1] += scores[sample_idx].cpu()

Expand Down