-
Notifications
You must be signed in to change notification settings - Fork 31.9k
Generation: deprecate PreTrainedModel inheriting from GenerationMixin
#33203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
1160c0f
round 2: BC compatible inheritation removal
gante 1b2cd4e
better test
gante 8e018c8
better warning
gante 898315f
Merge branch 'main' into isolated_mixin
gante 03e05b6
granite
gante 00aca81
Merge branch 'main' into isolated_mixin
gante 275b631
make fixup
gante 3709bf8
Merge branch 'main' into isolated_mixin
gante 9620273
PR comments; Add can_generate test
gante 01c66a1
Merge branch 'main' into isolated_mixin
gante 7aa36fb
update llama_onevision
gante 9337cd4
Merge branch 'main' into isolated_mixin
gante d57cefa
Merge branch 'main' into isolated_mixin
gante ed22155
add test to check inheritance is in place; fix missing models
gante 8b40781
make fixup
gante File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs): | |
| setattr(torch.nn.init, name, init_func) | ||
|
|
||
|
|
||
| def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| try: | ||
| return next(parameter.parameters()).device | ||
| except StopIteration: | ||
|
|
@@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
| return first_tuple[1].device | ||
|
|
||
|
|
||
| def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| """ | ||
| Returns the first parameter dtype (can be non-floating) or asserts if none were found. | ||
| """ | ||
|
|
@@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: | |
| return first_tuple[1].dtype | ||
|
|
||
|
|
||
| def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): | ||
| def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): | ||
| """ | ||
| Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. | ||
| """ | ||
|
|
@@ -1309,6 +1309,7 @@ def floating_point_ops( | |
| return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) | ||
|
|
||
|
|
||
| # TODO (joao): remove `GenerationMixin` inheritance in v4.50 | ||
| class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): | ||
| r""" | ||
| Base class for all models. | ||
|
|
@@ -1638,11 +1639,30 @@ def can_generate(cls) -> bool: | |
| Returns: | ||
| `bool`: Whether this model can generate sequences with `.generate()`. | ||
| """ | ||
| # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. | ||
| # Alternativelly, the model can also have a custom `generate` function. | ||
| if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): | ||
| return False | ||
| return True | ||
| # Directly inherits `GenerationMixin` -> can generate | ||
| if "GenerationMixin" in str(cls.__bases__): | ||
| return True | ||
| # Model class overwrites `generate` (e.g. time series models) -> can generate | ||
| if str(cls.__name__) in str(cls.generate): | ||
| return True | ||
| # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this | ||
| # was how we detected whether a model could generate. | ||
| if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): | ||
| logger.warning_once( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This warning should only appear in the case that is not BC after removing the |
||
| f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " | ||
| "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From πv4.50π onwards, " | ||
| "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " | ||
| "to call `generate` and other related functions." | ||
| "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " | ||
| "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" | ||
| "\n - If you are the owner of the model architecture code, please modify your model class such that " | ||
| "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." | ||
| "\n - If you are not the owner of the model architecture class, please contact the model code owner " | ||
| "to update it." | ||
| ) | ||
| return True | ||
| # Otherwise, can't generate | ||
| return False | ||
|
|
||
| @classmethod | ||
| def _check_and_enable_flash_attn_2( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This function was touched to avoid a circular import (the
from ..models.autoimports)