diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index d941b00318b0..4a4b59e9c16f 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -194,6 +194,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] TFAutoModelForMultipleChoice +## TFAutoModelForNextSentencePrediction + +[[autodoc]] TFAutoModelForNextSentencePrediction + ## TFAutoModelForTableQuestionAnswering [[autodoc]] TFAutoModelForTableQuestionAnswering diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 4ae5c9a57ecb..b1918bf4609d 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -67,6 +67,7 @@ Ready-made configurations include the following architectures: - M2M100 - Marian - mBART +- MobileBert - OpenAI GPT-2 - PLBart - RoBERTa diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 6d976ef6f2d7..2c41ff883fb9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1798,6 +1798,7 @@ "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -3964,6 +3965,7 @@ TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index 6dace993cd74..fa34a11964b0 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -108,6 +108,7 @@ "TFAutoModelForSeq2SeqLM", "TFAutoModelForSequenceClassification", "TFAutoModelForSpeechSeq2Seq", + "TFAutoModelForNextSentencePrediction", "TFAutoModelForTableQuestionAnswering", "TFAutoModelForTokenClassification", "TFAutoModelForVision2Seq", @@ -224,6 +225,7 @@ TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForPreTraining, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, diff --git a/src/transformers/models/mobilebert/__init__.py b/src/transformers/models/mobilebert/__init__.py index 505dabe18791..b35fe8a9c11a 100644 --- a/src/transformers/models/mobilebert/__init__.py +++ b/src/transformers/models/mobilebert/__init__.py @@ -22,7 +22,11 @@ _import_structure = { - "configuration_mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig"], + "configuration_mobilebert": [ + "MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "MobileBertConfig", + "MobileBertOnnxConfig", + ], "tokenization_mobilebert": ["MobileBertTokenizer"], } @@ -62,7 +66,11 @@ if TYPE_CHECKING: - from .configuration_mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig + from .configuration_mobilebert import ( + MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + MobileBertConfig, + MobileBertOnnxConfig, + ) from .tokenization_mobilebert import MobileBertTokenizer if is_tokenizers_available(): diff --git a/src/transformers/models/mobilebert/configuration_mobilebert.py b/src/transformers/models/mobilebert/configuration_mobilebert.py index 27863235b3d7..73b8844ed763 100644 --- a/src/transformers/models/mobilebert/configuration_mobilebert.py +++ b/src/transformers/models/mobilebert/configuration_mobilebert.py @@ -13,8 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """ MobileBERT model configuration""" +from collections import OrderedDict +from typing import Mapping from ...configuration_utils import PretrainedConfig +from ...onnx import OnnxConfig from ...utils import logging @@ -165,3 +168,20 @@ def __init__( self.true_hidden_size = hidden_size self.classifier_dropout = classifier_dropout + + +# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->MobileBert +class MobileBertOnnxConfig(OnnxConfig): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("token_type_ids", dynamic_axis), + ] + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 50e941332a95..516288e8d6a7 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -24,6 +24,7 @@ from ..models.m2m_100 import M2M100OnnxConfig from ..models.marian import MarianOnnxConfig from ..models.mbart import MBartOnnxConfig +from ..models.mobilebert import MobileBertOnnxConfig from ..models.roberta import RobertaOnnxConfig from ..models.roformer import RoFormerOnnxConfig from ..models.t5 import T5OnnxConfig @@ -43,6 +44,7 @@ AutoModelForMaskedImageModeling, AutoModelForMaskedLM, AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, @@ -54,6 +56,7 @@ TFAutoModelForCausalLM, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, + TFAutoModelForNextSentencePrediction, TFAutoModelForQuestionAnswering, TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, @@ -107,6 +110,7 @@ class FeaturesManager: "question-answering": AutoModelForQuestionAnswering, "image-classification": AutoModelForImageClassification, "masked-im": AutoModelForMaskedImageModeling, + "next-sentence-prediction": AutoModelForNextSentencePrediction, } if is_tf_available(): _TASKS_TO_TF_AUTOMODELS = { @@ -118,6 +122,7 @@ class FeaturesManager: "token-classification": TFAutoModelForTokenClassification, "multiple-choice": TFAutoModelForMultipleChoice, "question-answering": TFAutoModelForQuestionAnswering, + "next-sentence-prediction": TFAutoModelForNextSentencePrediction, } # Set of model topologies we support associated to the features supported by each topology and the factory @@ -152,6 +157,7 @@ class FeaturesManager: "multiple-choice", "token-classification", "question-answering", + "next-sentence-prediction", onnx_config_cls=BertOnnxConfig, ), "big-bird": supported_features_mapping( @@ -304,6 +310,16 @@ class FeaturesManager: "question-answering", onnx_config_cls=MBartOnnxConfig, ), + "mobilebert": supported_features_mapping( + "default", + "masked-lm", + "next-sentence-prediction", + "sequence-classification", + "multiple-choice", + "token-classification", + "question-answering", + onnx_config_cls=MobileBertOnnxConfig, + ), "m2m-100": supported_features_mapping( "default", "default-with-past", "seq2seq-lm", "seq2seq-lm-with-past", onnx_config_cls=M2M100OnnxConfig ), diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 4ecfc917d56b..90c0b5be7adf 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -180,6 +180,7 @@ def test_values_override(self): ("electra", "google/electra-base-generator"), ("roberta", "roberta-base"), ("roformer", "junnyu/roformer_chinese_base"), + ("mobilebert", "google/mobilebert-uncased"), ("xlm-roberta", "xlm-roberta-base"), ("layoutlm", "microsoft/layoutlm-base-uncased"), ("vit", "google/vit-base-patch16-224"),