-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Generate: move llama prepare_inputs_for_generation to GenerationMixin
#33677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate: move llama prepare_inputs_for_generation to GenerationMixin
#33677
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| """ | ||
| Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or | ||
| slicing inputs given the existing cache. | ||
|
|
||
| See the documentation in the used model for the arguments (different models might have different requirements | ||
| for e.g. `past_key_values`). Should work as is for most LLMs. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new docstring
| # Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create | ||
| # the 4D causal mask exists, it should be present in the base model (XXXModel class). | ||
| base_model = getattr(self, self.base_model_prefix) | ||
| causal_mask_creation_function = getattr( | ||
| base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None | ||
| ) | ||
| if causal_mask_creation_function is None: | ||
| logger.warning_once( | ||
| f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method " | ||
| "defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're " | ||
| "writing code, see Llama for an example implementation. If you're a user, please report this " | ||
| "issue on GitHub." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can handle the case where _prepare_4d_causal_attention_mask_with_cache_position doesn't exist.
I've added this extra logic to throw the warning in case something goes wrong when moving the function :) It will also be useful for other models in the future, since not all of them have this function.
| attention_mask, | ||
| sequence_length=sequence_length, | ||
| target_length=past_key_values.get_max_length(), | ||
| dtype=self.get_output_embeddings().weight.dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uses .get_output_embeddings() to be model-agnostic
| target_length=target_length, | ||
| dtype=dtype, | ||
| device=device, | ||
| min_dtype=min_dtype, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
min_dtype is derived from dtype -> shorter signature
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning up, looks a lot nicer!
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤗 looks nice thanks for taking the time to update most of the models!
|
Hey! 🤗 Thanks for your contribution to the Before merging this pull request, slow tests CI should be triggered. To enable this:
(For maintainers) The documentation for slow tests CI on PRs is here. |
What does this PR do?
Step 3 of #32685 [move as many
prepare_inputs_for_generationvariants as possible intoGenerationMixin]This PR does the following changes on
llama:prepare_inputs_for_generationtoGenerationMixin. The goal is to have a baseprepare_inputs_for_generationthat works on most generate-capable models, as a) this is not modeling code b) most of them are copy-paste;_prepare_4d_causal_attention_mask_with_cache_positionfrom a standalone method intoLlamaModel.prepare_inputs_for_generationcalls this function and this function may change on a given model, so it must be part of the model object. It is also part of the model forward pass, so I've decided to place it in each model's base model class, as opposed to inPreTrainedModelBecause
llamais central and has many derived models, this change required touching many models. All changes follow the logic above.👉 review suggestion:
generation/utils.py->llama-> the rest✅ slow llama tests are green