-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix 'Can't infer missing attention mask on mps
device'
#31744
Comments
Hey @BlueBlazin - do you have a reproducible code-snippet for this issue? You can likely bypass it using: from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
import torch
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-9b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(model.device)
# create a "dummy" attention mask to pass to the model
dummy_attention_mask = torch.ones_like(encodeds)
generated_ids = model.generate(model_inputs, attention_mask= dummy_attention_mask, max_new_tokens=1024, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids) |
@sanchit-gandhi Thanks. I just used the code snippet from that repo's README, which was: from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-27b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")
messages = [
{"role": "user", "content": "What is your favourite condiment?"},
{"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
{"role": "user", "content": "Do you have mayonnaise recipes?"}
]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(model.device)
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids) and had installed local-gemma using |
I've just had the same error with superprompt which was working , but I can't remember the last time I tried it, a couple of months ago maybe from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto")
input_text = "Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("mps")
outputs = model.generate(input_ids, max_new_tokens=77)
print(tokenizer.decode(outputs[0]))
|
Look like its now hard coded to fail # Otherwise we have may have information -> try to infer the attention mask
if inputs.device.type == "mps":
# mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
raise ValueError(
"Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
) Rather than letting it |
Feature request
I was trying out local-gemma just now on my M1 macbook and ran into this error:
The proposed temporary solution of using
PYTORCH_ENABLE_MPS_FALLBACK=1
was not working. The cause of the error is this bit of code ingeneration/utils.py
:Mainly,
torch.isin
is not supported. I don't know all the places where_prepare_attention_mask_for_generation
is used or how it's used. But just a cursory glance at the code made me wonder if this is really necessary:I fixed my issue by changing those to:
and removing the mps conditional.
Again, I'm not sure if this will actually cover all use cases of that helper function but if it does, could we please go with the simple solution until
isin
is supported by MPS?Motivation
This should just work on a macbook with mps without hacks like setting
PYTORCH_ENABLE_MPS_FALLBACK=1
(which didn't work anyway).Your contribution
If the proposed solution looks good then I can submit a PR (or someone else can, either way I just want my code to work).
The text was updated successfully, but these errors were encountered: