Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine-tuning with label mask #410

Merged
merged 11 commits into from
Jan 19, 2024
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Jan 17, 2024

  • Add support for fine-tuning with a label mask.
  • Add a script for preparing Tulu V2 for fine-tuning.
  • Add fine-tuning instructions to README.

- Add support for fine-tuning with a label mask.
- Add a script for preparing Tulu V2 for fine-tuning.
- Add fine-tuning instructions to README.
@@ -0,0 +1,111 @@
"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hamishivi could you review this script?

Comment on lines 66 to 88
def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
parts = []
for msg in example["messages"]:
parts.append(f"<|{msg['role']}|>")
parts.append(msg["content"])

prompt = "\n".join(parts[:-1]) + "\n"
completion = parts[-1]

prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
completion_ids = tokenizer.encode(completion, add_special_tokens=True)

input_ids = (prompt_ids + completion_ids)[:max_seq_len]
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len]

if len(input_ids) < max_seq_len:
pad_len = max_seq_len - len(input_ids)
input_ids += [tokenizer.pad_token_id] * pad_len
label_mask += [False] * pad_len

assert len(input_ids) == len(label_mask)

return {"input_ids": input_ids, "label_mask": label_mask}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hamishivi in particular this function for preprocessing/tokenizing each example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite right. Actually, the content for any message from the assistant role should be trained on, not just the final role. This is because we have some multi-turn dialogues in our dataset, and so this is important for that. This is a bit tricky to do but a reference is here: https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L292

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated to match your script

@@ -10,3 +10,19 @@
```
pip install ai2-olmo
```

## Fine-tuning
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AkshitaB, fine-tuning instructions added here.

Comment on lines 66 to 88
def preprocess(example, tokenizer: Tokenizer, max_seq_len: int):
parts = []
for msg in example["messages"]:
parts.append(f"<|{msg['role']}|>")
parts.append(msg["content"])

prompt = "\n".join(parts[:-1]) + "\n"
completion = parts[-1]

prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
completion_ids = tokenizer.encode(completion, add_special_tokens=True)

input_ids = (prompt_ids + completion_ids)[:max_seq_len]
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len]

if len(input_ids) < max_seq_len:
pad_len = max_seq_len - len(input_ids)
input_ids += [tokenizer.pad_token_id] * pad_len
label_mask += [False] * pad_len

assert len(input_ids) == len(label_mask)

return {"input_ids": input_ids, "label_mask": label_mask}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't quite right. Actually, the content for any message from the assistant role should be trained on, not just the final role. This is because we have some multi-turn dialogues in our dataset, and so this is important for that. This is a bit tricky to do but a reference is here: https://github.com/allenai/open-instruct/blob/main/open_instruct/finetune.py#L292

completion = parts[-1]

prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
completion_ids = tokenizer.encode(completion, add_special_tokens=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What special tokens does the olmo tokenizer add? There should be an eos token after every assistant message.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

prompt = "\n".join(parts[:-1]) + "\n"
completion = parts[-1]

prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it useful to add a bos token in training (or rather, using the eos as a bos marker), but I don't think its essential.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

input_ids = (prompt_ids + completion_ids)[:max_seq_len]
label_mask = ([False] * len(prompt_ids) + [True] * len(completion_ids))[:max_seq_len]

if len(input_ids) < max_seq_len:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random q: what happens when the sequence length is over your training max_seq_len? just naive truncation? (this is fine just curious)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea just naive truncation

Copy link
Collaborator

@2015aroras 2015aroras left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't see anything obviously wrong.

@epwalsh epwalsh merged commit f36ac42 into main Jan 19, 2024
10 checks passed
@epwalsh epwalsh deleted the epwalsh/fine-tune-with-label-masking branch January 19, 2024 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants