Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix 'Can't infer missing attention mask on mps device' #31744

Open
BlueBlazin opened this issue Jul 2, 2024 · 4 comments
Open

Fix 'Can't infer missing attention mask on mps device' #31744

BlueBlazin opened this issue Jul 2, 2024 · 4 comments
Labels
Feature request Request for a new feature

Comments

@BlueBlazin
Copy link

BlueBlazin commented Jul 2, 2024

Feature request

I was trying out local-gemma just now on my M1 macbook and ran into this error:

ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

The proposed temporary solution of using PYTORCH_ENABLE_MPS_FALLBACK=1 was not working. The cause of the error is this bit of code in generation/utils.py:

def _prepare_attention_mask_for_generation(
        self,
        inputs: torch.Tensor,
        pad_token_id: Optional[torch.Tensor],
        eos_token_id: Optional[torch.Tensor],
    ) -> torch.LongTensor:
        # ...

        if inputs.device.type == "mps":
            # mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
            raise ValueError(
                "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
            )

        is_pad_token_in_inputs = (pad_token_id is not None) and (
            torch.isin(elements=inputs, test_elements=pad_token_id).any()
        )
        is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ~(
            torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()
        )

Mainly, torch.isin is not supported. I don't know all the places where _prepare_attention_mask_for_generation is used or how it's used. But just a cursory glance at the code made me wonder if this is really necessary:

torch.isin(elements=inputs, test_elements=pad_token_id).any()
torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()

I fixed my issue by changing those to:

(inputs == pad_token_id).any()  #  was: torch.isin(elements=inputs, test_elements=pad_token_id).any()
# and
eos_token_id == pad_token_id  # was: torch.isin(elements=eos_token_id, test_elements=pad_token_id).any()

and removing the mps conditional.

Again, I'm not sure if this will actually cover all use cases of that helper function but if it does, could we please go with the simple solution until isin is supported by MPS?

Motivation

This should just work on a macbook with mps without hacks like setting PYTORCH_ENABLE_MPS_FALLBACK=1 (which didn't work anyway).

Your contribution

If the proposed solution looks good then I can submit a PR (or someone else can, either way I just want my code to work).

@BlueBlazin BlueBlazin added the Feature request Request for a new feature label Jul 2, 2024
@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Jul 2, 2024

Hey @BlueBlazin - do you have a reproducible code-snippet for this issue? You can likely bypass it using:

from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer
import torch

model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-9b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(model.device)

# create a "dummy" attention mask to pass to the model
dummy_attention_mask = torch.ones_like(encodeds)

generated_ids = model.generate(model_inputs, attention_mask= dummy_attention_mask, max_new_tokens=1024, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids)

@BlueBlazin
Copy link
Author

@sanchit-gandhi Thanks. I just used the code snippet from that repo's README, which was:

from local_gemma import LocalGemma2ForCausalLM
from transformers import AutoTokenizer

model = LocalGemma2ForCausalLM.from_pretrained("google/gemma-2-27b-it", preset="auto")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(model.device)

generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
decoded_text = tokenizer.batch_decode(generated_ids)

and had installed local-gemma using pip install local-gemma"[mps]" in a new virtual environment. I'll use your code snippet. Thanks.

@Vargol
Copy link

Vargol commented Jul 4, 2024

I've just had the same error with superprompt which was working , but I can't remember the last time I tried it, a couple of months ago maybe

from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("roborovski/superprompt-v1")
model = T5ForConditionalGeneration.from_pretrained("roborovski/superprompt-v1", device_map="auto")

input_text = "Expand the following prompt to add more detail: A storefront with 'Text to Image' written on it."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("mps")

outputs = model.generate(input_ids, max_new_tokens=77)
print(tokenizer.decode(outputs[0]))
$ python super.py 
tokenizer_config.json: 100%|██████████| 2.54k/2.54k [00:00<00:00, 5.73MB/s]
spiece.model: 100%|██████████| 792k/792k [00:00<00:00, 11.5MB/s]
tokenizer.json: 100%|██████| 2.42M/2.42M [00:00<00:00, 5.49MB/s]
You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
  File "/Volumes/SSD2TB/AI/Diffusers/super.py", line 9, in <module>
    outputs = model.generate(input_ids, max_new_tokens=77)
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.10/site-packages/transformers/generation/utils.py", line 1591, in generate
    model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
  File "/Volumes/SSD2TB/AI/Diffusers/lib/python3.10/site-packages/transformers/generation/utils.py", line 468, in _prepare_attention_mask_for_generation
    raise ValueError(
ValueError: Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device.

@Vargol
Copy link

Vargol commented Jul 4, 2024

Look like its now hard coded to fail

        # Otherwise we have may have information -> try to infer the attention mask
        if inputs.device.type == "mps":
            # mps does not support torch.isin (https://github.com/pytorch/pytorch/issues/77764)
            raise ValueError(
                "Can't infer missing attention mask on `mps` device. Please provide an `attention_mask` or use a different device."
            )

Rather than letting it isin fallback to the CPU implementation if PYTORCH_ENABLE_MPS_FALLBACK=1 is used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

No branches or pull requests

3 participants