Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Nov 14, 2022

What does this PR do?

Bloom has a different cache format, where the batch size and the number of heads are packed in a single dimension. Contrastive search needs to manipulate the cache at the batch dimension, so naturally it fails.

This PR adds functionality to convert Bloom's cache back and forth between its own format and the standard cache format. Then, propagates the use of these new functions to places where the conversion logic was already being used, and finally fixes Bloom's contrastive search.

All slow tests are passing.


This fix was also requested here

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super excited to see BLOOM with constrastive search 💪 Thanks a lot!
LGTM as long as all slow tests pass! I just left a small comment - I am unsure how device_to_beam_idx will behave in a multi-gpu setup - so feel free to ignore my comment if you think it works!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
layer_past[1].index_select(0, device_to_beam_idx[layer_past[1].device]),

Maybe this will fail in a multi-GPU setup? (thinking of BLOOM-176) but I am not sure didn't tested it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I jus noticed that we were also doing layer_past[1].view(batch_size, num_heads, seq_length, head_dim).index_select(0, device_to_beam_idx[layer_past[0].device]) before so probably what I propose is not needed!

Copy link
Contributor Author

@gante gante Nov 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know absolutely nothing about torch + multi GPU -- how can I test it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a small snippet to try this out (ofc you need a multi-gpu setup):

from transformers import AutoModelForCausalLM, AutoTokenizer


max_memory = {0:"2GB", 1:"3GB"}
model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-1b7", device_map="auto", max_memory=max_memory, torch_dtype="auto")
# Confirm the model is on multiple GPUs
print(model.hf_device_map)
>>> {'transformer.word_embeddings': 0, 'lm_head': 0, 'transformer.word_embeddings_layernorm': 0, 'transformer.h.0': 0, 'transformer.h.1': 0, 'transformer.h.2': 0, 'transformer.h.3': 0, 'transformer.h.4': 0, 'transformer.h.5': 0, 'transformer.h.6': 0, 'transformer.h.7': 0, 'transformer.h.8': 1, 'transformer.h.9': 1, 'transformer.h.10': 1, 'transformer.h.11': 1, 'transformer.h.12': 1, 'transformer.h.13': 1, 'transformer.h.14': 1, 'transformer.h.15': 1, 'transformer.h.16': 1, 'transformer.h.17': 1, 'transformer.h.18': 1, 'transformer.h.19': 1, 'transformer.h.20': 1, 'transformer.h.21': 1, 'transformer.h.22': 1, 'transformer.h.23': 1, 'transformer.ln_f': 1}

input_text = "Twitter is "
input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(0)
out = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)
print(tokenizer.decode(out[0]))
>>> 'Twitter is  a social networking site that allows users to create and manage their own profiles, and to connect with friends and other users. The site is owned by Facebook Inc. and is operated by Facebook.</s>'

Seems to work like charm! I think that my comment can be safely ignored

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw out = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512) is the correct way to try contrastive learning right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it is! Thanks for double-checking 🙌

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix. I don't like adding a new method like this, but I don't see any other way of doing it either :-(

@gante gante force-pushed the contrastive_bloom branch from 251d282 to e0ee202 Compare November 14, 2022 18:07
@gante
Copy link
Contributor Author

gante commented Nov 14, 2022

(rebasing to include #20200 in CI, the related test was failing)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@gante gante merged commit 938cb04 into huggingface:main Nov 14, 2022
@gante gante deleted the contrastive_bloom branch November 14, 2022 18:34
@rpryzant rpryzant mentioned this pull request Nov 17, 2022
4 tasks
@rpryzant
Copy link

rpryzant commented Nov 17, 2022

Stumbled on the same issue and found this fix. Thanks lots!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants