diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index e41ccae949e8..9561bbd8ec77 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -79,6 +79,7 @@ Ready-made configurations include the following architectures: - mBART - MobileBERT - MobileViT +- MT5 - OpenAI GPT-2 - Perceiver - PLBart diff --git a/src/transformers/models/mt5/__init__.py b/src/transformers/models/mt5/__init__.py index 3f04a256918b..f6e717bd875b 100644 --- a/src/transformers/models/mt5/__init__.py +++ b/src/transformers/models/mt5/__init__.py @@ -43,7 +43,7 @@ MT5TokenizerFast = T5TokenizerFast -_import_structure = {"configuration_mt5": ["MT5Config"]} +_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]} try: if not is_torch_available(): @@ -71,7 +71,7 @@ if TYPE_CHECKING: - from .configuration_mt5 import MT5Config + from .configuration_mt5 import MT5Config, MT5OnnxConfig try: if not is_torch_available(): diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index ad0345f53189..3e72831ad25f 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -13,8 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. """ mT5 model configuration""" +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxSeq2SeqConfigWithPast from ...utils import logging @@ -143,3 +145,29 @@ def num_attention_heads(self): @property def num_hidden_layers(self): return self.num_layers + + +# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig +class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = { + "input_ids": {0: "batch", 1: "encoder_sequence"}, + "attention_mask": {0: "batch", 1: "encoder_sequence"}, + } + if self.use_past: + common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence" + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + + return common_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index e7c24a8ad97a..ace5eb620a25 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -377,6 +377,13 @@ class FeaturesManager: "image-classification", onnx_config_cls="models.mobilevit.MobileViTOnnxConfig", ), + "mt5": supported_features_mapping( + "default", + "default-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + onnx_config_cls="models.mt5.MT5OnnxConfig", + ), "m2m-100": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 6b22dc3420f2..238f9ed4c7c0 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -217,6 +217,7 @@ def test_values_override(self): ("mbart", "sshleifer/tiny-mbart"), ("t5", "t5-small"), ("marian", "Helsinki-NLP/opus-mt-en-de"), + ("mt5", "google/mt5-base"), ("m2m-100", "facebook/m2m100_418M"), ("blenderbot-small", "facebook/blenderbot_small-90M"), ("blenderbot", "facebook/blenderbot-400M-distill"),