-
Notifications
You must be signed in to change notification settings - Fork 31.8k
T5 & mT5 #8552
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
patrickvonplaten
merged 9 commits into
huggingface:master
from
patrickvonplaten:adapt_for_tv1.1
Nov 17, 2020
Merged
T5 & mT5 #8552
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
fec9ec8
add mt5 and t5v1_1 model
patrickvonplaten d3470c5
fix tests
patrickvonplaten 847a2ab
solve merge conflicts
patrickvonplaten 91295ad
correct some imports
patrickvonplaten ccc0b48
add tf model
patrickvonplaten 4672304
finish tf t5
patrickvonplaten b087382
improve examples
patrickvonplaten 1f3842e
fix copies
patrickvonplaten 0b48175
clean doc
patrickvonplaten File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| MT5 | ||
| ----------------------------------------------------------------------------------------------------------------------- | ||
|
|
||
| Overview | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer | ||
| <https://arxiv.org/abs/2010.11934>`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya | ||
| Siddhant, Aditya Barua, Colin Raffel. | ||
|
|
||
| The abstract from the paper is the following: | ||
|
|
||
| *The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain | ||
| state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a | ||
| multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe | ||
| the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual | ||
| benchmarks. All of the code and model checkpoints* | ||
|
|
||
| The original code can be found `here <https://github.com/google-research/multilingual-t5>`__. | ||
|
|
||
| MT5Config | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autoclass:: transformers.MT5Config | ||
| :members: | ||
|
|
||
|
|
||
| MT5Model | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autoclass:: transformers.MT5Model | ||
| :members: | ||
|
|
||
|
|
||
| MT5ForConditionalGeneration | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autoclass:: transformers.MT5ForConditionalGeneration | ||
| :members: | ||
|
|
||
|
|
||
| TFMT5Model | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autoclass:: transformers.TFMT5Model | ||
| :members: | ||
|
|
||
|
|
||
| TFMT5ForConditionalGeneration | ||
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| .. autoclass:: transformers.TFMT5ForConditionalGeneration | ||
| :members: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # flake8: noqa | ||
| # There's no way to ignore "F401 '...' imported but unused" warnings in this | ||
| # module, but to preserve other warnings. So, don't check this module at all. | ||
|
|
||
| from ...file_utils import is_tf_available, is_torch_available | ||
| from .configuration_mt5 import MT5Config | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model | ||
|
|
||
| if is_tf_available(): | ||
| from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,122 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2020, The T5 Authors and HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ mT5 model configuration """ | ||
|
|
||
| from ...configuration_utils import PretrainedConfig | ||
| from ...utils import logging | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class MT5Config(PretrainedConfig): | ||
| r""" | ||
| This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a | ||
| :class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments, | ||
| defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration | ||
| to that of the mT5 `google/mt5-small <https://huggingface.co/google/mt5-small>`__ architecture. | ||
|
|
||
| Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model | ||
| outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. | ||
|
|
||
| Arguments: | ||
| vocab_size (:obj:`int`, `optional`, defaults to 32128): | ||
| Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the | ||
| :obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`. | ||
| d_model (:obj:`int`, `optional`, defaults to 512): | ||
| Size of the encoder layers and the pooler layer. | ||
| d_kv (:obj:`int`, `optional`, defaults to 64): | ||
| Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model | ||
| // num_heads`. | ||
| d_ff (:obj:`int`, `optional`, defaults to 1024): | ||
| Size of the intermediate feed forward layer in each :obj:`T5Block`. | ||
| num_layers (:obj:`int`, `optional`, defaults to 8): | ||
| Number of hidden layers in the Transformer encoder. | ||
| num_decoder_layers (:obj:`int`, `optional`): | ||
| Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not | ||
| set. | ||
| num_heads (:obj:`int`, `optional`, defaults to 6): | ||
| Number of attention heads for each attention layer in the Transformer encoder. | ||
| relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32): | ||
| The number of buckets to use for each attention layer. | ||
| dropout_rate (:obj:`float`, `optional`, defaults to 0.1): | ||
| The ratio for all dropout layers. | ||
| layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6): | ||
| The epsilon used by the layer normalization layers. | ||
| initializer_factor (:obj:`float`, `optional`, defaults to 1): | ||
| A factor for initializing all weight matrices (should be kept to 1, used internally for initialization | ||
| testing). | ||
| feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`): | ||
| Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. | ||
| """ | ||
| model_type = "mt5" | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocab_size=250112, | ||
| d_model=512, | ||
| d_kv=64, | ||
| d_ff=1024, | ||
| num_layers=8, | ||
| num_decoder_layers=None, | ||
| num_heads=6, | ||
| relative_attention_num_buckets=32, | ||
| dropout_rate=0.1, | ||
| layer_norm_epsilon=1e-6, | ||
| initializer_factor=1.0, | ||
| feed_forward_proj="gated-gelu", | ||
| is_encoder_decoder=True, | ||
| tokenizer_class="T5Tokenizer", | ||
| tie_word_embeddings=False, | ||
| pad_token_id=0, | ||
| eos_token_id=1, | ||
| decoder_start_token_id=0, | ||
| **kwargs | ||
| ): | ||
| super().__init__( | ||
| is_encoder_decoder=is_encoder_decoder, | ||
| tokenizer_class=tokenizer_class, | ||
| tie_word_embeddings=tie_word_embeddings, | ||
| pad_token_id=pad_token_id, | ||
| eos_token_id=eos_token_id, | ||
| decoder_start_token_id=decoder_start_token_id, | ||
| **kwargs, | ||
| ) | ||
| self.vocab_size = vocab_size | ||
| self.d_model = d_model | ||
| self.d_kv = d_kv | ||
| self.d_ff = d_ff | ||
| self.num_layers = num_layers | ||
| self.num_decoder_layers = ( | ||
| num_decoder_layers if num_decoder_layers is not None else self.num_layers | ||
| ) # default = symmetry | ||
| self.num_heads = num_heads | ||
| self.relative_attention_num_buckets = relative_attention_num_buckets | ||
| self.dropout_rate = dropout_rate | ||
| self.layer_norm_epsilon = layer_norm_epsilon | ||
| self.initializer_factor = initializer_factor | ||
| self.feed_forward_proj = feed_forward_proj | ||
|
|
||
| @property | ||
| def hidden_size(self): | ||
| return self.d_model | ||
|
|
||
| @property | ||
| def num_attention_heads(self): | ||
| return self.num_heads | ||
|
|
||
| @property | ||
| def num_hidden_layers(self): | ||
| return self.num_layers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| # coding=utf-8 | ||
| # Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ PyTorch mT5 model. """ | ||
|
|
||
| from ...utils import logging | ||
| from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model | ||
| from .configuration_mt5 import MT5Config | ||
|
|
||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
| _CONFIG_FOR_DOC = "T5Config" | ||
| _TOKENIZER_FOR_DOC = "T5Tokenizer" | ||
|
|
||
|
|
||
| class MT5Model(T5Model): | ||
| r""" | ||
| This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation | ||
| alongside usage examples. | ||
|
|
||
| Examples:: | ||
| >>> from transformers import MT5Model, T5Tokenizer | ||
| >>> model = MT5Model.from_pretrained("google/mt5-small") | ||
| >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") | ||
| >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." | ||
| >>> summary = "Weiter Verhandlung in Syrien." | ||
| >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") | ||
| >>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels) | ||
| >>> hidden_states = outputs.last_hidden_state | ||
| """ | ||
| model_type = "mt5" | ||
| config_class = MT5Config | ||
| authorized_missing_keys = [ | ||
| r"encoder\.embed_tokens\.weight", | ||
| r"decoder\.embed_tokens\.weight", | ||
| r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", | ||
| ] | ||
| keys_to_never_save = [ | ||
| r"encoder\.embed_tokens\.weight", | ||
| r"decoder\.embed_tokens\.weight", | ||
| ] | ||
|
|
||
|
|
||
| class MT5ForConditionalGeneration(T5ForConditionalGeneration): | ||
| r""" | ||
| This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the | ||
| appropriate documentation alongside usage examples. | ||
|
|
||
| Examples:: | ||
| >>> from transformers import MT5ForConditionalGeneration, T5Tokenizer | ||
| >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small") | ||
| >>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small") | ||
| >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien." | ||
| >>> summary = "Weiter Verhandlung in Syrien." | ||
| >>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt") | ||
| >>> outputs = model(**batch) | ||
| >>> loss = outputs.loss | ||
| """ | ||
|
|
||
| model_type = "mt5" | ||
| config_class = MT5Config | ||
| authorized_missing_keys = [ | ||
| r"encoder\.embed_tokens\.weight", | ||
| r"decoder\.embed_tokens\.weight", | ||
| r"lm_head\.weight", | ||
| r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", | ||
| ] | ||
| keys_to_never_save = [ | ||
| r"encoder\.embed_tokens\.weight", | ||
| r"decoder\.embed_tokens\.weight", | ||
| ] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config should be always imported, not just under
is_torch_available.