diff --git a/docs/source/index.rst b/docs/source/index.rst index 737f562f663e..dfd4164206b7 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -248,6 +248,7 @@ conversion utilities for the following models: model_doc/marian model_doc/mbart model_doc/mobilebert + model_doc/mt5 model_doc/gpt model_doc/gpt2 model_doc/pegasus diff --git a/docs/source/model_doc/mt5.rst b/docs/source/model_doc/mt5.rst new file mode 100644 index 000000000000..9171f5164913 --- /dev/null +++ b/docs/source/model_doc/mt5.rst @@ -0,0 +1,53 @@ +MT5 +----------------------------------------------------------------------------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer +`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya +Siddhant, Aditya Barua, Colin Raffel. + +The abstract from the paper is the following: + +*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain +state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a +multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe +the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual +benchmarks. All of the code and model checkpoints* + +The original code can be found `here `__. + +MT5Config +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MT5Config + :members: + + +MT5Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MT5Model + :members: + + +MT5ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.MT5ForConditionalGeneration + :members: + + +TFMT5Model +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMT5Model + :members: + + +TFMT5ForConditionalGeneration +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.TFMT5ForConditionalGeneration + :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ee5e6806113c..65ad1bbfcd86 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -498,6 +498,7 @@ MobileBertPreTrainedModel, load_tf_weights_in_mobilebert, ) + from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model from .models.openai import ( OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OpenAIGPTDoubleHeadsModel, @@ -791,6 +792,7 @@ TFMobileBertModel, TFMobileBertPreTrainedModel, ) + from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model from .models.openai import ( TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, TFOpenAIGPTDoubleHeadsModel, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index c8a744397428..8bc1f0ebd8e6 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -40,6 +40,7 @@ from ..marian.configuration_marian import MarianConfig from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from ..mobilebert.configuration_mobilebert import MobileBertConfig +from ..mt5.configuration_mt5 import MT5Config from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from ..pegasus.configuration_pegasus import PegasusConfig from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig @@ -101,6 +102,7 @@ [ # Add configs here ("retribert", RetriBertConfig), + ("mt5", MT5Config), ("t5", T5Config), ("mobilebert", MobileBertConfig), ("distilbert", DistilBertConfig), @@ -178,6 +180,7 @@ ("rag", "RAG"), ("xlm-prophetnet", "XLMProphetNet"), ("prophetnet", "ProphetNet"), + ("mt5", "mT5"), ] ) diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 0ba27b31e165..fe8c7bb4da74 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -120,6 +120,7 @@ MobileBertForTokenClassification, MobileBertModel, ) +from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel @@ -209,6 +210,7 @@ MarianConfig, MBartConfig, MobileBertConfig, + MT5Config, OpenAIGPTConfig, PegasusConfig, ProphetNetConfig, @@ -235,6 +237,7 @@ [ # Base model mapping (RetriBertConfig, RetriBertModel), + (MT5Config, MT5Model), (T5Config, T5Model), (DistilBertConfig, DistilBertModel), (AlbertConfig, AlbertModel), @@ -376,6 +379,7 @@ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (MT5Config, MT5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration), (MarianConfig, MarianMTModel), diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index c433d8c198d8..a1c34a137b23 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -106,6 +106,7 @@ TFMobileBertForTokenClassification, TFMobileBertModel, ) +from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration from ..roberta.modeling_tf_roberta import ( @@ -161,6 +162,7 @@ MarianConfig, MBartConfig, MobileBertConfig, + MT5Config, OpenAIGPTConfig, PegasusConfig, RobertaConfig, @@ -182,6 +184,7 @@ [ # Base model mapping (LxmertConfig, TFLxmertModel), + (MT5Config, TFMT5Model), (T5Config, TFT5Model), (DistilBertConfig, TFDistilBertModel), (AlbertConfig, TFAlbertModel), @@ -294,6 +297,7 @@ TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( [ # Model for Seq2Seq Causal LM mapping + (MT5Config, TFMT5ForConditionalGeneration), (T5Config, TFT5ForConditionalGeneration), (MarianConfig, TFMarianMTModel), (MBartConfig, TFMBartForConditionalGeneration), diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py new file mode 100644 index 000000000000..c186d88b80cb --- /dev/null +++ b/src/transformers/models/mt5/__init__.py @@ -0,0 +1,13 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +from ...file_utils import is_tf_available, is_torch_available +from .configuration_mt5 import MT5Config + + +if is_torch_available(): + from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model + +if is_tf_available(): + from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py new file mode 100644 index 000000000000..23bde1004798 --- /dev/null +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2020, The T5 Authors and HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" mT5 model configuration """ + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +class MT5Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a + :class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration + to that of the mT5 `google/mt5-small `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + Arguments: + vocab_size (:obj:`int`, `optional`, defaults to 32128): + Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`. + d_model (:obj:`int`, `optional`, defaults to 512): + Size of the encoder layers and the pooler layer. + d_kv (:obj:`int`, `optional`, defaults to 64): + Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model + // num_heads`. + d_ff (:obj:`int`, `optional`, defaults to 1024): + Size of the intermediate feed forward layer in each :obj:`T5Block`. + num_layers (:obj:`int`, `optional`, defaults to 8): + Number of hidden layers in the Transformer encoder. + num_decoder_layers (:obj:`int`, `optional`): + Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not + set. + num_heads (:obj:`int`, `optional`, defaults to 6): + Number of attention heads for each attention layer in the Transformer encoder. + relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32): + The number of buckets to use for each attention layer. + dropout_rate (:obj:`float`, `optional`, defaults to 0.1): + The ratio for all dropout layers. + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6): + The epsilon used by the layer normalization layers. + initializer_factor (:obj:`float`, `optional`, defaults to 1): + A factor for initializing all weight matrices (should be kept to 1, used internally for initialization + testing). + feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`): + Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. + """ + model_type = "mt5" + + def __init__( + self, + vocab_size=250112, + d_model=512, + d_kv=64, + d_ff=1024, + num_layers=8, + num_decoder_layers=None, + num_heads=6, + relative_attention_num_buckets=32, + dropout_rate=0.1, + layer_norm_epsilon=1e-6, + initializer_factor=1.0, + feed_forward_proj="gated-gelu", + is_encoder_decoder=True, + tokenizer_class="T5Tokenizer", + tie_word_embeddings=False, + pad_token_id=0, + eos_token_id=1, + decoder_start_token_id=0, + **kwargs + ): + super().__init__( + is_encoder_decoder=is_encoder_decoder, + tokenizer_class=tokenizer_class, + tie_word_embeddings=tie_word_embeddings, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + decoder_start_token_id=decoder_start_token_id, + **kwargs, + ) + self.vocab_size = vocab_size + self.d_model = d_model + self.d_kv = d_kv + self.d_ff = d_ff + self.num_layers = num_layers + self.num_decoder_layers = ( + num_decoder_layers if num_decoder_layers is not None else self.num_layers + ) # default = symmetry + self.num_heads = num_heads + self.relative_attention_num_buckets = relative_attention_num_buckets + self.dropout_rate = dropout_rate + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj + + @property + def hidden_size(self): + return self.d_model + + @property + def num_attention_heads(self): + return self.num_heads + + @property + def num_hidden_layers(self): + return self.num_layers diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py new file mode 100644 index 000000000000..10d64faf305d --- /dev/null +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -0,0 +1,83 @@ +# coding=utf-8 +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch mT5 model. """ + +from ...utils import logging +from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model +from .configuration_mt5 import MT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" + + +class MT5Model(T5Model): + r""" + This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation + alongside usage examples. + + Examples:: + >>> from transformers import MT5Model, T5Tokenizer + >>> model = MT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") + >>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels) + >>> hidden_states = outputs.last_hidden_state + """ + model_type = "mt5" + config_class = MT5Config + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + keys_to_never_save = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] + + +class MT5ForConditionalGeneration(T5ForConditionalGeneration): + r""" + This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the + appropriate documentation alongside usage examples. + + Examples:: + >>> from transformers import MT5ForConditionalGeneration, T5Tokenizer + >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") + >>> outputs = model(**batch) + >>> loss = outputs.loss + """ + + model_type = "mt5" + config_class = MT5Config + authorized_missing_keys = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + r"lm_head\.weight", + r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", + ] + keys_to_never_save = [ + r"encoder\.embed_tokens\.weight", + r"decoder\.embed_tokens\.weight", + ] diff --git a/src/transformers/models/mt5/modeling_tf_mt5.py b/src/transformers/models/mt5/modeling_tf_mt5.py new file mode 100644 index 000000000000..21cf25dedcd5 --- /dev/null +++ b/src/transformers/models/mt5/modeling_tf_mt5.py @@ -0,0 +1,66 @@ +# coding=utf-8 +# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tensorflow mT5 model. """ + +from ...utils import logging +from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model +from .configuration_mt5 import MT5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" + + +class TFMT5Model(TFT5Model): + r""" + This class overrides :class:`~transformers.TFT5Model`. Please check the superclass for the appropriate + documentation alongside usage examples. + + Examples:: + >>> from transformers import TFMT5Model, T5Tokenizer + >>> model = TFMT5Model.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") + >>> batch["decoder_input_ids"] = batch["labels"] + >>> del batch["labels"] + >>> outputs = model(batch) + >>> hidden_states = outputs.last_hidden_state + """ + model_type = "mt5" + config_class = MT5Config + + +class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration): + r""" + This class overrides :class:`~transformers.TFT5ForConditionalGeneration`. Please check the superclass for the + appropriate documentation alongside usage examples. + + Examples:: + >>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer + >>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") + >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") + >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." + >>> summary = "Weiter Verhandlung in Syrien." + >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf") + >>> outputs = model(batch) + >>> loss = outputs.loss + """ + + model_type = "mt5" + config_class = MT5Config diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index b1a045cb18de..48bdb6c32944 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2010, The T5 Authors and HuggingFace Inc. +# Copyright 2020, The T5 Authors and HuggingFace Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -43,9 +43,6 @@ class T5Config(PretrainedConfig): vocab_size (:obj:`int`, `optional`, defaults to 32128): Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the :obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`. - n_positions (:obj:`int`, `optional`, defaults to 512): - The maximum sequence length that this model might ever be used with. Typically set this to something large - just in case (e.g., 512 or 1024 or 2048). d_model (:obj:`int`, `optional`, defaults to 512): Size of the encoder layers and the pooler layer. d_kv (:obj:`int`, `optional`, defaults to 64): @@ -69,6 +66,9 @@ class T5Config(PretrainedConfig): initializer_factor (:obj:`float`, `optional`, defaults to 1): A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing). + feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`): + Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses + the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. """ model_type = "t5" @@ -85,6 +85,7 @@ def __init__( dropout_rate=0.1, layer_norm_epsilon=1e-6, initializer_factor=1.0, + feed_forward_proj="relu", is_encoder_decoder=True, pad_token_id=0, eos_token_id=1, @@ -109,6 +110,7 @@ def __init__( self.dropout_rate = dropout_rate self.layer_norm_epsilon = layer_norm_epsilon self.initializer_factor = initializer_factor + self.feed_forward_proj = feed_forward_proj @property def hidden_size(self): diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py index aed3c1e5e25f..e38680df8427 100755 --- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -17,7 +17,7 @@ import argparse -from transformers import T5Config, T5Model, load_tf_weights_in_t5 +from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 from transformers.utils import logging @@ -28,7 +28,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du # Initialise PyTorch model config = T5Config.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) - model = T5Model(config) + model = T5ForConditionalGeneration(config) # Load weights from tf checkpoint load_tf_weights_in_t5(model, config, tf_checkpoint_path) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index b2fee2503e36..39bff0f46a4b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,6 +25,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...activations import ACT2FN from ...file_utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -140,6 +141,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): continue elif scope_names[0] == "logits": pointer = getattr(pointer, "lm_head") + elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit(): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue else: try: pointer = getattr(pointer, scope_names[0]) @@ -211,10 +215,36 @@ def forward(self, hidden_states): return hidden_states +class T5DenseGatedGeluDense(nn.Module): + def __init__(self, config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.gelu_act = ACT2FN["gelu_new"] + + def forward(self, hidden_states): + hidden_gelu = self.gelu_act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + class T5LayerFF(nn.Module): def __init__(self, config): super().__init__() - self.DenseReluDense = T5DenseReluDense(config) + if config.feed_forward_proj == "relu": + self.DenseReluDense = T5DenseReluDense(config) + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = T5DenseGatedGeluDense(config) + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -641,6 +671,16 @@ def _init_weights(self, module): module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) if hasattr(module.wo, "bias") and module.wo.bias is not None: module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedGeluDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() elif isinstance(module, T5Attention): # Mesh TensorFlow attention initialization to avoid scaling before softmax # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 @@ -1099,8 +1139,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel): r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight", - r"encoder\.embed_tokens\.weight", - r"decoder\.embed_tokens\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", ] @@ -1262,9 +1300,12 @@ def forward( ) sequence_output = decoder_outputs[0] - # Rescale output before projecting on vocab - # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 - sequence_output = sequence_output * (self.model_dim ** -0.5) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim ** -0.5) + lm_logits = self.lm_head(sequence_output) loss = None diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 74e703cb9de9..1f7a78f5bfdd 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 T5 Authors and The HuggingFace Inc. team. +# Copyright 2020 T5 Authors and The HuggingFace Inc. team. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +26,7 @@ from transformers.modeling_tf_utils import TFWrappedEmbeddings +from ...activations_tf import get_tf_activation from ...file_utils import ( DUMMY_INPUTS, DUMMY_MASK, @@ -103,10 +104,35 @@ def call(self, hidden_states, training=False): return hidden_states +class TFT5GatedGeluDense(tf.keras.layers.Layer): + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0") + self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1") + self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo") + self.dropout = tf.keras.layers.Dropout(config.dropout_rate) + self.act = get_tf_activation("gelu_new") + + def call(self, hidden_states, training=False): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states, training=training) + hidden_states = self.wo(hidden_states) + return hidden_states + + class TFT5LayerFF(tf.keras.layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) - self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") + if config.feed_forward_proj == "relu": + self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") + elif config.feed_forward_proj == "gated-gelu": + self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense") + else: + raise ValueError( + f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`" + ) self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.dropout = tf.keras.layers.Dropout(config.dropout_rate) @@ -547,9 +573,6 @@ def __init__(self, config, embed_tokens=None, **kwargs): def get_input_embeddings(self): return self.embed_tokens - def get_output_embeddings(self): - return self.embed_tokens - def set_embed_tokens(self, embed_tokens): self.embed_tokens = embed_tokens @@ -970,9 +993,6 @@ def __init__(self, config, *inputs, **kwargs): def get_input_embeddings(self): return self.shared - def get_output_embeddings(self): - return self.shared - def set_input_embeddings(self, new_embeddings): self.shared.weight = new_embeddings self.shared.vocab_size = self.shared.weight.shape[0] @@ -1165,11 +1185,17 @@ def __init__(self, config, *inputs, **kwargs): decoder_config.is_decoder = True self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") + if not config.tie_word_embeddings: + self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head") + def get_input_embeddings(self): return self.shared def get_output_embeddings(self): - return self.shared + if self.config.tie_word_embeddings: + return self.shared + else: + return self.lm_head def set_input_embeddings(self, new_embeddings): self.shared.weight = new_embeddings @@ -1331,9 +1357,14 @@ def call( training=training, ) - sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) - embed_tokens = self.get_output_embeddings() - logits = embed_tokens(sequence_output, mode="linear") + sequence_output = decoder_outputs[0] + + # T5v1.1 does not tie output word embeddings and thus does not require downscaling + if self.config.tie_word_embeddings: + sequence_output = sequence_output * (self.model_dim ** -0.5) + logits = self.get_output_embeddings()(sequence_output, mode="linear") + else: + logits = self.get_output_embeddings()(sequence_output) loss = None if labels is None else self.compute_loss(labels, logits) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 979e906fcf09..596992e54a2d 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -1361,6 +1361,29 @@ def load_tf_weights_in_mobilebert(*args, **kwargs): requires_pytorch(load_tf_weights_in_mobilebert) +class MT5Config: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + +class MT5ForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + +class MT5Model: + def __init__(self, *args, **kwargs): + requires_pytorch(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_pytorch(self) + + OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 33b58b90a47d..38e4d831abd3 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -970,6 +970,24 @@ def from_pretrained(self, *args, **kwargs): requires_tf(self) +class TFMT5ForConditionalGeneration: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + +class TFMT5Model: + def __init__(self, *args, **kwargs): + requires_tf(self) + + @classmethod + def from_pretrained(self, *args, **kwargs): + requires_tf(self) + + TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/tests/test_modeling_mt5.py b/tests/test_modeling_mt5.py new file mode 100644 index 000000000000..ce6e2925b3f1 --- /dev/null +++ b/tests/test_modeling_mt5.py @@ -0,0 +1,39 @@ +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device + + +if is_torch_available(): + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + +@require_torch +@require_sentencepiece +@require_tokenizers +class MT5IntegrationTest(unittest.TestCase): + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_mt5_checkpoint = '' + >>> path_to_mtf_small_mt5_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small", return_dict=True).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -84.9127 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) diff --git a/tests/test_modeling_t5.py b/tests/test_modeling_t5.py index 31a73c100106..90573d5a7890 100644 --- a/tests/test_modeling_t5.py +++ b/tests/test_modeling_t5.py @@ -490,6 +490,14 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + # check that gated gelu feed forward and different word embeddings work + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_model(config, *config_and_inputs[1:]) + def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_lm_head(*config_and_inputs) @@ -569,7 +577,7 @@ def test_small_integration_test(self): >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) """ - model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device) + model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device) tokenizer = T5Tokenizer.from_pretrained("t5-small") input_ids = tokenizer("Hello there", return_tensors="pt").input_ids @@ -581,6 +589,32 @@ def test_small_integration_test(self): EXPECTED_SCORE = -19.0845 self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_t5_v1_1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1_1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="pt").input_ids + labels = tokenizer("Hi I am", return_tensors="pt").input_ids + + loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss + mtf_score = -(labels.shape[-1] * loss.item()) + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + @slow def test_summarization(self): model = self.model diff --git a/tests/test_modeling_tf_mt5.py b/tests/test_modeling_tf_mt5.py new file mode 100644 index 000000000000..d2c65372d042 --- /dev/null +++ b/tests/test_modeling_tf_mt5.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import is_tf_available +from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow + + +if is_tf_available(): + import tensorflow as tf + + from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM + + +@require_tf +@require_sentencepiece +@require_tokenizers +class TFMT5ModelIntegrationTest(unittest.TestCase): + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_mt5_checkpoint = '' + >>> path_to_mtf_small_mt5_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path, extra_ids=100) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFAutoModelForSeq2SeqLM.from_pretrained("google/mt5-small") + tokenizer = AutoTokenizer.from_pretrained("google/mt5-small") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_sum(loss).numpy() + + EXPECTED_SCORE = -84.9127 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) diff --git a/tests/test_modeling_tf_t5.py b/tests/test_modeling_tf_t5.py index d7e43d569cc3..45ba79ec220c 100644 --- a/tests/test_modeling_tf_t5.py +++ b/tests/test_modeling_tf_t5.py @@ -258,6 +258,13 @@ def test_t5_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_model(*config_and_inputs) + def test_t5_model_v1_1(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + config = config_and_inputs[0] + config.tie_word_embeddings = False + config.feed_forward_proj = "gated-gelu" + self.model_tester.create_and_check_t5_model(config, *config_and_inputs[1:]) + def test_with_lm_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) @@ -296,6 +303,58 @@ class TFT5ModelIntegrationTests(unittest.TestCase): def model(self): return TFT5ForConditionalGeneration.from_pretrained("t5-base") + @slow + def test_small_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_t5_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFT5ForConditionalGeneration.from_pretrained("t5-small") + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_sum(loss).numpy() + + EXPECTED_SCORE = -19.0845 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + + @slow + def test_small_v1_1_integration_test(self): + """ + For comparision run: + >>> import t5 # pip install t5==0.7.1 + >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary + + >>> path_to_mtf_small_t5_v1.1_checkpoint = '' + >>> path_to_mtf_small_spm_model_path = '' + >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1.1_checkpoint, batch_size=1, tpu=None) + >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100) + >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) + """ + + model = TFT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small") + tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small") + + input_ids = tokenizer("Hello there", return_tensors="tf").input_ids + labels = tokenizer("Hi I am", return_tensors="tf").input_ids + + loss = model(input_ids, labels=labels).loss + mtf_score = -tf.math.reduce_sum(loss).numpy() + + EXPECTED_SCORE = -59.0293 + self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) + @slow def test_summarization(self): model = self.model diff --git a/utils/check_repo.py b/utils/check_repo.py index fd4316cd41d0..d05cbf8326e8 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -46,6 +46,7 @@ "test_modeling_flax_bert.py", "test_modeling_flax_roberta.py", "test_modeling_mbart.py", + "test_modeling_mt5.py", "test_modeling_pegasus.py", "test_modeling_tf_camembert.py", "test_modeling_tf_xlm_roberta.py",