Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[generate] fix eos/pad id check on mps devices #31695

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Jun 28, 2024

What does this PR do?

Generation currently fails on main for mps devices:

from transformers.models.gemma2 import Gemma2ForCausalLM, Gemma2Config
import torch

config = Gemma2Config(num_hidden_layers=1, vocab_size=128, hidden_size=16, intermediate_size=32, num_attention_heads=1, num_key_value_heads=1)
model = Gemma2ForCausalLM(config).to("mps")

input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
model.generate(input_ids, attention_mask=input_ids)
Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[2], line 8
      5 model = Gemma2ForCausalLM(config).to("mps")
      7 input_ids = torch.ones((1, 10), dtype=torch.int).to("mps")
----> 8 model.generate(input_ids, attention_mask=input_ids)

File ~/venv/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/transformers/src/transformers/generation/utils.py:1664, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1661 batch_size = inputs_tensor.shape[0]
   1663 device = inputs_tensor.device
-> 1664 self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
   1666 # decoder-only models must use left-padding for batched generation.
   1667 if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
   1668     # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
   1669     # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.

File ~/transformers/src/transformers/generation/utils.py:1513, in GenerationMixin._prepare_special_tokens(self, generation_config, kwargs_has_attention_mask, device)
   1510     logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")
   1512 # we can't infer attn mask if pad token is set to be eos token in model's generation config
-> 1513 if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
   1514     if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
   1515         logger.warning_once(
   1516             "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
   1517             "As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` "
   1518             "to obtain reliable results."
   1519         )

NotImplementedError: The operator 'aten::isin.Tensor_Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

=> this is due to the torch.isin operator not being implemented on torch mps. This PR removes the torch.isin operator from the main body of generation, while keeping compatibility with the eos/pad checks added in #31254.

Following this PR, Gemma-2 (and other generate-compatible models) can be run on mps.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing!

@@ -1510,7 +1510,7 @@ def _tensor_or_none(token_kwargs, token_self, device=None):
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.")

# we can't infer attn mask if pad token is set to be eos token in model's generation config
if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante Are there any reasons e.g. compilation for using torch.isin?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, compilation requires this torch.isin 💔 cc @sanchit-gandhi

imo, we should create an internal function containing this torch.isin workaround (that works on compile AND mps devices), and replace all torch.isin call by this function

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(see comment above)

@gante
Copy link
Member

gante commented Jul 15, 2024

@sanchit-gandhi lmk if you have the bandwidth to address the change :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants