Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Sep 24, 2024

What does this PR do?

Step 3 of #32685 [move as many prepare_inputs_for_generation variants as possible into GenerationMixin]

This PR does the following changes on llama:

  • move prepare_inputs_for_generation to GenerationMixin. The goal is to have a base prepare_inputs_for_generation that works on most generate-capable models, as a) this is not modeling code b) most of them are copy-paste;
  • move _prepare_4d_causal_attention_mask_with_cache_position from a standalone method into LlamaModel. prepare_inputs_for_generation calls this function and this function may change on a given model, so it must be part of the model object. It is also part of the model forward pass, so I've decided to place it in each model's base model class, as opposed to in PreTrainedModel

Because llama is central and has many derived models, this change required touching many models. All changes follow the logic above.

👉 review suggestion: generation/utils.py -> llama -> the rest


✅ slow llama tests are green

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +357 to +363
"""
Prepare the model inputs for generation. In includes operations like computing the 4D attention mask or
slicing inputs given the existing cache.

See the documentation in the used model for the arguments (different models might have different requirements
for e.g. `past_key_values`). Should work as is for most LLMs.
"""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new docstring

Comment on lines +401 to +413
# Create the causal mask with fixed shape in advance, to reduce recompilations. If the function to create
# the 4D causal mask exists, it should be present in the base model (XXXModel class).
base_model = getattr(self, self.base_model_prefix)
causal_mask_creation_function = getattr(
base_model, "_prepare_4d_causal_attention_mask_with_cache_position", None
)
if causal_mask_creation_function is None:
logger.warning_once(
f"{self.__class__.__name__} has no `_prepare_4d_causal_attention_mask_with_cache_position` method "
"defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're "
"writing code, see Llama for an example implementation. If you're a user, please report this "
"issue on GitHub."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can handle the case where _prepare_4d_causal_attention_mask_with_cache_position doesn't exist.

I've added this extra logic to throw the warning in case something goes wrong when moving the function :) It will also be useful for other models in the future, since not all of them have this function.

attention_mask,
sequence_length=sequence_length,
target_length=past_key_values.get_max_length(),
dtype=self.get_output_embeddings().weight.dtype,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uses .get_output_embeddings() to be model-agnostic

target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_dtype is derived from dtype -> shorter signature

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning up, looks a lot nicer!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤗 looks nice thanks for taking the time to update most of the models!

@HuggingFaceDocBuilderDev

Hey! 🤗 Thanks for your contribution to the transformers library!

Before merging this pull request, slow tests CI should be triggered. To enable this:

  • Add the run-slow label to the PR
  • When your PR is ready for merge and all reviewers' comments have been addressed, push an empty commit with the command [run-slow] followed by a comma separated list of all the models to be tested, i.e. [run_slow] model_to_test_1, model_to_test_2
    • If the pull request affects a lot of models, put at most 10 models in the commit message
  • A transformers maintainer will then approve the workflow to start the tests

(For maintainers) The documentation for slow tests CI on PRs is here.

@gante gante merged commit 22266be into huggingface:main Oct 1, 2024
@gante gante deleted the move_llama_prepare_inputs_for_generation branch October 1, 2024 11:32
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants