-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Generate: add Bloom fixes for contrastive search #20213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super excited to see BLOOM with constrastive search 💪 Thanks a lot!
LGTM as long as all slow tests pass! I just left a small comment - I am unsure how device_to_beam_idx will behave in a multi-gpu setup - so feel free to ignore my comment if you think it works!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]), | |
| layer_past[1].index_select(0, device_to_beam_idx[layer_past[1].device]), |
Maybe this will fail in a multi-GPU setup? (thinking of BLOOM-176) but I am not sure didn't tested it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I jus noticed that we were also doing layer_past[1].view(batch_size, num_heads, seq_length, head_dim).index_select(0, device_to_beam_idx[layer_past[0].device]) before so probably what I propose is not needed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know absolutely nothing about torch + multi GPU -- how can I test it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a small snippet to try this out (ofc you need a multi-gpu setup):
from transformers import AutoModelForCausalLM, AutoTokenizer
max_memory = {0:"2GB", 1:"3GB"}
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b7", device_map="auto", max_memory=max_memory, torch_dtype="auto")
# Confirm the model is on multiple GPUs
print(model.hf_device_map)
>>> {'transformer.word_embeddings': 0, 'lm_head': 0, 'transformer.word_embeddings_layernorm': 0, 'transformer.h.0': 0, 'transformer.h.1': 0, 'transformer.h.2': 0, 'transformer.h.3': 0, 'transformer.h.4': 0, 'transformer.h.5': 0, 'transformer.h.6': 0, 'transformer.h.7': 0, 'transformer.h.8': 1, 'transformer.h.9': 1, 'transformer.h.10': 1, 'transformer.h.11': 1, 'transformer.h.12': 1, 'transformer.h.13': 1, 'transformer.h.14': 1, 'transformer.h.15': 1, 'transformer.h.16': 1, 'transformer.h.17': 1, 'transformer.h.18': 1, 'transformer.h.19': 1, 'transformer.h.20': 1, 'transformer.h.21': 1, 'transformer.h.22': 1, 'transformer.h.23': 1, 'transformer.ln_f': 1}
input_text = "Twitter is "
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(0)
out = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
print(tokenizer.decode(out[0]))
>>> 'Twitter is a social networking site that allows users to create and manage their own profiles, and to connect with friends and other users. The site is owned by Facebook Inc. and is operated by Facebook.</s>'Seems to work like charm! I think that my comment can be safely ignored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw out = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512) is the correct way to try contrastive learning right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, it is! Thanks for double-checking 🙌
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix. I don't like adding a new method like this, but I don't see any other way of doing it either :-(
251d282 to
e0ee202
Compare
|
(rebasing to include #20200 in CI, the related test was failing) |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Stumbled on the same issue and found this fix. Thanks lots! |
What does this PR do?
Bloom has a different cache format, where the batch size and the number of heads are packed in a single dimension. Contrastive search needs to manipulate the cache at the batch dimension, so naturally it fails.
This PR adds functionality to convert Bloom's cache back and forth between its own format and the standard cache format. Then, propagates the use of these new functions to places where the conversion logic was already being used, and finally fixes Bloom's contrastive search.
All slow tests are passing.
This fix was also requested here