Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..pytorch_utils import isin_mps_friendly
from ..tokenization_utils import ExtensionsTrie
from ..utils import (
Expand Down Expand Up @@ -1113,32 +1106,6 @@ def compute_transition_scores(

return transition_scores

def _validate_model_class(self):
"""
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not is_torchdynamo_compiling() and not self.can_generate():
generate_compatible_mappings = [
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
]
generate_compatible_classes = set()
for model_mapping in generate_compatible_mappings:
supported_models = model_mapping.get(type(self.config), default=None)
if supported_models is not None:
generate_compatible_classes.add(supported_models.__name__)
exception_message = (
f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as "
"it doesn't have a language model head."
)
if generate_compatible_classes:
exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}"
raise TypeError(exception_message)

def _validate_assistant(self, assistant_model):
if assistant_model is None:
return
Expand Down Expand Up @@ -1777,7 +1744,6 @@ def generate(
- [`~generation.GenerateBeamEncoderDecoderOutput`]
"""
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
Expand Down
44 changes: 35 additions & 9 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs):
setattr(torch.nn.init, name, init_func)


def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
try:
return next(parameter.parameters()).device
except StopIteration:
Expand All @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].device


def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first parameter dtype (can be non-floating) or asserts if none were found.
"""
Expand All @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
return first_tuple[1].dtype


def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]):
def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
"""
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
"""
Expand Down Expand Up @@ -1295,7 +1295,7 @@ def floating_point_ops(
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)


class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
r"""
Base class for all models.

Expand Down Expand Up @@ -1624,11 +1624,7 @@ def can_generate(cls) -> bool:
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
# Alternativelly, the model can also have a custom `generate` function.
if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
return False
return True
return issubclass(cls, GenerationMixin)

@classmethod
def _check_and_enable_flash_attn_2(
Expand Down Expand Up @@ -4716,6 +4712,36 @@ def _is_quantized_training_enabled(self):

return self.hf_quantizer.is_trainable

def generate(self, *args, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} doesn't have a `generate` method. If you were not "
"expecting this exception, here are some things to check:"
"\n 1. If you are using a custom model (or `trust_remote_code=True`): Load your model with an "
"`AutoModel...` class. If that is not possible, make sure your model class inherits from "
"`GenerationMixin` as its first inherited class. `PreTrainedModel` no longer inherits "
"`GenerationMixin` as of v4.45"
"\n 2. If you are using a transformers model: You might be using a class that is not meant for "
"generation. Your model must be compatible with one of the following auto classes: "
"`AutoModelForCausalLM`, `ForConditionalGeneration`, `AutoModelForSpeechSeq2Seq`, and "
"`AutoModelForVision2Seq`"
"\n 3. If none of the cases above apply, please open an issue on GitHub πŸ€—"
)

def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
f"{self.__class__.__name__} doesn't have a `prepare_inputs_for_generation` method. If you were not "
"expecting this exception, here are some things to check:"
"\n 1. If you are using a custom model (or `trust_remote_code=True`): Load your model with an "
"`AutoModel...` class. If that is not possible, make sure your model class inherits from "
"`GenerationMixin` as its first inherited class. `PreTrainedModel` no longer inherits "
"`GenerationMixin` as of v4.45"
"\n 2. If you are using a transformers model: You might be using a class that is not meant for "
"generation. Your model must be compatible with one of the following auto classes: "
"`AutoModelForCausalLM`, `ForConditionalGeneration`, `AutoModelForSpeechSeq2Seq`, and "
"`AutoModelForVision2Seq`"
"\n 3. If none of the cases above apply, please open an issue on GitHub πŸ€—"
)


PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
if PreTrainedModel.push_to_hub.__doc__ is not None:
Expand Down
37 changes: 37 additions & 0 deletions src/transformers/models/auto/auto_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,17 @@
extract_commit_hash,
find_adapter_config_file,
is_peft_available,
is_torch_available,
logging,
requires_backends,
)
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings


if is_torch_available():
from ...generation import GenerationMixin


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -432,6 +437,7 @@ def from_config(cls, config, **kwargs):
else:
cls.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping)
Expand Down Expand Up @@ -556,6 +562,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
model_class.register_for_auto_class(cls.__name__)
else:
cls.register(config.__class__, model_class, exist_ok=True)
model_class = add_generation_mixin_to_remote_model(model_class)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
Expand Down Expand Up @@ -705,6 +712,36 @@ def getattribute_from_module(module, attr):
raise ValueError(f"Could not find {attr} in {transformers_module}!")


def add_generation_mixin_to_remote_model(model_class):
"""
Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. This function is
used for backwards compatibility purposes: prior to v4.45, `PreTrainedModel` inherited from `GenerationMixin`.
Without this function, older models dynamically loaded from the Hub may not have the `generate` method.
"""
# 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing
if "torch.nn.modules.module.Module" not in str(model_class.__mro__):
return model_class

# 2. If it already inherits from GenerationMixin, do nothing
if issubclass(model_class, GenerationMixin):
return model_class

# 3. If the class name has a suffix that indicates that it should be able to generate, add `GenerationMixin` to
# the class inheritance. Otherwise, do nothing.
terminations_with_generation_support = [
"ForCausalLM",
"ForConditionalGeneration",
"ForSpeechSeq2Seq",
"ForVision2Seq",
]
if any(model_class.__name__.endswith(termination) for termination in terminations_with_generation_support):
model_class_with_generation_mixin = type(
model_class.__name__, (GenerationMixin, model_class), {**model_class.__dict__}
)
return model_class_with_generation_mixin
return model_class


class _LazyAutoMapping(OrderedDict):
"""
" A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed.
Expand Down
7 changes: 4 additions & 3 deletions src/transformers/models/bark/modeling_bark.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch import nn
from torch.nn import functional as F

from ...generation import GenerationMixin
from ...generation.logits_process import (
AlternatingCodebooksLogitsProcessor,
BarkEosPrioritizerLogitsProcessor,
Expand Down Expand Up @@ -546,7 +547,7 @@ def device(self) -> torch.device:


# GPT2-like autoregressive model
class BarkCausalModel(BarkPreTrainedModel):
class BarkCausalModel(GenerationMixin, BarkPreTrainedModel):
config_class = BarkSubModelConfig

def __init__(self, config):
Expand Down Expand Up @@ -1136,7 +1137,7 @@ def generate(
language modeling heads, one for each codebook.""",
BARK_MODEL_START_DOCSTRING.format(config="BarkFineConfig"),
)
class BarkFineModel(BarkPreTrainedModel):
class BarkFineModel(GenerationMixin, BarkPreTrainedModel):
base_model_prefix = "fine_acoustics"
config_class = BarkFineConfig
main_input_name = "codebook_idx"
Expand Down Expand Up @@ -1536,7 +1537,7 @@ def generate(
""",
BARK_START_DOCSTRING,
)
class BarkModel(BarkPreTrainedModel):
class BarkModel(GenerationMixin, BarkPreTrainedModel):
config_class = BarkConfig

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1557,7 +1558,7 @@ def forward(
@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING
)
class BartForConditionalGeneration(BartPreTrainedModel):
class BartForConditionalGeneration(GenerationMixin, BartPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs):
""",
BART_START_DOCSTRING,
)
class BartForCausalLM(BartPreTrainedModel):
class BartForCausalLM(GenerationMixin, BartPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
Expand Down Expand Up @@ -1280,7 +1281,7 @@ def forward(
@add_start_docstrings(
"""Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING
)
class BertLMHeadModel(BertPreTrainedModel):
class BertLMHeadModel(GenerationMixin, BertPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
Expand Down Expand Up @@ -863,7 +864,7 @@ def _tie_weights(self):
"""BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""",
BERT_GENERATION_START_DOCSTRING,
)
class BertGenerationDecoder(BertGenerationPreTrainedModel):
class BertGenerationDecoder(GenerationMixin, BertGenerationPreTrainedModel):
_tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
Expand Down Expand Up @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
@add_start_docstrings(
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
)
class BigBirdForCausalLM(BigBirdPreTrainedModel):
class BigBirdForCausalLM(GenerationMixin, BigBirdPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]

def __init__(self, config):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -2436,7 +2437,7 @@ def forward(
BIGBIRD_PEGASUS_START_DOCSTRING,
)
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForConditionalGeneration(GenerationMixin, BigBirdPegasusPreTrainedModel):
base_model_prefix = "model"
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
Expand Down Expand Up @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)


class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
class BigBirdPegasusForCausalLM(GenerationMixin, BigBirdPegasusPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -596,7 +597,7 @@ def forward(
@add_start_docstrings(
"""BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING
)
class BioGptForCausalLM(BioGptPreTrainedModel):
class BioGptForCausalLM(GenerationMixin, BioGptPreTrainedModel):
_tied_weights_keys = ["output_projection.weight"]

def __init__(self, config):
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/blenderbot/modeling_blenderbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
from ...modeling_outputs import (
BaseModelOutput,
Expand Down Expand Up @@ -1196,7 +1197,7 @@ def forward(
@add_start_docstrings(
"The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING
)
class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
class BlenderbotForConditionalGeneration(GenerationMixin, BlenderbotPreTrainedModel):
base_model_prefix = "model"
_keys_to_ignore_on_load_missing = ["final_logits_bias"]
_tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"]
Expand Down Expand Up @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs):


# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
class BlenderbotForCausalLM(GenerationMixin, BlenderbotPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config):
Expand Down
Loading