From 1160c0f318703888750d20a6988febf52099df2b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Aug 2024 18:26:20 +0000 Subject: [PATCH 1/9] round 2: BC compatible inheritation removal --- src/transformers/generation/utils.py | 34 ++++++---------- src/transformers/modeling_utils.py | 33 ++++++++++++---- .../models/albert/modeling_albert.py | 3 +- src/transformers/models/auto/auto_factory.py | 39 +++++++++++++++++++ src/transformers/models/bark/modeling_bark.py | 3 +- src/transformers/models/bart/modeling_bart.py | 5 ++- src/transformers/models/bert/modeling_bert.py | 3 +- .../modeling_bert_generation.py | 3 +- .../models/big_bird/modeling_big_bird.py | 3 +- .../modeling_bigbird_pegasus.py | 5 ++- .../models/biogpt/modeling_biogpt.py | 3 +- .../models/blenderbot/modeling_blenderbot.py | 5 ++- .../modeling_blenderbot_small.py | 5 ++- src/transformers/models/blip/modeling_blip.py | 3 +- .../models/blip/modeling_blip_text.py | 3 +- .../models/blip_2/modeling_blip_2.py | 3 +- .../models/bloom/modeling_bloom.py | 3 +- .../models/camembert/modeling_camembert.py | 3 +- .../models/chameleon/modeling_chameleon.py | 3 +- src/transformers/models/clvp/modeling_clvp.py | 6 +-- .../models/codegen/modeling_codegen.py | 3 +- .../models/cohere/modeling_cohere.py | 3 +- .../models/cpmant/modeling_cpmant.py | 3 +- src/transformers/models/ctrl/modeling_ctrl.py | 3 +- .../models/data2vec/modeling_data2vec_text.py | 3 +- src/transformers/models/dbrx/modeling_dbrx.py | 3 +- .../models/electra/modeling_electra.py | 3 +- .../models/ernie/modeling_ernie.py | 3 +- .../models/falcon/modeling_falcon.py | 3 +- .../falcon_mamba/modeling_falcon_mamba.py | 3 +- .../models/flaubert/modeling_flaubert.py | 3 +- src/transformers/models/fsmt/modeling_fsmt.py | 3 +- src/transformers/models/fuyu/modeling_fuyu.py | 3 +- src/transformers/models/gemma/diff_gemma.py | 3 +- .../models/gemma/modeling_gemma.py | 3 +- src/transformers/models/gemma2/diff_gemma2.py | 3 +- .../models/gemma2/modeling_gemma2.py | 3 +- src/transformers/models/git/modeling_git.py | 3 +- src/transformers/models/gpt2/modeling_gpt2.py | 5 ++- .../gpt_bigcode/modeling_gpt_bigcode.py | 3 +- .../models/gpt_neo/modeling_gpt_neo.py | 3 +- .../models/gpt_neox/modeling_gpt_neox.py | 3 +- .../modeling_gpt_neox_japanese.py | 3 +- src/transformers/models/gptj/modeling_gptj.py | 3 +- .../models/idefics2/modeling_idefics2.py | 5 ++- .../models/imagegpt/modeling_imagegpt.py | 3 +- .../instructblip/modeling_instructblip.py | 3 +- .../diff_instructblipvideo.py | 3 +- .../modeling_instructblipvideo.py | 3 +- .../models/jamba/modeling_jamba.py | 3 +- .../models/jetmoe/modeling_jetmoe.py | 3 +- .../models/kosmos2/modeling_kosmos2.py | 5 ++- src/transformers/models/led/modeling_led.py | 3 +- .../models/llama/modeling_llama.py | 3 +- .../models/llava/modeling_llava.py | 5 ++- .../models/llava_next/modeling_llava_next.py | 5 ++- .../llava_next_video/diff_llava_next_video.py | 3 +- .../modeling_llava_next_video.py | 5 ++- .../models/longt5/modeling_longt5.py | 3 +- .../models/m2m_100/modeling_m2m_100.py | 3 +- .../models/mamba/modeling_mamba.py | 3 +- .../models/mamba2/modeling_mamba2.py | 3 +- .../models/marian/modeling_marian.py | 5 ++- .../models/mbart/modeling_mbart.py | 5 ++- .../megatron_bert/modeling_megatron_bert.py | 3 +- .../models/mistral/modeling_mistral.py | 3 +- .../models/mixtral/modeling_mixtral.py | 3 +- src/transformers/models/mpt/modeling_mpt.py | 3 +- src/transformers/models/mt5/modeling_mt5.py | 3 +- .../models/musicgen/modeling_musicgen.py | 15 ++++--- .../modeling_musicgen_melody.py | 15 ++++--- src/transformers/models/mvp/modeling_mvp.py | 5 ++- .../models/nemotron/modeling_nemotron.py | 3 +- .../models/nllb_moe/modeling_nllb_moe.py | 3 +- src/transformers/models/olmo/modeling_olmo.py | 3 +- .../models/openai/modeling_openai.py | 3 +- src/transformers/models/opt/modeling_opt.py | 3 +- .../models/paligemma/modeling_paligemma.py | 3 +- .../models/pegasus/modeling_pegasus.py | 5 ++- .../models/pegasus_x/modeling_pegasus_x.py | 3 +- .../models/persimmon/modeling_persimmon.py | 3 +- src/transformers/models/phi/modeling_phi.py | 3 +- src/transformers/models/phi3/modeling_phi3.py | 3 +- .../models/pix2struct/modeling_pix2struct.py | 3 +- .../models/plbart/modeling_plbart.py | 5 ++- .../models/pop2piano/modeling_pop2piano.py | 3 +- .../models/prophetnet/modeling_prophetnet.py | 5 ++- .../models/qwen2/modeling_qwen2.py | 3 +- .../qwen2_audio/modeling_qwen2_audio.py | 5 ++- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 3 +- .../modeling_recurrent_gemma.py | 3 +- .../models/reformer/modeling_reformer.py | 3 +- .../models/rembert/modeling_rembert.py | 3 +- .../models/roberta/modeling_roberta.py | 3 +- .../modeling_roberta_prelayernorm.py | 3 +- .../models/roc_bert/modeling_roc_bert.py | 3 +- .../models/roformer/modeling_roformer.py | 3 +- src/transformers/models/rwkv/modeling_rwkv.py | 3 +- .../seamless_m4t/modeling_seamless_m4t.py | 5 ++- .../modeling_seamless_m4t_v2.py | 5 ++- .../speech_to_text/modeling_speech_to_text.py | 3 +- .../models/stablelm/modeling_stablelm.py | 3 +- .../models/starcoder2/modeling_starcoder2.py | 3 +- .../modeling_switch_transformers.py | 3 +- src/transformers/models/t5/modeling_t5.py | 3 +- .../models/trocr/modeling_trocr.py | 3 +- src/transformers/models/udop/modeling_udop.py | 3 +- src/transformers/models/umt5/modeling_umt5.py | 3 +- .../video_llava/modeling_video_llava.py | 5 ++- .../models/vipllava/modeling_vipllava.py | 5 ++- .../models/whisper/generation_whisper.py | 4 +- .../models/whisper/modeling_whisper.py | 3 +- src/transformers/models/xglm/modeling_xglm.py | 3 +- src/transformers/models/xlm/modeling_xlm.py | 3 +- .../xlm_roberta/modeling_xlm_roberta.py | 3 +- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 3 +- .../models/xlnet/modeling_xlnet.py | 3 +- src/transformers/models/xmod/modeling_xmod.py | 3 +- tests/models/auto/test_modeling_auto.py | 18 +++++++++ 120 files changed, 363 insertions(+), 179 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0d2baea6d85f..c855b6f7dea8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -40,13 +40,6 @@ ) from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput -from ..models.auto import ( - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, -) from ..pytorch_utils import isin_mps_friendly from ..tokenization_utils import ExtensionsTrie from ..utils import ( @@ -1118,26 +1111,21 @@ def _validate_model_class(self): Confirms that the model class is compatible with generation. If not, raises an exception that points to the right class to use. """ + # TODO(joao): remove this function in v4.50, i.e. when we remove the inheritance of `GenerationMixin` from + # `PreTrainedModel`. With that inheritance removed, all model classes inheriting from `GenerationMixin` can + # safely call `GenerationMixin.generate` if not is_torchdynamo_compiling() and not self.can_generate(): - generate_compatible_mappings = [ - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, + terminations_with_generation_support = [ + "ForCausalLM", + "ForConditionalGeneration", + "ForSpeechSeq2Seq", + "ForVision2Seq", ] - generate_compatible_classes = set() - for model_mapping in generate_compatible_mappings: - supported_models = model_mapping.get(type(self.config), default=None) - if supported_models is not None: - generate_compatible_classes.add(supported_models.__name__) - exception_message = ( + raise TypeError( f"The current model class ({self.__class__.__name__}) is not compatible with `.generate()`, as " - "it doesn't have a language model head." + "it doesn't have a language model head. Classes that support generation often end in one of these " + f"names: {terminations_with_generation_support}." ) - if generate_compatible_classes: - exception_message += f" Please use one of the following classes instead: {generate_compatible_classes}" - raise TypeError(exception_message) def _validate_assistant(self, assistant_model): if assistant_model is None: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b943b5e7989f..93797b3c92d9 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -212,7 +212,7 @@ def _skip_init(*args, **kwargs): setattr(torch.nn.init, name, init_func) -def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]): try: return next(parameter.parameters()).device except StopIteration: @@ -227,7 +227,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].device -def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_first_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first parameter dtype (can be non-floating) or asserts if none were found. """ @@ -245,7 +245,7 @@ def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: return first_tuple[1].dtype -def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): +def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]): """ Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. """ @@ -1624,11 +1624,28 @@ def can_generate(cls) -> bool: Returns: `bool`: Whether this model can generate sequences with `.generate()`. """ - # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. - # Alternativelly, the model can also have a custom `generate` function. - if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): - return False - return True + # Directly inherits `GenerationMixin` -> can generate + if "GenerationMixin" in str(cls.__bases__): + return True + # Model class overwrites `generate` -> can generate + if str(cls.__class__.__name__) in str(cls.generate): + return True + # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this + # was how we detected whether a model could generate. + if "GenerationMixin" not in str(cls.prepare_inputs_for_generation): + logger.warning_once( + f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly " + "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " + "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " + "to call `generate` and other related functions." + "\n -If you are the owner of the model architecture code, please modify your model class such that " + "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." + "\n -If you are not the owner of the model architecture class, please contact the model code owner " + "to update it." + ) + return True + # Otherwise, can't generate + return False @classmethod def _check_and_enable_flash_attn_2( diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index ac4958798b2c..f120f52cbf83 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -852,7 +853,7 @@ def forward( ) -class AlbertMLMHead(nn.Module): +class AlbertMLMHead(nn.Module, GenerationMixin): def __init__(self, config: AlbertConfig): super().__init__() diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 6b572b252779..7e33956c7b21 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -30,12 +30,17 @@ extract_commit_hash, find_adapter_config_file, is_peft_available, + is_torch_available, logging, requires_backends, ) from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings +if is_torch_available(): + from ...generation import GenerationMixin + + logger = logging.get_logger(__name__) @@ -432,6 +437,7 @@ def from_config(cls, config, **kwargs): else: cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): model_class = _get_model_class(config, cls._model_mapping) @@ -556,6 +562,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): model_class.register_for_auto_class(cls.__name__) else: cls.register(config.__class__, model_class, exist_ok=True) + model_class = add_generation_mixin_to_remote_model(model_class) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) @@ -705,6 +712,38 @@ def getattribute_from_module(module, attr): raise ValueError(f"Could not find {attr} in {transformers_module}!") +def add_generation_mixin_to_remote_model(model_class): + """ + Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. + + This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make + `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded + from the Hub may not have the `generate` method after we remove the inheritance. + """ + # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing + if "torch.nn.modules.module.Module" not in str(model_class.__mro__): + return model_class + + # 2. If it already **directly** inherits from GenerationMixin, do nothing + if "GenerationMixin" in str(model_class.__bases__): + return model_class + + # 3. If the class name has a suffix that indicates that it should be able to generate, add `GenerationMixin` to + # the class inheritance. Otherwise, do nothing. + terminations_with_generation_support = [ + "ForCausalLM", + "ForConditionalGeneration", + "ForSpeechSeq2Seq", + "ForVision2Seq", + ] + if any(model_class.__name__.endswith(termination) for termination in terminations_with_generation_support): + model_class_with_generation_mixin = type( + model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} + ) + return model_class_with_generation_mixin + return model_class + + class _LazyAutoMapping(OrderedDict): """ " A mapping config to object (model or tokenizer for instance) that will load keys and values when it is accessed. diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index ac67ef4b37e4..1f5416c65a6c 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import functional as F +from ...generation import GenerationMixin from ...generation.logits_process import ( AlternatingCodebooksLogitsProcessor, BarkEosPrioritizerLogitsProcessor, @@ -546,7 +547,7 @@ def device(self) -> torch.device: # GPT2-like autoregressive model -class BarkCausalModel(BarkPreTrainedModel): +class BarkCausalModel(BarkPreTrainedModel, GenerationMixin): config_class = BarkSubModelConfig def __init__(self, config): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fa928d05caa8..2e4e6dcaeb2d 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1557,7 +1558,7 @@ def forward( @add_start_docstrings( "The BART Model with a language modeling head. Can be used for summarization.", BART_START_DOCSTRING ) -class BartForConditionalGeneration(BartPreTrainedModel): +class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2010,7 +2011,7 @@ def forward(self, *args, **kwargs): """, BART_START_DOCSTRING, ) -class BartForCausalLM(BartPreTrainedModel): +class BartForCausalLM(BartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 850e93ca59fb..ed0d58446f46 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa, @@ -1280,7 +1281,7 @@ def forward( @add_start_docstrings( """Bert Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_START_DOCSTRING ) -class BertLMHeadModel(BertPreTrainedModel): +class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index a5fb3d053115..8496d1f6072f 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer @@ -863,7 +864,7 @@ def _tie_weights(self): """BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.""", BERT_GENERATION_START_DOCSTRING, ) -class BertGenerationDecoder(BertGenerationPreTrainedModel): +class BertGenerationDecoder(BertGenerationPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index a6b1660d5ae1..41045cb5f000 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -2495,7 +2496,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING ) -class BigBirdForCausalLM(BigBirdPreTrainedModel): +class BigBirdForCausalLM(BigBirdPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9f8e3cd19cd8..e26dce1edfc2 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2436,7 +2437,7 @@ def forward( BIGBIRD_PEGASUS_START_DOCSTRING, ) # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS -class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"] @@ -2882,7 +2883,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): +class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 020f52833d5b..6e346016cfe1 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -596,7 +597,7 @@ def forward( @add_start_docstrings( """BioGPT Model with a `language modeling` head on top for CLM fine-tuning.""", BIOGPT_START_DOCSTRING ) -class BioGptForCausalLM(BioGptPreTrainedModel): +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 12d259fde71e..4ea5926d854c 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -26,6 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1196,7 +1197,7 @@ def forward( @add_start_docstrings( "The Blenderbot Model with a language modeling head. Can be used for summarization.", BLENDERBOT_START_DOCSTRING ) -class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): +class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1397,7 +1398,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill -class BlenderbotForCausalLM(BlenderbotPreTrainedModel): +class BlenderbotForCausalLM(BlenderbotPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index aa0e38bd8e91..3e378f483a31 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1163,7 +1164,7 @@ def forward( "The BlenderbotSmall Model with a language modeling head. Can be used for summarization.", BLENDERBOT_SMALL_START_DOCSTRING, ) -class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] @@ -1349,7 +1350,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M -class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): +class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 46e3a6005b0a..04bb0dc76c87 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -25,6 +25,7 @@ from torch.nn.functional import normalize from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1027,7 +1028,7 @@ def forward( """, BLIP_START_DOCSTRING, ) -class BlipForConditionalGeneration(BlipPreTrainedModel): +class BlipForConditionalGeneration(BlipPreTrainedModel, GenerationMixin): config_class = BlipConfig _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"] main_input_name = "pixel_values" diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index a800ba89825d..78384e6ce2f7 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -808,7 +809,7 @@ def forward( # Adapted from https://github.com/salesforce/BLIP/blob/main/models/med.py#L811 -class BlipTextLMHeadModel(BlipTextPreTrainedModel): +class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index e89576c67ecc..1527685491de 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1583,7 +1584,7 @@ def forward( """, BLIP_2_START_DOCSTRING, ) -class Blip2ForConditionalGeneration(Blip2PreTrainedModel): +class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin): config_class = Blip2Config main_input_name = "pixel_values" diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 70e748343561..8159bfd145e5 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -26,6 +26,7 @@ from ...cache_utils import Cache, DynamicCache, StaticCache from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -856,7 +857,7 @@ def _update_causal_mask( """, BLOOM_START_DOCSTRING, ) -class BloomForCausalLM(BloomPreTrainedModel): +class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: BloomConfig): diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index f050b56d1178..dc2fee4625f5 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1399,7 +1400,7 @@ def forward( """CamemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", CAMEMBERT_START_DOCSTRING ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->Camembert, ROBERTA->CAMEMBERT, FacebookAI/roberta-base->almanach/camembert-base -class CamembertForCausalLM(CamembertPreTrainedModel): +class CamembertForCausalLM(CamembertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca95..c631181f00c5 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1496,7 +1497,7 @@ def _update_causal_mask( "Chameleon Model with a head on top used for outputting logits for next token prediction.", CHAMELEON_START_DOCSTRING, ) -class ChameleonForConditionalGeneration(ChameleonPreTrainedModel): +class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index b6d025a0b8e2..69d85ee778b7 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -26,7 +26,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1278,7 +1278,7 @@ def forward( "The CLVP decoder model with a language modelling head on top.", CLVP_START_DOCSTRING, ) -class ClvpForCausalLM(ClvpPreTrainedModel): +class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) @@ -1509,7 +1509,7 @@ def _reorder_cache( "together to filter out the best speech_ids.", CLVP_START_DOCSTRING, ) -class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): +class ClvpModelForConditionalGeneration(ClvpPreTrainedModel, GenerationMixin): config_class = ClvpConfig def __init__(self, config: ClvpConfig): diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index bfa591f7bdaf..8ffee7b0c69a 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel @@ -697,7 +698,7 @@ def _update_causal_mask( """, CODEGEN_START_DOCSTRING, ) -class CodeGenForCausalLM(CodeGenPreTrainedModel): +class CodeGenForCausalLM(CodeGenPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 6912a4596370..19964bc57e0c 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -32,6 +32,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -972,7 +973,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere -class CohereForCausalLM(CoherePreTrainedModel): +class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Ignore copy diff --git a/src/transformers/models/cpmant/modeling_cpmant.py b/src/transformers/models/cpmant/modeling_cpmant.py index c8a313505251..964d0bbfd145 100755 --- a/src/transformers/models/cpmant/modeling_cpmant.py +++ b/src/transformers/models/cpmant/modeling_cpmant.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging @@ -736,7 +737,7 @@ def forward( """, CPMANT_START_DOCSTRING, ) -class CpmAntForCausalLM(CpmAntPreTrainedModel): +class CpmAntForCausalLM(CpmAntPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: CpmAntConfig): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index bbf3b10a62ec..6d921621d47d 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -22,6 +22,7 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_linear_layer @@ -503,7 +504,7 @@ def forward( """, CTRL_START_DOCSTRING, ) -class CTRLLMHeadModel(CTRLPreTrainedModel): +class CTRLLMHeadModel(CTRLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index a41fdfb56ed1..fcddeab7a595 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -866,7 +867,7 @@ def forward( @add_start_docstrings( """Data2VecText Model with a `language modeling` head on top for CLM fine-tuning.""", DATA2VECTEXT_START_DOCSTRING ) -class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): +class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 9684fd174733..a779609b31b0 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1223,7 +1224,7 @@ def _update_causal_mask( @add_start_docstrings("The DBRX Model transformer for causal language modeling.", DBRX_START_DOCSTRING) -class DbrxForCausalLM(DbrxPreTrainedModel): +class DbrxForCausalLM(DbrxPreTrainedModel, GenerationMixin): def __init__(self, config: DbrxConfig): super().__init__(config) self.transformer = DbrxModel(config) diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index dd017170bef9..a200d716d451 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithCrossAttentions, BaseModelOutputWithPastAndCrossAttentions, @@ -1524,7 +1525,7 @@ def forward( @add_start_docstrings( """ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING ) -class ElectraForCausalLM(ElectraPreTrainedModel): +class ElectraForCausalLM(ElectraPreTrainedModel, GenerationMixin): _tied_weights_keys = ["generator_lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 6a0a26a5cbe5..6d81c97da023 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1081,7 +1082,7 @@ def forward( @add_start_docstrings( """Ernie Model with a `language modeling` head on top for CLM fine-tuning.""", ERNIE_START_DOCSTRING ) -class ErnieForCausalLM(ErniePreTrainedModel): +class ErnieForCausalLM(ErniePreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index a9acd171c3ae..64ec0ce662ae 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -25,6 +25,7 @@ from ...activations import get_activation from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1220,7 +1221,7 @@ def _update_causal_mask( "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).", FALCON_START_DOCSTRING, ) -class FalconForCausalLM(FalconPreTrainedModel): +class FalconForCausalLM(FalconPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: FalconConfig): diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 07374fe1dfd7..e77bcd96e6f2 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -702,7 +703,7 @@ def forward( FALCONMAMBA_START_DOCSTRING, ) # Copied from transformers.models.mamba.modeling_mamba.MambaForCausalLM with MAMBA->FALCONMAMBA,Mamba->FalconMamba,mamba->falcon_mamba,FalconMambaCache->MambaCache -class FalconMambaForCausalLM(FalconMambaPreTrainedModel): +class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/flaubert/modeling_flaubert.py b/src/transformers/models/flaubert/modeling_flaubert.py index 50c6f7ede222..ef1501e78035 100644 --- a/src/transformers/models/flaubert/modeling_flaubert.py +++ b/src/transformers/models/flaubert/modeling_flaubert.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -644,7 +645,7 @@ def forward( FLAUBERT_START_DOCSTRING, ) # Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert -class FlaubertWithLMHeadModel(FlaubertPreTrainedModel): +class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 179408aba38e..4d50f9bb5925 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -35,6 +35,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, @@ -1173,7 +1174,7 @@ def set_output_embeddings(self, value): @add_start_docstrings( "The FSMT Model with a language modeling head. Can be used for summarization.", FSMT_START_DOCSTRING ) -class FSMTForConditionalGeneration(PretrainedFSMTModel): +class FSMTForConditionalGeneration(PretrainedFSMTModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["decoder.embed_tokens.weight", "decoder.output_projection.weight"] diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index b4b6330d0d86..728c3f099c37 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -20,6 +20,7 @@ import torch.utils.checkpoint from torch import nn +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...models.auto.modeling_auto import AutoModelForCausalLM @@ -145,7 +146,7 @@ def _init_weights(self, module): "Fuyu Model with a language modeling head on top for causal language model conditioned on image patches and text.", FUYU_START_DOCSTRING, ) -class FuyuForCausalLM(FuyuPreTrainedModel): +class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin): def __init__(self, config: FuyuConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index fcdb0f0b3d7d..708b6021aec4 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -520,7 +521,7 @@ def forward( # Example where we ony modify the docstring and call super -class GemmaForCausalLM(LlamaForCausalLM): +class GemmaForCausalLM(LlamaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 62917c73f332..e6701904981f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -989,7 +990,7 @@ def _update_causal_mask( return causal_mask -class GemmaForCausalLM(GemmaPreTrainedModel): +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 0e300c6337e2..6d6e318fe207 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -33,6 +33,7 @@ ) from ...cache_utils import Cache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging @@ -476,7 +477,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(GemmaForCausalLM): +class Gemma2ForCausalLM(GemmaForCausalLM, GenerationMixin): def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index bf6ff76189d4..e13624c3cc41 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, HybridCache +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -943,7 +944,7 @@ def _update_causal_mask( return causal_mask -class Gemma2ForCausalLM(Gemma2PreTrainedModel): +class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 4807289c927c..a13dbc2b3e3f 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...file_utils import ModelOutput +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1318,7 +1319,7 @@ def forward( @add_start_docstrings( """GIT Model with a `language modeling` head on top for autoregressive language modeling.""", GIT_START_DOCSTRING ) -class GitForCausalLM(GitPreTrainedModel): +class GitForCausalLM(GitPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8dfbfb906444..e99f4b126246 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1182,7 +1183,7 @@ def forward( """, GPT2_START_DOCSTRING, ) -class GPT2LMHeadModel(GPT2PreTrainedModel): +class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): @@ -1384,7 +1385,7 @@ def _reorder_cache( """, GPT2_START_DOCSTRING, ) -class GPT2DoubleHeadsModel(GPT2PreTrainedModel): +class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0f927a72469d..ca1c03fcd9f9 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -1040,7 +1041,7 @@ def forward( """, GPT_BIGCODE_START_DOCSTRING, ) -class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): +class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index e59853677f83..8a52efc95421 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -24,6 +24,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -912,7 +913,7 @@ def _update_causal_mask( """, GPT_NEO_START_DOCSTRING, ) -class GPTNeoForCausalLM(GPTNeoPreTrainedModel): +class GPTNeoForCausalLM(GPTNeoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7be35c0d137d..2ead484932dc 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -30,6 +30,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1079,7 +1080,7 @@ def _update_causal_mask( @add_start_docstrings( """GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING ) -class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): +class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b9c4cad0fdc5..e3fa236aa307 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import logging @@ -589,7 +590,7 @@ def forward( """GPTNeoXJapanese Model with a `language modeling` head on top for Classifier Model fine-tuning.""", GPT_NEOX_JAPANESE_START_DOCSTRING, ) -class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): +class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel, GenerationMixin): _tied_weights_keys = ["embed_out.weight"] def __init__(self, config): diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 1408bfe8a61d..738c12cb8535 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1006,7 +1007,7 @@ def _update_causal_mask( """, GPTJ_START_DOCSTRING, ) -class GPTJForCausalLM(GPTJPreTrainedModel): +class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 5d4f8e408eb0..8ec5141a7d02 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -23,11 +23,12 @@ from torch import nn from torch.nn import CrossEntropyLoss -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -1438,7 +1439,7 @@ def forward( """The Idefics2 Model with a language modeling head. It is made up a SigLIP vision encoder, with a language modeling head on top. """, IDEFICS2_START_DOCSTRING, ) -class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): +class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5d59a4ed90e4..a027876b43d3 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -880,7 +881,7 @@ def forward( """, IMAGEGPT_START_DOCSTRING, ) -class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): +class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: ImageGPTConfig): diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index f59f72a6699c..6392c40a0c5e 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1274,7 +1275,7 @@ def forward( """, INSTRUCTBLIP_START_DOCSTRING, ) -class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel): +class InstructBlipForConditionalGeneration(InstructBlipPreTrainedModel, GenerationMixin): config_class = InstructBlipConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py index 506da83c5322..be569abc9137 100644 --- a/src/transformers/models/instructblipvideo/diff_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/diff_instructblipvideo.py @@ -45,6 +45,7 @@ InstructBlipVisionModel, ) +from ...generation import GenerationMixin from ...utils import logging @@ -128,7 +129,7 @@ class InstructBlipVideoQFormerModel(InstructBlipQFormerModel): pass -class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): +class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration, GenerationMixin): def forward( self, pixel_values: torch.FloatTensor, diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index 701402241d4a..1d2ce82c4f64 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -30,6 +30,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1283,7 +1284,7 @@ def forward( """, INSTRUCTBLIPVIDEO_START_DOCSTRING, ) -class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel): +class InstructBlipVideoForConditionalGeneration(InstructBlipVideoPreTrainedModel, GenerationMixin): config_class = InstructBlipVideoConfig main_input_name = "pixel_values" diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 2722a5e06909..da4274436ee6 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1386,7 +1387,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # Adapted from transformers.models.mixtral.modeling_mixtral.MixtralForCausalLM with MIXTRAL->JAMBA, Mixtral->Jamba -class JambaForCausalLM(JambaPreTrainedModel): +class JambaForCausalLM(JambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: JambaConfig): diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 162478a7258c..18176e57f18c 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1195,7 +1196,7 @@ def _update_causal_mask( return causal_mask -class JetMoeForCausalLM(JetMoePreTrainedModel): +class JetMoeForCausalLM(JetMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 69641790b2db..90e21ed2f558 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1521,7 +1522,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel): +class Kosmos2TextForCausalLM(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2TextConfig _tied_weights_keys = ["lm_head.weight"] @@ -1864,7 +1865,7 @@ def forward( """, KOSMOS2_START_DOCSTRING, ) -class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel): +class Kosmos2ForConditionalGeneration(Kosmos2PreTrainedModel, GenerationMixin): config_class = Kosmos2Config main_input_name = "pixel_values" _tied_weights_keys = ["text_model.lm_head.weight"] diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 41b6c0a2bea2..f96bfd82b526 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _create_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -2298,7 +2299,7 @@ def forward( @add_start_docstrings( "The LED Model with a language modeling head. Can be used for summarization.", LED_START_DOCSTRING ) -class LEDForConditionalGeneration(LEDPreTrainedModel): +class LEDForConditionalGeneration(LEDPreTrainedModel, GenerationMixin): base_model_prefix = "led" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 022ae5ce74c4..5e47d5f08322 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -28,6 +28,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1097,7 +1098,7 @@ def _update_causal_mask( return causal_mask -class LlamaForCausalLM(LlamaPreTrainedModel): +class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 394c80edb540..368f10759413 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -240,7 +241,7 @@ def _supports_sdpa(self): """The LLAVA model which consists of a vision backbone and a language model.""", LLAVA_START_DOCSTRING, ) -class LlavaForConditionalGeneration(LlavaPreTrainedModel): +class LlavaForConditionalGeneration(LlavaPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 723d54c92dd9..dd506ae38aa0 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -346,7 +347,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_START_DOCSTRING, ) -class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel): +class LlavaNextForConditionalGeneration(LlavaNextPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaNextConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index b4018db586e7..765ed469c4b8 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -30,6 +30,7 @@ ) from ...cache_utils import Cache +from ...generation import GenerationMixin from ...utils import ( logging, replace_return_docstrings, @@ -219,7 +220,7 @@ class LlavaNextVideoMultiModalProjector(LlavaNextMultiModalProjector): pass -class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration): +class LlavaNextVideoForConditionalGeneration(LlavaNextForConditionalGeneration, GenerationMixin): def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): super().__init__(config, **super_kwargs) self.vision_resampler = LlavaNextVideoPooler(config) diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3430fbe590aa..87c5fae3a926 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -29,11 +29,12 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -387,7 +388,7 @@ def _supports_sdpa(self): """The LLAVA-NeXT model which consists of a vision backbone and a language model.""", LLAVA_NEXT_VIDEO_START_DOCSTRING, ) -class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel): +class LlavaNextVideoForConditionalGeneration(LlavaNextVideoPreTrainedModel, GenerationMixin): def __init__( self, config: LlavaNextVideoConfig, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index b2a6ed11ca57..8f9385c0fe76 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1900,7 +1901,7 @@ def forward( @add_start_docstrings("""LONGT5 Model with a `language modeling` head on top.""", LONGT5_START_DOCSTRING) -class LongT5ForConditionalGeneration(LongT5PreTrainedModel): +class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 23a855fff256..86a4378da29c 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1342,7 +1343,7 @@ def forward( @add_start_docstrings( "The M2M100 Model with a language modeling head. Can be used for summarization.", M2M_100_START_DOCSTRING ) -class M2M100ForConditionalGeneration(M2M100PreTrainedModel): +class M2M100ForConditionalGeneration(M2M100PreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 14a3dea1d1cc..6bed1caab23a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import MambaCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -657,7 +658,7 @@ def forward( """, MAMBA_START_DOCSTRING, ) -class MambaForCausalLM(MambaPreTrainedModel): +class MambaForCausalLM(MambaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mamba2/modeling_mamba2.py b/src/transformers/models/mamba2/modeling_mamba2.py index 69390ea9ad2b..f66ad1ebaea5 100644 --- a/src/transformers/models/mamba2/modeling_mamba2.py +++ b/src/transformers/models/mamba2/modeling_mamba2.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -931,7 +932,7 @@ def forward( """, MAMBA2_START_DOCSTRING, ) -class Mamba2ForCausalLM(Mamba2PreTrainedModel): +class Mamba2ForCausalLM(Mamba2PreTrainedModel, GenerationMixin): _tied_weights_keys = [] def __init__(self, config): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 2045f673540f..cb26bb11e094 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1224,7 +1225,7 @@ def forward( @add_start_docstrings( "The Marian Model with a language modeling head. Can be used for summarization.", MARIAN_START_DOCSTRING ) -class MarianMTModel(MarianPreTrainedModel): +class MarianMTModel(MarianPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ "final_logits_bias", @@ -1504,7 +1505,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en -class MarianForCausalLM(MarianPreTrainedModel): +class MarianForCausalLM(MarianPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 6cad7b08f994..7130f3e0451c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1389,7 +1390,7 @@ def forward( "The MBART Model with a language modeling head. Can be used for summarization, after fine-tuning the pretrained models.", MBART_START_DOCSTRING, ) -class MBartForConditionalGeneration(MBartPreTrainedModel): +class MBartForConditionalGeneration(MBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["model.encoder.embed_tokens.weight", "model.decoder.embed_tokens.weight", "lm_head.weight"] @@ -1830,7 +1831,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25 -class MBartForCausalLM(MBartPreTrainedModel): +class MBartForCausalLM(MBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 16641655e203..20506f91bcbc 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1110,7 +1111,7 @@ def forward( """MegatronBert Model with a `language modeling` head on top for CLM fine-tuning.""", MEGATRON_BERT_START_DOCSTRING, ) -class MegatronBertForCausalLM(MegatronBertPreTrainedModel): +class MegatronBertForCausalLM(MegatronBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder"] def __init__(self, config): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 240e229e0bb0..99f8fc35c77f 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -947,7 +948,7 @@ def _update_causal_mask( return causal_mask -class MistralForCausalLM(MistralPreTrainedModel): +class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 919f32abc7fc..2d0985194512 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1180,7 +1181,7 @@ def _update_causal_mask( return causal_mask -class MixtralForCausalLM(MixtralPreTrainedModel): +class MixtralForCausalLM(MixtralPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 85579636dcc4..9c826c370b75 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -24,6 +24,7 @@ from torch.nn import functional as F from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -500,7 +501,7 @@ def forward( """, MPT_START_DOCSTRING, ) -class MptForCausalLM(MptPreTrainedModel): +class MptForCausalLM(MptPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: MptConfig): diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 54943cf982dd..6a7406f11b5b 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1550,7 +1551,7 @@ def forward( @add_start_docstrings("""MT5 Model with a `language modeling` head on top.""", MT5_START_DOCSTRING) -class MT5ForConditionalGeneration(MT5PreTrainedModel): +class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f720faac038e..3109c4fc2431 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1206,7 +1211,7 @@ def forward( "The MusicGen decoder model with a language modelling head on top.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForCausalLM(MusicgenPreTrainedModel): +class MusicgenForCausalLM(MusicgenPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenDecoderConfig): super().__init__(config) @@ -1658,7 +1663,7 @@ def generate( "for music generation tasks with one or both of text and audio prompts.", MUSICGEN_START_DOCSTRING, ) -class MusicgenForConditionalGeneration(PreTrainedModel): +class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenConfig base_model_prefix = "encoder_decoder" main_input_name = "input_ids" diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a8a8fe960989..c8345870b253 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -26,9 +26,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...generation.configuration_utils import GenerationConfig, GenerationMode -from ...generation.logits_process import ClassifierFreeGuidanceLogitsProcessor, LogitsProcessorList -from ...generation.stopping_criteria import StoppingCriteriaList +from ...generation import ( + ClassifierFreeGuidanceLogitsProcessor, + GenerationConfig, + GenerationMixin, + GenerationMode, + LogitsProcessorList, + StoppingCriteriaList, +) from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1117,7 +1122,7 @@ def forward( MUSICGEN_MELODY_START_DOCSTRING, ) # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenForCausalLM with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody,MusicGen->Musicgen Melody -class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel): +class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel, GenerationMixin): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__(config) @@ -1585,7 +1590,7 @@ def generate( decoder (`Optional[MusicgenMelodyForCausalLM]`, *optional*): MusicGen Melody decoder used to generate audio codes. """, ) -class MusicgenMelodyForConditionalGeneration(PreTrainedModel): +class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): config_class = MusicgenMelodyConfig main_input_name = "input_ids" supports_gradient_checkpointing = True diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 319f1760cef9..c47c4b26b539 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1351,7 +1352,7 @@ def forward( @add_start_docstrings( "The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING ) -class MvpForConditionalGeneration(MvpPreTrainedModel): +class MvpForConditionalGeneration(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: MvpConfig): @@ -1791,7 +1792,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class MvpForCausalLM(MvpPreTrainedModel): +class MvpForCausalLM(MvpPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 548732b371a5..93c053b27ac1 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -979,7 +980,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronForCausalLM(NemotronPreTrainedModel): +class NemotronForCausalLM(NemotronPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 2bec0fb84dce..c33844da0f55 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1604,7 +1605,7 @@ def forward( @add_start_docstrings( "The NllbMoe Model with a language modeling head. Can be used for summarization.", NLLB_MOE_START_DOCSTRING ) -class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): +class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 587ef92e4585..7e31373e4707 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1018,7 +1019,7 @@ def _update_causal_mask( # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->OLMO,Llama->Olmo -class OlmoForCausalLM(OlmoPreTrainedModel): +class OlmoForCausalLM(OlmoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 2b24850f3f0c..0aa02a6f5d84 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu_new, silu +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, SequenceSummary from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer @@ -524,7 +525,7 @@ def forward( """, OPENAI_GPT_START_DOCSTRING, ) -class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): +class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f058171778e..f7782b8f6172 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -22,6 +22,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -882,7 +883,7 @@ def forward( ) -class OPTForCausalLM(OPTPreTrainedModel): +class OPTForCausalLM(OPTPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 8eff8cce50cc..4c17e6a1bf4a 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -22,6 +22,7 @@ from torch import nn from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -236,7 +237,7 @@ def _supports_sdpa(self): """The PALIGEMMA model which consists of a vision backbone and a language model.""", PALIGEMMA_START_DOCSTRING, ) -class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel): +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): def __init__(self, config: PaliGemmaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config=config.vision_config) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 42cef3a63558..03d1574e9be2 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1244,7 +1245,7 @@ def forward( @add_start_docstrings( "The PEGASUS Model with a language modeling head. Can be used for summarization.", PEGASUS_START_DOCSTRING ) -class PegasusForConditionalGeneration(PegasusPreTrainedModel): +class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1456,7 +1457,7 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -class PegasusForCausalLM(PegasusPreTrainedModel): +class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 6d9072777bf6..77c0b32e6433 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,6 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1464,7 +1465,7 @@ def forward( @add_start_docstrings("The PEGASUS-X for conditional generation (e.g. summarization).", PEGASUS_X_START_DOCSTRING) -class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): +class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 90a7f355992e..6ad0377339d1 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -828,7 +829,7 @@ def _update_causal_mask( return causal_mask -class PersimmonForCausalLM(PersimmonPreTrainedModel): +class PersimmonForCausalLM(PersimmonPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->PERSIMMON,Llama->Persimmon diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3c647a9d8d81..8439cd0fe2c6 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1113,7 +1114,7 @@ def _update_causal_mask( return causal_mask -class PhiForCausalLM(PhiPreTrainedModel): +class PhiForCausalLM(PhiPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 4652294980fd..7c960bcbe4e7 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1153,7 +1154,7 @@ def _update_causal_mask( return causal_mask -class Phi3ForCausalLM(Phi3PreTrainedModel): +class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 94d882c80566..f209d7d88287 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -1553,7 +1554,7 @@ def forward( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", PIX2STRUCT_START_DOCSTRING, ) -class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): +class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMixin): config_class = Pix2StructConfig main_input_name = "flattened_patches" _tied_weights_keys = ["decoder.lm_head.weight"] diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93d91e160089..d15e079770a3 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -1254,7 +1255,7 @@ def forward( "The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.", PLBART_START_DOCSTRING, ) -class PLBartForConditionalGeneration(PLBartPreTrainedModel): +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): base_model_prefix = "model" _keys_to_ignore_on_load_missing = ["final_logits_bias"] _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] @@ -1568,7 +1569,7 @@ def forward(self, *args, **kwargs): # Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base -class PLBartForCausalLM(PLBartPreTrainedModel): +class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index c769cff3c454..e6488898e8a9 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,6 +25,7 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1001,7 +1002,7 @@ def forward(self, feature, index_value, embedding_offset): @add_start_docstrings("""Pop2Piano Model with a `language modeling` head on top.""", Pop2Piano_START_DOCSTRING) -class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel): +class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: Pop2PianoConfig): diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 96fa2e2c12e5..7d23088f6e57 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -26,6 +26,7 @@ from torch.nn import LayerNorm from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -1856,7 +1857,7 @@ def forward( "The ProphetNet Model with a language modeling head. Can be used for sequence generation tasks.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): +class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.word_embeddings.weight", "decoder.word_embeddings.weight", "lm_head.weight"] def __init__(self, config: ProphetNetConfig): @@ -2073,7 +2074,7 @@ def get_decoder(self): " language modeling.", PROPHETNET_START_DOCSTRING, ) -class ProphetNetForCausalLM(ProphetNetPreTrainedModel): +class ProphetNetForCausalLM(ProphetNetPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "prophetnet.word_embeddings.weight", "prophetnet.decoder.word_embeddings.weight", diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 59413730ad4a..f4b22a4a5129 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1018,7 +1019,7 @@ def _update_causal_mask( return causal_mask -class Qwen2ForCausalLM(Qwen2PreTrainedModel): +class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 14235bf0aaf6..bf48e1c6a97e 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -22,10 +22,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN from ...cache_utils import Cache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -855,7 +856,7 @@ def forward(self, audio_features): """The QWEN2AUDIO model which consists of a audio backbone and a language model.""", QWEN2AUDIO_START_DOCSTRING, ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel): +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation) diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index c08735f45345..559e52324b41 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1191,7 +1192,7 @@ def _update_causal_mask( return causal_mask -class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): +class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 6ab813ad9ade..82076b0b96a9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -31,6 +31,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1306,7 +1307,7 @@ def _update_causal_mask( """ -class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel): +class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a8f076fad79c..e04929489984 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput from ...modeling_utils import PreTrainedModel @@ -777,7 +778,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->RECURRENTGEMMA,Llama->RecurrentGemma,llama->gemma -class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel): +class RecurrentGemmaForCausalLM(RecurrentGemmaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 2e98a07217e6..37b675539e66 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -29,6 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import CausalLMOutput, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward @@ -2183,7 +2184,7 @@ def _pad_to_mult_of_chunk_length( @add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING) -class ReformerModelWithLMHead(ReformerPreTrainedModel): +class ReformerModelWithLMHead(ReformerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 31f7e3dce454..99016c1be429 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1002,7 +1003,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RemBERT Model with a `language modeling` head on top for CLM fine-tuning.""", REMBERT_START_DOCSTRING ) -class RemBertForCausalLM(RemBertPreTrainedModel): +class RemBertForCausalLM(RemBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index c15636f4b65c..29081b89a5e6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -860,7 +861,7 @@ def forward( @add_start_docstrings( """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING ) -class RobertaForCausalLM(RobertaPreTrainedModel): +class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index 3e592f387768..15e485e2e192 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -856,7 +857,7 @@ def forward( ROBERTA_PRELAYERNORM_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with FacebookAI/roberta-base->andreasmadsen/efficient_mlm_m0.40,ROBERTA->ROBERTA_PRELAYERNORM,Roberta->RobertaPreLayerNorm,roberta->roberta_prelayernorm, RobertaPreLayerNormTokenizer->RobertaTokenizer -class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): +class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index c4efbf16323e..2969f7f1a3d0 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -1403,7 +1404,7 @@ def prepare_inputs_for_generation( @add_start_docstrings( """RoCBert Model with a `language modeling` head on top for CLM fine-tuning.""", ROC_BERT_START_DOCSTRING ) -class RoCBertForCausalLM(RoCBertPreTrainedModel): +class RoCBertForCausalLM(RoCBertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"] # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 69588ff743a0..c98b525abe08 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -1033,7 +1034,7 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_ @add_start_docstrings( """RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING ) -class RoFormerForCausalLM(RoFormerPreTrainedModel): +class RoFormerForCausalLM(RoFormerPreTrainedModel, GenerationMixin): _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] def __init__(self, config): diff --git a/src/transformers/models/rwkv/modeling_rwkv.py b/src/transformers/models/rwkv/modeling_rwkv.py index 7dec1f26e1a3..8361afbf727b 100644 --- a/src/transformers/models/rwkv/modeling_rwkv.py +++ b/src/transformers/models/rwkv/modeling_rwkv.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -751,7 +752,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): """, RWKV_START_DOCSTRING, ) -class RwkvForCausalLM(RwkvPreTrainedModel): +class RwkvForCausalLM(RwkvPreTrainedModel, GenerationMixin): _tied_weights_keys = ["head.weight"] def __init__(self, config): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index a79d1d4cf2b9..dcc665f48da3 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2150,7 +2151,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel): +class SeamlessM4TTextToUnitForConditionalGeneration(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2656,7 +2657,7 @@ def remove_weight_norm(self): "The text-to-text SeamlessM4T Model transformer which can be used for T2TT.", SEAMLESS_M4T_START_DOCSTRING, ) -class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel): +class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index a53f544bb34f..0157996ba182 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...deepspeed import is_deepspeed_zero3_enabled +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -2439,7 +2440,7 @@ def forward( embed_tokens_decoder (`nn.Embedding`, *optional*): input embedding of the decoder. """, ) -class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2TextToUnitForConditionalGeneration(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = [ "vocoder", "speech_encoder", @@ -2914,7 +2915,7 @@ def remove_weight_norm(self): SEAMLESS_M4T_V2_START_DOCSTRING, ) # Copied from transformers.models.seamless_m4t.modeling_seamless_m4t.SeamlessM4TForTextToText with SeamlessM4T->SeamlessM4Tv2,SeamlessM4Tv2Tokenizer->SeamlessM4TTokenizer, SeamlessM4Tv2Processor->SeamlessM4TProcessor -class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel): +class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_missing = ["speech_encoder", "t2u_model", "vocoder"] main_input_name = "input_ids" diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 8353a172b212..bdd532fa25e8 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -22,6 +22,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -1207,7 +1208,7 @@ def forward( "The Speech2Text Model with a language modeling head. Can be used for summarization.", SPEECH_TO_TEXT_START_DOCSTRING, ) -class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): +class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 1ec4665fcfb7..0cf2448a75f1 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -1106,7 +1107,7 @@ def _update_causal_mask( # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM with PERSIMMON->STABLELM,Persimmon->StableLm -class StableLmForCausalLM(StableLmPreTrainedModel): +class StableLmForCausalLM(StableLmPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with LLAMA->STABLELM,Llama->StableLm diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 90603fd4e51e..b36c22c2e79e 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -993,7 +994,7 @@ def _update_causal_mask( # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2ForCausalLM with QWEN2->STARCODER2,Qwen2->Starcoder2 -class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): +class Starcoder2ForCausalLM(Starcoder2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c5797d4573b7..96b6c7334b15 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,6 +24,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -1456,7 +1457,7 @@ def forward( @add_start_docstrings( """SWITCH_TRANSFORMERS Model with a `language modeling` head on top.""", SWITCH_TRANSFORMERS_START_DOCSTRING ) -class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel): +class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel, GenerationMixin): _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: SwitchTransformersConfig): diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index a90101924c5b..43e3f3afa4a8 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1542,7 +1543,7 @@ def forward( @add_start_docstrings("""T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING) -class T5ForConditionalGeneration(T5PreTrainedModel): +class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin): _keys_to_ignore_on_load_unexpected = [ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 04eb40ab2a2f..67b97cf9c852 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -736,7 +737,7 @@ def forward(self, *args, **kwargs): " [`VisionEncoderDecoder`].", TROCR_START_DOCSTRING, ) -class TrOCRForCausalLM(TrOCRPreTrainedModel): +class TrOCRForCausalLM(TrOCRPreTrainedModel, GenerationMixin): _tied_weights_keys = ["output_projection.weight"] def __init__(self, config): diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 972248daaae5..3f95c10d24ce 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,6 +34,7 @@ ) from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( @@ -1679,7 +1680,7 @@ def forward( This class is based on [`T5ForConditionalGeneration`], extended to deal with images and layout (2D) data.""", UDOP_START_DOCSTRING, ) -class UdopForConditionalGeneration(UdopPreTrainedModel): +class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin): _tied_weights_keys = [ "encoder.embed_tokens.weight", "decoder.embed_tokens.weight", diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 3271689540b9..a7d1e5bacc65 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1101,7 +1102,7 @@ def forward( @add_start_docstrings("""UMT5 Model with a `language modeling` head on top.""", UMT5_START_DOCSTRING) -class UMT5ForConditionalGeneration(UMT5PreTrainedModel): +class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin): r""" Examples: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 425d46bd7741..bf31ed63ce53 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -238,7 +239,7 @@ def _supports_sdpa(self): """The VideoLlava model which consists of a vision backbone and a language model.""", VIDEO_LLAVA_START_DOCSTRING, ) -class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel): +class VideoLlavaForConditionalGeneration(VideoLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VideoLlavaConfig): super().__init__(config) self.video_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index b1df10fdb3dc..9d085ce47d5d 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -21,9 +21,10 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -242,7 +243,7 @@ def _supports_sdpa(self): VIPLLAVA_START_DOCSTRING, ) # Copied from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration with LLAVA->VIPLLAVA,Llava->VipLlava -class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel): +class VipLlavaForConditionalGeneration(VipLlavaPreTrainedModel, GenerationMixin): def __init__(self, config: VipLlavaConfig): super().__init__(config) self.vision_tower = AutoModel.from_config(config.vision_config) diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index d9a684183cab..95314125ab4d 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -25,7 +25,7 @@ from transformers.cache_utils import EncoderDecoderCache -from ...generation.configuration_utils import GenerationConfig +from ...generation import GenerationConfig, GenerationMixin from ...generation.logits_process import ( LogitsProcessorList, SuppressTokensAtBeginLogitsProcessor, @@ -172,7 +172,7 @@ def _pad_to_max_length( return sequences -class WhisperGenerationMixin: +class WhisperGenerationMixin(GenerationMixin): def _extract_token_timestamps(self, generate_outputs, alignment_heads, time_precision=0.02, num_frames=None): """ Calculates token-level timestamps using the encoder-decoder cross-attentions and dynamic time-warping (DTW) to diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 81f60edbfa98..ebb899ea7cf6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, @@ -1910,7 +1911,7 @@ def forward(self, *args, **kwargs): """, WHISPER_START_DOCSTRING, ) -class WhisperForCausalLM(WhisperPreTrainedModel): +class WhisperForCausalLM(WhisperPreTrainedModel, GenerationMixin): _tied_weights_keys = ["proj_out.weight"] main_input_name = "input_ids" diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 4f1693583494..3090bc2973cd 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -23,6 +23,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel @@ -696,7 +697,7 @@ def forward( """, XGLM_START_DOCSTRING, ) -class XGLMForCausalLM(XGLMPreTrainedModel): +class XGLMForCausalLM(XGLMPreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index 280383630987..3acec2353b69 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -27,6 +27,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, MaskedLMOutput, @@ -657,7 +658,7 @@ def forward(self, x, y=None): """, XLM_START_DOCSTRING, ) -class XLMWithLMHeadModel(XLMPreTrainedModel): +class XLMWithLMHeadModel(XLMPreTrainedModel, GenerationMixin): _tied_weights_keys = ["pred_layer.proj.weight"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 5fa4812d350d..7ec21c6a2ac2 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -24,6 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -864,7 +865,7 @@ def forward( XLM_ROBERTA_START_DOCSTRING, ) # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA -class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): +class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 86fdadeaef02..05aa006ffc9a 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -831,7 +832,7 @@ def forward( """XLM-RoBERTa-XL Model with a `language modeling` head on top for CLM fine-tuning.""", XLM_ROBERTA_XL_START_DOCSTRING, ) -class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): +class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] def __init__(self, config): diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index 5d424ebe12dd..7681fbafad6d 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -26,6 +26,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...generation import GenerationMixin from ...modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( @@ -1286,7 +1287,7 @@ def forward( """, XLNET_START_DOCSTRING, ) -class XLNetLMHeadModel(XLNetPreTrainedModel): +class XLNetLMHeadModel(XLNetPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_loss.weight"] def __init__(self, config): diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index b1ca8116a72a..71474cc9c45b 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -23,6 +23,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, @@ -956,7 +957,7 @@ def forward( "X-MOD Model with a `language modeling` head on top for CLM fine-tuning.", XMOD_START_DOCSTRING, ) -class XmodForCausalLM(XmodPreTrainedModel): +class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 363028c7f229..40f662e65463 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -66,6 +66,7 @@ BertModel, FunnelBaseModel, FunnelModel, + GenerationMixin, GPT2Config, GPT2LMHeadModel, ResNetBackbone, @@ -529,3 +530,20 @@ def test_attr_not_existing(self): _MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")]) _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model) + + def test_generation_mixin_added_inheritance(self): + """ + Tests that our inheritance patching for generate-compatible models works as expected. Without this feature, + old Hub models lose the ability to call `generate`. + """ + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/test_dynamic_model_generation", trust_remote_code=True + ) + self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM") + + # It inherits from GenerationMixin. This means it can `generate`. This check would fail from v4.50, if patching + # was not present. + self.assertTrue(isinstance(model, GenerationMixin)) + # More precisely, it directly inherits from GenerationMixin. This check would fail from v4.45, if patching + # was not present. + self.assertTrue("GenerationMixin" in str(model.__class__.__bases__)) From 1b2cd4e19fd33ed087791b73764dadaf916a65b4 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Aug 2024 18:49:39 +0000 Subject: [PATCH 2/9] better test --- src/transformers/models/auto/auto_factory.py | 14 +++++--------- tests/models/auto/test_modeling_auto.py | 10 +++++----- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 7e33956c7b21..8f4b2a0f5720 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -728,15 +728,11 @@ def add_generation_mixin_to_remote_model(model_class): if "GenerationMixin" in str(model_class.__bases__): return model_class - # 3. If the class name has a suffix that indicates that it should be able to generate, add `GenerationMixin` to - # the class inheritance. Otherwise, do nothing. - terminations_with_generation_support = [ - "ForCausalLM", - "ForConditionalGeneration", - "ForSpeechSeq2Seq", - "ForVision2Seq", - ] - if any(model_class.__name__.endswith(termination) for termination in terminations_with_generation_support): + # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or + # `prepare_inputs_for_generation` method. + has_custom_generate = "GenerationMixin" not in str(getattr(model_class, "generate")) + has_custom_prepare_inputs = "GenerationMixin" not in str(getattr(model_class, "prepare_inputs_for_generation")) + if has_custom_generate or has_custom_prepare_inputs: model_class_with_generation_mixin = type( model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} ) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 40f662e65463..631bf52d10fa 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -531,7 +531,7 @@ def test_attr_not_existing(self): _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model) - def test_generation_mixin_added_inheritance(self): + def test_custom_model_patched_generation_inheritance(self): """ Tests that our inheritance patching for generate-compatible models works as expected. Without this feature, old Hub models lose the ability to call `generate`. @@ -541,9 +541,9 @@ def test_generation_mixin_added_inheritance(self): ) self.assertTrue(model.__class__.__name__ == "NewModelForCausalLM") - # It inherits from GenerationMixin. This means it can `generate`. This check would fail from v4.50, if patching - # was not present. + # It inherits from GenerationMixin. This means it can `generate`. Because `PreTrainedModel` is scheduled to + # stop inheriting from `GenerationMixin` in v4.50, this check will fail if patching is not present. self.assertTrue(isinstance(model, GenerationMixin)) - # More precisely, it directly inherits from GenerationMixin. This check would fail from v4.45, if patching - # was not present. + # More precisely, it directly inherits from GenerationMixin. This check would fail prior to v4.45 (inheritance + # patching was added in v4.45) self.assertTrue("GenerationMixin" in str(model.__class__.__bases__)) From 8e018c8b4e9ff852e0f5a336e68a23af8cc02be1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Aug 2024 18:59:20 +0000 Subject: [PATCH 3/9] better warning --- src/transformers/modeling_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 93797b3c92d9..f55762d682b3 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1638,9 +1638,11 @@ def can_generate(cls) -> bool: "overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, " "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability " "to call `generate` and other related functions." - "\n -If you are the owner of the model architecture code, please modify your model class such that " + "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the " + "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes" + "\n - If you are the owner of the model architecture code, please modify your model class such that " "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)." - "\n -If you are not the owner of the model architecture class, please contact the model code owner " + "\n - If you are not the owner of the model architecture class, please contact the model code owner " "to update it." ) return True From 03e05b6a805f2fc700243a891dcd376113cc4c9a Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 29 Aug 2024 19:15:58 +0000 Subject: [PATCH 4/9] granite --- src/transformers/modeling_utils.py | 1 + src/transformers/models/granite/modeling_granite.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f55762d682b3..c80c4ba5f417 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1295,6 +1295,7 @@ def floating_point_ops( return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +# TODO (joao): remove `GenerationMixin` inheritance in v4.50 class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): r""" Base class for all models. diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index aee62fd249f3..a1fd00aea96f 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -976,7 +977,7 @@ def _update_causal_mask( return causal_mask -class GraniteForCausalLM(GranitePreTrainedModel): +class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Granite From 275b631bf3e271d8f83213210d09162484ead076 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 2 Sep 2024 09:14:07 +0000 Subject: [PATCH 5/9] make fixup --- tests/models/auto/test_modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 223aea6c79ab..22c504969c8e 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -531,7 +531,7 @@ def test_attr_not_existing(self): _MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")]) _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model) - + def test_dynamic_saving_from_local_repo(self): with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as tmp_dir_out: _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-custom-architecture") From 96202739c48f0c40a83daeda724da64d2bb66421 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 6 Sep 2024 16:25:53 +0000 Subject: [PATCH 6/9] PR comments; Add can_generate test --- src/transformers/modeling_utils.py | 4 +-- .../models/albert/modeling_albert.py | 4 +-- tests/utils/test_modeling_utils.py | 27 +++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 59438fe9ced2..6d1c2ab8012c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1628,8 +1628,8 @@ def can_generate(cls) -> bool: # Directly inherits `GenerationMixin` -> can generate if "GenerationMixin" in str(cls.__bases__): return True - # Model class overwrites `generate` -> can generate - if str(cls.__class__.__name__) in str(cls.generate): + # Model class overwrites `generate` (e.g. time series models) -> can generate + if str(cls.__name__) in str(cls.generate): return True # BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this # was how we detected whether a model could generate. diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 254d0c1bca99..989a890ddccb 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -937,7 +937,7 @@ def forward( ) -class AlbertMLMHead(nn.Module, GenerationMixin): +class AlbertMLMHead(nn.Module): def __init__(self, config: AlbertConfig): super().__init__() @@ -984,7 +984,7 @@ def forward(self, pooled_output: torch.Tensor) -> torch.Tensor: "Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING, ) -class AlbertForMaskedLM(AlbertPreTrainedModel): +class AlbertForMaskedLM(AlbertPreTrainedModel, GenerationMixin): _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] def __init__(self, config): diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index f78285fdb90d..91ab21496820 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -90,6 +90,7 @@ BertConfig, BertModel, CLIPTextModel, + GenerationMixin, PreTrainedModel, T5Config, T5ForConditionalGeneration, @@ -1715,6 +1716,32 @@ def test_isin_mps_friendly(self): torch.equal(torch.isin(random_ids, random_test_tensor), isin_mps_friendly(random_ids, random_test_tensor)) ) + def test_can_generate(self): + """Tests the behavior of `PreTrainedModel.can_generate` method.""" + # 1 - By default, a model CAN'T generate + self.assertFalse(BertModel.can_generate()) + + # 2 - The most common case for a model to be able to generate is to inherit from `GenerationMixin` directly + class DummyBertWithMixin(BertModel, GenerationMixin): + pass + + self.assertTrue(DummyBertWithMixin.can_generate()) + + # 3 - Alternatively, a model can implement a `generate` method + class DummyBertWithGenerate(BertModel): + def generate(self): + pass + + self.assertTrue(DummyBertWithGenerate.can_generate()) + + # 4 - BC: models with a custom `prepare_inputs_for_generation` can generate (it was assumed they inherited + # `GenerationMixin`) + class DummyBertWithPrepareInputs(BertModel): + def prepare_inputs_for_generation(self): + pass + + self.assertTrue(DummyBertWithPrepareInputs.can_generate()) + @slow @require_torch From 7aa36fbbbd004e00e0dd4eface140dbe22c98f4f Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Sun, 8 Sep 2024 11:02:06 +0000 Subject: [PATCH 7/9] update llama_onevision --- .../models/llava_onevision/modeling_llava_onevision.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 9496c3857027..3a3f01f86bd4 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -23,10 +23,11 @@ import torch.utils.checkpoint from torch import nn -from ... import PreTrainedModel from ...activations import ACT2FN +from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput +from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, logging, @@ -350,7 +351,7 @@ def _init_weights(self, module): """The LLaVA-Onevision model which consists of a vision backbone and a language model.""", LLAVA_ONEVISION_START_DOCSTRING, ) -class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel): +class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) self.vision_tower = AutoModel.from_config( From ed22155a486a1607df24439f49f849e80f3e05fe Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 23 Sep 2024 15:51:04 +0000 Subject: [PATCH 8/9] add test to check inheritance is in place; fix missing models --- .../models/granitemoe/modeling_granitemoe.py | 3 ++- src/transformers/models/olmoe/modeling_olmoe.py | 3 ++- tests/generation/test_utils.py | 9 +++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 3ac462bdad34..d724485990b9 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( @@ -1234,7 +1235,7 @@ def _update_causal_mask( return causal_mask -class GraniteMoeForCausalLM(GraniteMoePreTrainedModel): +class GraniteMoeForCausalLM(GraniteMoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config: GraniteMoeConfig): diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 2cbde7dc8631..d30cace3a705 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoeCausalLMOutputWithPast, @@ -1173,7 +1174,7 @@ def _update_causal_mask( return causal_mask -class OlmoeForCausalLM(OlmoePreTrainedModel): +class OlmoeForCausalLM(OlmoePreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151..600942a7ac08 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2099,6 +2099,15 @@ def test_assisted_decoding_with_num_logits_to_keep(self): ) self.assertEqual(with_all_logits.tolist(), without_all_logits.tolist()) + @pytest.mark.generate + def test_inherits_generation_mixin(self): + """ + Tests that the model class directly inherits `GenerationMixin`, as opposed to relying on `PreTrainedModel` + to inherit it. + """ + for model_class in self.all_generative_model_classes: + self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): batch_size, seq_length = input_ids.shape config = config.text_config if hasattr(config, "text_config") else config From 8b40781c971dc0245720ae8ed9fe8b5bbe91a9cf Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 23 Sep 2024 15:52:00 +0000 Subject: [PATCH 9/9] make fixup --- .../models/llava_next_video/diff_llava_next_video.py | 1 - .../models/llava_next_video/modeling_llava_next_video.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/models/llava_next_video/diff_llava_next_video.py b/src/transformers/models/llava_next_video/diff_llava_next_video.py index 2a954fdd0eca..c5ca2bf00324 100644 --- a/src/transformers/models/llava_next_video/diff_llava_next_video.py +++ b/src/transformers/models/llava_next_video/diff_llava_next_video.py @@ -29,7 +29,6 @@ image_size_to_num_patches, ) -from ...cache_utils import Cache from ...generation import GenerationMixin from ...utils import ( logging, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index e6d2b73318ad..7ad9e0769eb3 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -30,7 +30,6 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache from ...generation import GenerationMixin from ...image_processing_utils import select_best_resolution from ...modeling_outputs import ModelOutput