From a38fd8073571215ece50b7a3d081faa5c518d7f4 Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Wed, 1 Jul 2020 23:30:57 -0400 Subject: [PATCH 01/18] Add DeBERTa model --- README.md | 5 +- docs/source/index.rst | 1 + docs/source/model_doc/deberta.rst | 62 ++++ docs/source/pretrained_models.rst | 10 + model_cards/microsoft/DeBERTa-base/README.md | 36 ++ model_cards/microsoft/DeBERTa-large/README.md | 37 ++ setup.py | 1 + src/transformers/__init__.py | 9 + src/transformers/configuration_auto.py | 3 + src/transformers/configuration_deberta.py | 105 ++++++ src/transformers/modeling_auto.py | 4 + src/transformers/modeling_deberta.py | 328 ++++++++++++++++++ src/transformers/tokenization_auto.py | 3 + src/transformers/tokenization_deberta.py | 220 ++++++++++++ tests/test_modeling_deberta.py | 269 ++++++++++++++ tests/test_tokenization_deberta.py | 72 ++++ 16 files changed, 1163 insertions(+), 2 deletions(-) create mode 100644 docs/source/model_doc/deberta.rst create mode 100644 model_cards/microsoft/DeBERTa-base/README.md create mode 100644 model_cards/microsoft/DeBERTa-large/README.md create mode 100644 src/transformers/configuration_deberta.py create mode 100644 src/transformers/modeling_deberta.py create mode 100644 src/transformers/tokenization_deberta.py create mode 100644 tests/test_modeling_deberta.py create mode 100644 tests/test_tokenization_deberta.py diff --git a/README.md b/README.md index c12fdc63f6dd..bd1cda412705 100644 --- a/README.md +++ b/README.md @@ -183,8 +183,9 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih. 24. **[MBart](https://github.com/pytorch/fairseq/tree/master/examples/mbart)** (from Facebook) released with the paper [Multilingual Denoising Pre-training for Neural Machine Translation](https://arxiv.org/abs/2001.08210) by Yinhan Liu, Jiatao Gu, Naman Goyal, Xian Li, Sergey Edunov, Marjan Ghazvininejad, Mike Lewis, Luke Zettlemoyer. 25. **[LXMERT](https://github.com/airsplay/lxmert)** (from UNC Chapel Hill) released with the paper [LXMERT: Learning Cross-Modality Encoder Representations from Transformers for Open-Domain Question Answering](https://arxiv.org/abs/1908.07490) by Hao Tan and Mohit Bansal. 26. **[Funnel Transformer](https://github.com/laiguokun/Funnel-Transformer)** (from CMU/Google Brain) released with the paper [Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing](https://arxiv.org/abs/2006.03236) by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le. -27. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). -28. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. +27. **[DeBERTa](https://huggingface.co/transformers/model_doc/deberta.html)** (from Microsoft Research) released with the paper [DeBERTa: Decoding-enhanced BERT with Disentangled Attention](https://arxiv.org/abs/2006.03654) by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. +28. **[Other community models](https://huggingface.co/models)**, contributed by the [community](https://huggingface.co/users). +29. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations. You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). diff --git a/docs/source/index.rst b/docs/source/index.rst index 08e65b7549b6..6d66bdfa492b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -227,6 +227,7 @@ conversion utilities for the following models: model_doc/funnel model_doc/lxmert model_doc/bertgeneration + model_doc/deberta internal/modeling_utils internal/tokenization_utils internal/pipelines_utils diff --git a/docs/source/model_doc/deberta.rst b/docs/source/model_doc/deberta.rst new file mode 100644 index 000000000000..56c925c38f86 --- /dev/null +++ b/docs/source/model_doc/deberta.rst @@ -0,0 +1,62 @@ +DeBERTa +---------------------------------------------------- + +Overview +~~~~~~~~~~~~~~~~~~~~~ + +The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `_ +by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen +It is based on Google's BERT model released in 2018 and Facebook's RoBERTa model released in 2019. + +It builds on RoBERTa with disentangled attention and enhanced mask decoder training with half of the data used in RoBERTa. + +The abstract from the paper is the following: + +*Recent progress in pre-trained neural language models has significantly improved the performance of many natural language processing (NLP) tasks. +In this paper we propose a new model architecture DeBERTa (Decoding-enhanced BERT with disentangled attention) that improves the BERT and RoBERTa +models using two novel techniques. The first is the disentangled attention mechanism, where each word is represented using two vectors that encode +its content and position, respectively, and the attention weights among words are computed using disentangled matrices on their contents and +relative positions. Second, an enhanced mask decoder is used to replace the output softmax layer to predict the masked tokens for model pretraining. +We show that these two techniques significantly improve the efficiency of model pre-training and performance of downstream tasks. Compared to +RoBERTa-Large, a DeBERTa model trained on half of the training data performs consistently better on a wide range of NLP tasks, achieving improvements +on MNLI by +0.9% (90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and RACE by +3.6% (83.2% vs. 86.8%). The DeBERTa code and pre-trained +models will be made publicly available at https://github.com/microsoft/DeBERTa.* + + +The original code can be found `here `_. + + +DeBERTaConfig +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeBERTaConfig + :members: + + +DeBERTaTokenizer +~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeBERTaTokenizer + :members: build_inputs_with_special_tokens, get_special_tokens_mask, + create_token_type_ids_from_sequences, save_vocabulary + + +DeBERTaModel +~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeBERTaModel + :members: + + +DeBERTaPreTrainedModel +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeBERTaPreTrainedModel + :members: + + +DeBERTaForSequenceClassification +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: transformers.DeBERTaForSequenceClassification + :members: diff --git a/docs/source/pretrained_models.rst b/docs/source/pretrained_models.rst index 9c277f4a83e6..50a51fffc6fa 100644 --- a/docs/source/pretrained_models.rst +++ b/docs/source/pretrained_models.rst @@ -408,3 +408,13 @@ For a list that includes community-uploaded models, refer to `https://huggingfac | | | | | | | (see `details `__) | +--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| DeBERTa | ``microsoft/deberta-base`` | | 12-layer, 768-hidden, 12-heads, ~125M parameters | +| | | | DeBERTa using the BERT-base architecture | +| | | | +| | | (see `details `__) | +| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ +| | ``microsoft/deberta-large`` | | 24-layer, 1024-hidden, 16-heads, ~390M parameters | +| | | | DeBERTa using the BERT-large architecture | +| | | | +| | | (see `details `__) | ++--------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/model_cards/microsoft/DeBERTa-base/README.md b/model_cards/microsoft/DeBERTa-base/README.md new file mode 100644 index 000000000000..8c53040bfaae --- /dev/null +++ b/model_cards/microsoft/DeBERTa-base/README.md @@ -0,0 +1,36 @@ +--- +thumbnail: https://huggingface.co/front/thumbnails/microsoft.png +license: mit +--- + +## DeBERTa: Decoding-enhanced BERT with Disentangled Attention + +[DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data. + +Please check the [official repository](https://github.com/microsoft/DeBERTa) for more details and updates. + + +#### Fine-tuning on NLU tasks + +We present the dev results on SQuAD 1.1/2.0 and MNLI tasks. + +| Model | SQuAD 1.1 | SQuAD 2.0 | MNLI-m | +|-------------------|-----------|-----------|--------| +| RoBERTa-base | 91.5/84.6 | 83.7/80.5 | 87.6 | +| XLNet-Large | -/- | -/80.2 | 86.8 | +| **DeBERTa-base** | 93.1/87.2 | 86.2/83.1 | 88.8 | + +### Citation + +If you find DeBERTa useful for your work, please cite the following paper: + +``` latex +@misc{he2020deberta, + title={DeBERTa: Decoding-enhanced BERT with Disentangled Attention}, + author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen}, + year={2020}, + eprint={2006.03654}, + archivePrefix={arXiv}, + primaryClass={cs.CL} + } +``` diff --git a/model_cards/microsoft/DeBERTa-large/README.md b/model_cards/microsoft/DeBERTa-large/README.md new file mode 100644 index 000000000000..9e36c951100f --- /dev/null +++ b/model_cards/microsoft/DeBERTa-large/README.md @@ -0,0 +1,37 @@ +--- +thumbnail: https://huggingface.co/front/thumbnails/microsoft.png +license: mit +--- + +## DeBERTa: Decoding-enhanced BERT with Disentangled Attention + +[DeBERTa](https://arxiv.org/abs/2006.03654) improves the BERT and RoBERTa models using disentangled attention and enhanced mask decoder. With those two improvements, DeBERTa out perform RoBERTa on a majority of NLU tasks with 80GB training data. + +Please check the [official repository](https://github.com/microsoft/DeBERTa) for more details and updates. + + +#### Fine-tuning on NLU tasks + +We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks. + +| Model | SQuAD 1.1 | SQuAD 2.0 | MNLI-m | SST-2 | QNLI | CoLA | RTE | MRPC | QQP |STS-B| +|-------------------|-----------|-----------|--------|-------|------|------|------|------|------|-----| +| BERT-Large | 90.9/84.1 | 81.8/79.0 | 86.6 | 93.2 | 92.3 | 60.6 | 70.4 | 88.0 | 91.3 |90.0 | +| RoBERTa-Large | 94.6/88.9 | 89.4/86.5 | 90.2 | 96.4 | 93.9 | 68.0 | 86.6 | 90.9 | 92.2 |92.4 | +| XLNet-Large | 95.1/89.7 | 90.6/87.9 | 90.8 | 97.0 | 94.9 | 69.0 | 85.9 | 90.8 | 92.3 |92.5 | +| **DeBERTa-Large** | 95.5/90.1 | 90.7/88.0 | 91.1 | 96.5 | 95.3 | 69.5 | 88.1 | 92.5 | 92.3 |92.5 | + +### Citation + +If you find DeBERTa useful for your work, please cite the following paper: + +``` latex +@misc{he2020deberta, + title={DeBERTa: Decoding-enhanced BERT with Disentangled Attention}, + author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen}, + year={2020}, + eprint={2006.03654}, + archivePrefix={arXiv}, + primaryClass={cs.CL} + } +``` diff --git a/setup.py b/setup.py index 8527abdec7e3..1559accc8d77 100644 --- a/setup.py +++ b/setup.py @@ -127,6 +127,7 @@ "sentencepiece != 0.1.92", # for XLM "sacremoses", + "deberta", ], extras_require=extras, entry_points={ diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 154cccff9dd2..4722bbcb26d9 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -25,6 +25,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig +from .configuration_deberta import DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -149,6 +150,7 @@ from .tokenization_bertweet import BertweetTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer +from .tokenization_deberta import DeBERTaTokenizer from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_dpr import ( DPRContextEncoderTokenizer, @@ -479,6 +481,13 @@ load_tf_weights_in_xlnet, ) + from .modeling_deberta import ( + DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST, + DeBERTaPreTrainedModel, + DeBERTaModel, + DeBERTaForSequenceClassification, + ) + # Optimization from .optimization import ( Adafactor, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 9f429e3307f3..7962aa9ed6dc 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -23,6 +23,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig +from .configuration_deberta import DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_encoder_decoder import EncoderDecoderConfig @@ -73,6 +74,7 @@ RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, ] for key, value, in pretrained_map.items() ) @@ -108,6 +110,7 @@ ("encoder-decoder", EncoderDecoderConfig), ("funnel", FunnelConfig), ("lxmert", LxmertConfig), + ("deberta", DeBERTaConfig), ] ) diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py new file mode 100644 index 000000000000..5a82256f84f8 --- /dev/null +++ b/src/transformers/configuration_deberta.py @@ -0,0 +1,105 @@ +# coding=utf-8 +# Copyright 2020, Microsoft +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DeBERTa model configuration """ + +import logging + +from .configuration_utils import PretrainedConfig + + +__all__ = ["DeBERTaConfig", "DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP"] + +logger = logging.getLogger(__name__) + +DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/config.json", + "microsoft/deberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-large/config.json", +} + + +class DeBERTaConfig(PretrainedConfig): + r""" + :class:`~transformers.DeBERTaConfig` is the configuration class to store the configuration of a + `DeBERTaModel`. + + Arguments: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. + num_attention_heads (int): Number of attention heads for each attention layer in + the Transformer encoder, default: `12`. + intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder, default: `3072`. + hidden_act (str): The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. + hidden_dropout_prob (float): The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler, default: `0.1`. + attention_probs_dropout_prob (float): The dropout ratio for the attention + probabilities, default: `0.1`. + max_position_embeddings (int): The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048), default: `512`. + type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into + `DeBERTa` model, default: `-1`. + initializer_range (int): The sttdev of the _normal_initializer for + initializing all weight matrices, default: `0.02`. + relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. + max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. + padding_idx (int): The value used to pad input_ids, default: `0`. + position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. + pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p", default: "None". + vocab_size (int): The size of the vocabulary, default: `-1`. + """ + model_type = "deberta" + + def __init__( + self, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=-1, + initializer_range=0.02, + relative_attention=False, + max_relative_positions=-1, + padding_idx=0, + position_biased_input=True, + pos_att_type="None", + vocab_size=-1, + layer_norm_eps=1e-7, + **kwargs + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.relative_attention = relative_attention + self.max_relative_positions = max_relative_positions + self.padding_idx = padding_idx + self.position_biased_input = position_biased_input + self.pos_att_type = pos_att_type + self.vocab_size = vocab_size + self.layer_norm_eps = layer_norm_eps diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index e482b170f6d0..a13e8268e042 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -26,6 +26,7 @@ BertGenerationConfig, CamembertConfig, CTRLConfig, + DeBERTaConfig, DistilBertConfig, ElectraConfig, EncoderDecoderConfig, @@ -88,6 +89,7 @@ CamembertModel, ) from .modeling_ctrl import CTRLLMHeadModel, CTRLModel +from .modeling_deberta import DeBERTaForSequenceClassification, DeBERTaModel from .modeling_distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, @@ -221,6 +223,7 @@ (FunnelConfig, FunnelModel), (LxmertConfig, LxmertModel), (BertGenerationConfig, BertGenerationEncoder), + (DeBERTaConfig, DeBERTaModel), ] ) @@ -344,6 +347,7 @@ (XLMConfig, XLMForSequenceClassification), (ElectraConfig, ElectraForSequenceClassification), (FunnelConfig, FunnelForSequenceClassification), + (DeBERTaConfig, DeBERTaForSequenceClassification), ] ) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py new file mode 100644 index 000000000000..9b294546e823 --- /dev/null +++ b/src/transformers/modeling_deberta.py @@ -0,0 +1,328 @@ +# coding=utf-8 +# Copyright 2020 Microsoft +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa model. """ + +import logging + +import torch +from DeBERTa.deberta import BertEmbeddings as DeBERTaEmbeddings +from DeBERTa.deberta import BertEncoder as DeBERTaEncoder +from DeBERTa.deberta import ContextPooler, PoolConfig, StableDropout +from torch import nn +from torch.nn import CrossEntropyLoss + +from .configuration_deberta import DeBERTaConfig +from .file_utils import add_start_docstrings +from .modeling_utils import PreTrainedModel + + +__all__ = [ + "DeBERTaModel", + "DeBERTaForSequenceClassification", + "DeBERTaPreTrainedModel", + "DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST", +] + +logger = logging.getLogger(__name__) + +#################################################### +# This list contrains shortcut names for some of +# the pretrained weights provided with the models +#################################################### +DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-base", + "microsoft/deberta-large", +] + + +class DeBERTaPreTrainedModel(PreTrainedModel): + """ An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + config_class = DeBERTaConfig + base_model_prefix = "deberta" + + def _init_weights(self, module): + """ Initialize the weights """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +DeBERTa_START_DOCSTRING = r""" The DeBERTa model was proposed in + `DeBERTa: Decoding-enhanced BERT with Disentangled Attention`_ + by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of BERT/RoBERTa with two improvements, i.e. + disentangled attention and enhanced mask decoder. With those two improvements, it out perform BERT/RoBERTa on a majority + of tasks with 80GB pre-trianing data. + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`DeBERTa: Decoding-enhanced BERT with Disentangled Attention`: + https://arxiv.org/abs/2006.03654 + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~transformers.DeBERTaConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. +""" + +DeBERTa_INPUTS_DOCSTRING = r""" + Inputs: + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + To match pre-training, DeBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows: + + (a) For sequence pairs: + + ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` + + + (b) For single sequences: + + ``tokens: [CLS] the dog is hairy . [SEP]`` + + Indices can be obtained using :class:`transformers.DeBERTaTokenizer`. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Segment token indices to indicate first and second portions of the inputs. + Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` + corresponds to a `sentence B` token + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + Selected in the range ``[0, config.max_position_embeddings - 1]``. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: + Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DeBERTa_START_DOCSTRING, + DeBERTa_INPUTS_DOCSTRING, +) +class DeBERTaModel(DeBERTaPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during DeBERTa pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = DeBERTaTokenizer.from_pretrained('deberta-base-uncased') + model = DeBERTaModel.from_pretrained('deberta-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + + def __init__(self, config): + super().__init__(config) + + self.embeddings = DeBERTaEmbeddings(config) + self.encoder = DeBERTaEncoder(config) + self.z_steps = 0 + self.config = config + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + ): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + embedding_output = self.embeddings( + input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask + ) + + encoded_layers = self.encoder(embedding_output, attention_mask, output_all_encoded_layers=True) + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + return_att=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + return (sequence_output, sequence_output[0]) + + +@add_start_docstrings( + """DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of + the pooled output) e.g. for GLUE tasks. """, + DeBERTa_START_DOCSTRING, + DeBERTa_INPUTS_DOCSTRING, +) +class DeBERTaForSequenceClassification(DeBERTaPreTrainedModel): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = DeBERTaTokenizer.from_pretrained('deberta-base-uncased') + model = DeBERTaForSequenceClassification.from_pretrained('deberta-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 + outputs = model(input_ids, labels=labels) + loss, logits = outputs[:2] + + """ + + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.bert = DeBERTaModel(config) + pool_config = PoolConfig(self.config) + output_dim = self.bert.config.hidden_size + self.pooler = ContextPooler(pool_config) + output_dim = self.pooler.output_dim() + + self.classifier = torch.nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + labels=None, + output_hidden_states=None, + ): + + encoder_layer, cls = self.bert( + input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids + ) + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = torch.tensor(0).to(logits) + if labels is not None: + if self.num_labels == 1: + # regression task + loss_fn = torch.nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = torch.nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + return (loss, logits) diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index 5de28f70f00a..a892ad31e17f 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -25,6 +25,7 @@ BertGenerationConfig, CamembertConfig, CTRLConfig, + DeBERTaConfig, DistilBertConfig, ElectraConfig, EncoderDecoderConfig, @@ -58,6 +59,7 @@ from .tokenization_bertweet import BertweetTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer +from .tokenization_deberta import DeBERTaTokenizer from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast from .tokenization_flaubert import FlaubertTokenizer @@ -117,6 +119,7 @@ (CTRLConfig, (CTRLTokenizer, None)), (FSMTConfig, (FSMTTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)), + (DeBERTaConfig, (DeBERTaTokenizer, None)), ] ) diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py new file mode 100644 index 000000000000..d2798877af81 --- /dev/null +++ b/src/transformers/tokenization_deberta.py @@ -0,0 +1,220 @@ +# coding=utf-8 +# Copyright 2020 Microsoft. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Tokenization class for model DeBERTa.""" + +import logging +import os + +from DeBERTa.deberta import GPT2Tokenizer + +from .tokenization_utils import PreTrainedTokenizer + + +logger = logging.getLogger(__name__) + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to file names for serializing Tokenizer instances +#################################################### +VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} + +#################################################### +# Mapping from the keyword arguments names of Tokenizer `__init__` +# to pretrained vocabulary URL for all the model shortcut names. +#################################################### +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/bpe_encoder.bin", + "microsoft/deberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-large/bpe_encoder.bin", + } +} + +#################################################### +# Mapping from model shortcut names to max length of inputs +#################################################### +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "microsoft/deberta-base": 512, + "microsoft/deberta-large": 512, +} + +#################################################### +# Mapping from model shortcut names to a dictionary of additional +# keyword arguments for Tokenizer `__init__`. +# To be used for checkpoint specific configurations. +#################################################### +PRETRAINED_INIT_CONFIGURATION = { + "microsoft/deberta-base": {"do_lower_case": False}, + "microsoft/deberta-large": {"do_lower_case": False}, +} + +__all__ = ["DeBERTaTokenizer"] + + +class DeBERTaTokenizer(PreTrainedTokenizer): + r""" + Constructs a XxxTokenizer. + :class:`~transformers.DeBERTaTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece + + Args: + vocab_file: Path to a one-wordpiece-per-line vocabulary file + do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + + def __init__( + self, + vocab_file, + do_lower_case=False, + unk_token="[UNK]", + sep_token="[SEP]", + pad_token="[PAD]", + cls_token="[CLS]", + mask_token="[MASK]", + **kwargs + ): + """Constructs a XxxTokenizer. + + Args: + **vocab_file**: Path to a one-wordpiece-per-line vocabulary file + **do_lower_case**: (`optional`) boolean (default False) + Whether to lower case the input + Only has an effect when do_basic_tokenize=True + """ + super().__init__( + unk_token=unk_token, + sep_token=sep_token, + pad_token=pad_token, + cls_token=cls_token, + mask_token=mask_token, + **kwargs, + ) + + if not os.path.isfile(vocab_file): + raise ValueError( + "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " + "model use `tokenizer = XxxTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) + ) + self.do_lower_case = do_lower_case + self.gpt2_tokenizer = GPT2Tokenizer(vocab_file) + + @property + def vocab_size(self): + return len(self.vocab) + + @property + def vocab(self): + return self.gpt2_tokenizer.vocab + + def get_vocab(self): + vocab = self.vocab.copy() + vocab.update(self.get_added_vocab()) + return vocab + + def _tokenize(self, text): + """ Take as input a string and return a list of strings (tokens) for words/sub-words + """ + if self.do_lower_case: + text = text.lower() + return self.gpt2_tokenizer.tokenize(text) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.vocab.get(token, self.vocab.get(self.unk_token)) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.gpt2_tokenizer.sym(index) if index < self.vocab_size else self.unk_token + + def convert_tokens_to_string(self, tokens): + """ Converts a sequence of tokens (string) in a single string. """ + return self.gpt2_tokenizer.decode(tokens) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. + A BERT sequence has the following format: + + - single sequence: [CLS] X [SEP] + - pair of sequences: [CLS] A [SEP] B [SEP] + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + token_ids_1 + sep + + def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_special_tokens=False): + """ + Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. + + Args: + token_ids_0: list of ids (must not contain special tokens) + token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids + for sequence pairs + already_has_special_tokens: (default False) Set to True if the token list is already formated with + special tokens for the model + + Returns: + A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + if token_ids_1 is not None: + raise ValueError( + "You should not supply a second sequence if the provided sequence of " + "ids is already formated with special tokens for the model." + ) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + + if token_ids_1 is not None: + return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1] + + def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. + A BERT sequence pair mask has the following format: + 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence + + if token_ids_1 is None, only returns the first portion of the mask (0's). + """ + sep = [self.sep_token_id] + cls = [self.cls_token_id] + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + + def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs): + add_prefix_space = kwargs.pop("add_prefix_space", False) + if is_pretokenized or add_prefix_space: + text = " " + text + return (text, kwargs) + + def save_vocabulary(self, vocab_path): + """Save the tokenizer vocabulary to a directory or file.""" + if os.path.isdir(vocab_path): + vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"]) + else: + vocab_file = vocab_path + self.gpt2_tokenizer.save_pretrained(vocab_file) + return (vocab_file,) diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py new file mode 100644 index 000000000000..a28454efcb68 --- /dev/null +++ b/tests/test_modeling_deberta.py @@ -0,0 +1,269 @@ +# coding=utf-8 +# Copyright 2018 Microsoft Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +from transformers import is_torch_available +from transformers.testing_utils import require_torch, torch_device + +from .test_configuration_common import ConfigTester +from .test_modeling_common import ModelTesterMixin, ids_tensor + + +if is_torch_available(): + from transformers import ( + DeBERTaConfig, + DeBERTaModel, + # XxxForMaskedLM, + # XxxForQuestionAnswering, + DeBERTaForSequenceClassification, + # XxxForTokenClassification, + ) + from transformers.modeling_deberta import DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST + + +@require_torch +class DeBERTaModelTest(ModelTesterMixin, unittest.TestCase): + + all_model_classes = ( + ( + DeBERTaModel, + DeBERTaForSequenceClassification, + ) # , DeBERTaForMaskedLM, DeBERTaForQuestionAnswering, DeBERTaForTokenClassification) + if is_torch_available() + else () + ) + + test_torchscript = False + test_pruning = False + test_resize_embeddings = False + test_head_masking = False + test_missing_keys = False + is_encoder_decoder = False + + class DeBERTaModelTester(object): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + relative_attention=False, + position_biased_input=True, + pos_att_type="None", + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_input_mask = use_input_mask + self.use_token_type_ids = use_token_type_ids + self.use_labels = use_labels + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.initializer_range = initializer_range + self.num_labels = num_labels + self.num_choices = num_choices + self.relative_attention = relative_attention + self.position_biased_input = position_biased_input + self.pos_att_type = pos_att_type + self.scope = scope + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = DeBERTaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + relative_attention=self.relative_attention, + position_biased_input=self.position_biased_input, + pos_att_type=self.pos_att_type, + ) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def check_loss_output(self, result): + self.parent.assertListEqual(list(result["loss"].size()), []) + + def create_and_check_deberta_model( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = DeBERTaModel(config=config) + model.to(torch_device) + model.eval() + sequence_output, cls = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) + sequence_output, cls = model(input_ids, token_type_ids=token_type_ids) + sequence_output, cls = model(input_ids) + + result = { + "sequence_output": sequence_output, + } + self.parent.assertListEqual( + list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] + ) + + def create_and_check_deberta_for_sequence_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + config.num_labels = self.num_labels + model = DeBERTaForSequenceClassification(config) + model.to(torch_device) + model.eval() + loss, logits = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels + ) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) + self.check_loss_output(result) + + # TODO: TBD + def create_and_check_deberta_for_masked_lm( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + pass + + # TODO: TBD + def create_and_check_deberta_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + pass + + # TODO: TBD + def create_and_check_deberta_for_token_classification( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + pass + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + def setUp(self): + self.model_tester = DeBERTaModelTest.DeBERTaModelTester(self) + self.config_tester = ConfigTester(self, config_class=DeBERTaConfig, hidden_size=37) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_deberta_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_model(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_sequence_classification(*config_and_inputs) + + @unittest.skip(reason="Model not available yet") + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_masked_lm(*config_and_inputs) + + @unittest.skip(reason="Model not available yet") + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_question_answering(*config_and_inputs) + + @unittest.skip(reason="Model not available yet") + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs) + + @unittest.skip + def test_attention_outputs(self): + pass + + @unittest.skip + def test_hidden_states_output(self): + pass + + @unittest.skip + def test_inputs_embeds(self): + pass + + @unittest.skip + def test_model_common_attributes(self): + pass + + # @slow + def test_model_from_pretrained(self): + for model_name in DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = DeBERTaModel.from_pretrained(model_name) + self.assertIsNotNone(model) diff --git a/tests/test_tokenization_deberta.py b/tests/test_tokenization_deberta.py new file mode 100644 index 000000000000..b2cbc769b9b2 --- /dev/null +++ b/tests/test_tokenization_deberta.py @@ -0,0 +1,72 @@ +# coding=utf-8 +# Copyright 2018 Microsoft. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +import unittest +from typing import Tuple + +from transformers.tokenization_deberta import DeBERTaTokenizer + +from .test_tokenization_common import TokenizerTesterMixin + + +class DeBERTaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + + tokenizer_class = DeBERTaTokenizer + + def setUp(self): + super().setUp() + + def get_tokenizer(self, name="microsoft/deberta-base", **kwargs): + return DeBERTaTokenizer.from_pretrained(name, **kwargs) + + def get_input_output_texts(self, tokenizer): + input_text = "lower newer" + output_text = "lower newer" + return input_text, output_text + + def get_clean_sequence(self, tokenizer, with_prefix_space=False, max_length=20) -> Tuple[str, list]: + toks = [ + (i, tokenizer.decode([i], clean_up_tokenization_spaces=False)) + for i in range(5, min(len(tokenizer), 50260)) + ] + toks = list(filter(lambda t: re.match(r"^[ a-zA-Z]+$", t[1]), toks)) + toks = list(filter(lambda t: [t[0]] == tokenizer.encode(t[1], add_special_tokens=False), toks)) + if max_length is not None and len(toks) > max_length: + toks = toks[:max_length] + # toks_str = [t[1] for t in toks] + toks_ids = [t[0] for t in toks] + + # Ensure consistency + output_txt = tokenizer.decode(toks_ids, clean_up_tokenization_spaces=False) + if " " not in output_txt and len(toks_ids) > 1: + output_txt = ( + tokenizer.decode([toks_ids[0]], clean_up_tokenization_spaces=False) + + " " + + tokenizer.decode(toks_ids[1:], clean_up_tokenization_spaces=False) + ) + if with_prefix_space and not output_txt.startswith(" "): + output_txt = " " + output_txt + output_ids = tokenizer.encode(output_txt, add_special_tokens=False) + return output_txt, output_ids + + def test_full_tokenizer(self): + tokenizer = self.get_tokenizer("microsoft/deberta-base") + input_str = "UNwant\u00E9d,running" + tokens = tokenizer.tokenize(input_str) + token_ids = tokenizer.convert_tokens_to_ids(tokens) + + self.assertEqual(tokenizer.decode(token_ids), input_str) From ce0b94a52df167954c778f6cb4db4b225561ac0f Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Tue, 4 Aug 2020 13:03:00 -0400 Subject: [PATCH 02/18] Remove dependency of deberta --- setup.py | 1 - src/transformers/modeling_deberta.py | 938 ++++++++++++++++++++++- src/transformers/tokenization_deberta.py | 426 +++++++++- tests/test_modeling_deberta.py | 22 +- tests/test_tokenization_deberta.py | 2 + 5 files changed, 1361 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 1559accc8d77..8527abdec7e3 100644 --- a/setup.py +++ b/setup.py @@ -127,7 +127,6 @@ "sentencepiece != 0.1.92", # for XLM "sacremoses", - "deberta", ], extras_require=extras, entry_points={ diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 9b294546e823..bbf63e79c4d9 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -14,12 +14,15 @@ # limitations under the License. """ PyTorch DeBERTa model. """ +import copy +import json import logging +import math +import os +from collections import Sequence import torch -from DeBERTa.deberta import BertEmbeddings as DeBERTaEmbeddings -from DeBERTa.deberta import BertEncoder as DeBERTaEncoder -from DeBERTa.deberta import ContextPooler, PoolConfig, StableDropout +from packaging import version from torch import nn from torch.nn import CrossEntropyLoss @@ -28,6 +31,11 @@ from .modeling_utils import PreTrainedModel +if version.Version(torch.__version__) >= version.Version("1.0.0"): + from torch import _softmax_backward_data as _softmax_backward_data +else: + from torch import softmax_backward_data as _softmax_backward_data + __all__ = [ "DeBERTaModel", "DeBERTaForSequenceClassification", @@ -47,6 +55,924 @@ ] +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +def linear_act(x): + return x + + +ACT2FN = { + "gelu": gelu, + "relu": torch.nn.functional.relu, + "swish": swish, + "tanh": torch.nn.functional.tanh, + "linear": linear_act, + "sigmoid": torch.sigmoid, +} + + +def traceable(cls): + """ Decorator over customer functions + There is an issue for tracing customer python torch Function, using this decorator to work around it. + e.g. + @traceable + class MyOp(torch.autograd.Function): + xxx + """ + + class _Function(object): + @staticmethod + def apply(*args): + jit_trace = os.getenv("JIT_TRACE", "False").lower() == "true" + if jit_trace: + return cls.forward(_Function, *args) + else: + return cls.apply(*args) + + @staticmethod + def save_for_backward(*args): + pass + + _Function.__name__ = cls.__name__ + _Function.__doc__ = cls.__doc__ + return _Function + + +class AbsModelConfig(object): + def __init__(self): + pass + + @classmethod + def from_dict(cls, json_object): + """Constructs a `ModelConfig` from a Python dictionary of parameters.""" + config = cls() + for key, value in json_object.items(): + if isinstance(value, dict): + value = AbsModelConfig.from_dict(value) + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `ModelConfig` from a json file of parameters.""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + + def _json_default(obj): + if isinstance(obj, AbsModelConfig): + return obj.__dict__ + + return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n" + + +class ModelConfig(AbsModelConfig): + """Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model. + + Attributes: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. + num_attention_heads (int): Number of attention heads for each attention layer in + the Transformer encoder, default: `12`. + intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder, default: `3072`. + hidden_act (str): The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. + hidden_dropout_prob (float): The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler, default: `0.1`. + attention_probs_dropout_prob (float): The dropout ratio for the attention + probabilities, default: `0.1`. + max_position_embeddings (int): The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048), default: `512`. + type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into + `DeBERTa` model, default: `-1`. + initializer_range (int): The sttdev of the _normal_initializer for + initializing all weight matrices, default: `0.02`. + relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. + max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. + padding_idx (int): The value used to pad input_ids, default: `0`. + position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. + pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None". + + + """ + + def __init__(self): + """Constructs ModelConfig. + + """ + + self.hidden_size = 768 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.hidden_act = "gelu" + self.intermediate_size = 3072 + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 0 + self.initializer_range = 0.02 + self.layer_norm_eps = 1e-7 + self.padding_idx = 0 + self.vocab_size = -1 + + +class PoolConfig(AbsModelConfig): + """Configuration class to store the configuration of `pool layer`. + + Parameters: + config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config. + + Attributes: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + + dropout (float): The dropout rate applied on the output of `[CLS]` token, + + hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh']. + + Example:: + # Here is the content of an exmple model config file in json format + + { + "hidden_size": 768, + "num_hidden_layers" 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + ... + "pooling": { + "hidden_size": 768, + "hidden_act": "gelu", + "dropout": 0.1 + } + } + + """ + + def __init__(self, config=None): + """Constructs PoolConfig. + + Args: + `config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config. + """ + + self.hidden_size = 768 + self.dropout = 0 + self.hidden_act = "gelu" + if config: + pool_config = getattr(config, "pooling", config) + if isinstance(pool_config, dict): + pool_config = AbsModelConfig.from_dict(pool_config) + self.hidden_size = getattr(pool_config, "hidden_size", config.hidden_size) + self.dropout = getattr(pool_config, "dropout", 0) + self.hidden_act = getattr(pool_config, "hidden_act", "gelu") + + +class TraceMode: + """ Trace context used when tracing modules contains customer operators/Functions + """ + + def __enter__(self): + os.environ["JIT_TRACE"] = "True" + return self + + def __exit__(self, exp_value, exp_type, trace): + del os.environ["JIT_TRACE"] + + +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = StableDropout(config.dropout) + self.config = config + + def forward(self, hidden_states, mask=None): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.hidden_act](pooled_output) + return pooled_output + + def output_dim(self): + return self.config.hidden_size + + +@traceable +class XSoftmax(torch.autograd.Function): + """ Masked Softmax which is optimized for saving memory + + Args: + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation. + dim (int): The dimenssion that will apply softmax. + Example:: + import torch + from DeBERTa.deberta import XSoftmax + # Make a tensor + x = torch.randn([4,20,100]) + # Create a mask + mask = (x>0).int() + y = XSoftmax.apply(x, mask, dim=-1) + """ + + @staticmethod + def forward(self, input, mask, dim): + """ + """ + + self.dim = dim + if version.Version(torch.__version__) >= version.Version("1.2.0a"): + rmask = ~(mask.bool()) + else: + rmask = (1 - mask).byte() # This line is not supported by Onnx tracing. + + output = input.masked_fill(rmask, float("-inf")) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + """ + """ + + (output,) = self.saved_tensors + inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + return inputGrad, None, None + + +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + if version.Version(torch.__version__) >= version.Version("1.2.0a"): + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + else: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).byte() + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +@traceable +class XDropout(torch.autograd.Function): + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +class StableDropout(torch.nn.Module): + """ Optimized dropout module for stabilizing the training + + Args: + + drop_prob (float): the dropout probabilities + + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ Call the module + + Args: + x (:obj:`torch.tensor`): The input tensor to apply dropout + + + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +def MaskedLayerNorm(layerNorm, input, mask=None): + """ Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. + Args: + layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` + + Example:: + + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.BertLayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) + + """ + output = layerNorm(input).to(input) + if mask is None: + return output + if mask.dim() != input.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(output.dtype) + return output * mask + + +class BertLayerNorm(nn.Module): + """LayerNorm module in the TF style (epsilon inside the square root). + """ + + def __init__(self, size, eps=1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(size)) + self.bias = nn.Parameter(torch.zeros(size)) + self.variance_epsilon = eps + + def forward(self, x): + input_type = x.dtype + x = x.float() + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + x = x.to(input_type) + y = self.weight * x + self.bias + return y + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_states, mask=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states += input_states + hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = BertSelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if return_att: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states, attention_mask) + + if return_att: + return (attention_output, att_matrix) + else: + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = ( + ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act + ) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config): + super(BertOutput, self).__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_states, mask=None): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states += input_states + hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config): + super(BertLayer, self).__init__() + self.attention = BertAttention(config) + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + return_att=return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if return_att: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, attention_mask) + if return_att: + return (layer_output, att_matrix) + else: + return layer_output + + +class DeBERTaEncoder(nn.Module): + """ Modified BertEncoder with relative position bias support + """ + + def __init__(self, config): + super().__init__() + layer = BertLayer(config) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_all_encoded_layers=True, + return_att=False, + query_states=None, + relative_pos=None, + ): + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_encoder_layers = [] + att_matrixs = [] + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + for i, layer_module in enumerate(self.layer): + output_states = layer_module( + next_kv, + attention_mask, + return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + ) + if return_att: + output_states, att_m = output_states + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_all_encoded_layers: + all_encoder_layers.append(output_states) + if return_att: + att_matrixs.append(att_m) + if not output_all_encoded_layers: + all_encoder_layers.append(output_states) + if return_att: + att_matrixs.append(att_m) + if return_att: + return (all_encoder_layers, att_matrixs) + else: + return all_encoder_layers + + +def build_relative_position(query_size, key_size, device): + """ Build relative position according to the query and key + + We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size), + The relative positions from query to key is + :math:`R_{q \\rightarrow k} = P_q - P_k` + + Args: + query_size (int): the length of query + key_size (int): the length of key + + Return: + :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + + q_ids = torch.arange(query_size, dtype=torch.long, device=device) + k_ids = torch.arange(key_size, dtype=torch.long, device=device) + rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(torch.nn.Module): + """ Disentangled self-attention module + + Parameters: + config (:obj:`str`): + A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ + for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) + self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.pos_att_type = [x.strip() for x in getattr(config, "pos_att_type", "none").lower().split("|")] # c2p|p2c + + self.relative_attention = getattr(config, "relative_attention", False) + self.talking_head = getattr(config, "talking_head", False) + + if self.talking_head: + self.head_logits_proj = torch.nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + self.head_weights_proj = torch.nn.Linear( + config.num_attention_heads, config.num_attention_heads, bias=False + ) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_proj = torch.nn.Linear(config.hidden_size, self.all_head_size, bias=False) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_q_proj = torch.nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + ): + """ Call the module + + Args: + hidden_states (:obj:`torch.FloatTensor`): + Input states to the module usally the output from previous layer, it will be the Q,K and V in `Attention(Q,K,V)` + + attention_mask (:obj:`torch.ByteTensor`): + An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maxium sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` th token. + + return_att (:obj:`bool`, optional): + Whether return the attention maxitrix. + + query_states (:obj:`torch.FloatTensor`, optional): + The `Q` state in `Attention(Q,K,V)`. + + relative_pos (:obj:`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with values ranging in [`-max_relative_positions`, `max_relative_positions`]. + + rel_embeddings (:obj:`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [:math:`2 \\times \\text{max_relative_positions}`, `hidden_size`]. + + + """ + if query_states is None: + qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) + query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) + else: + + def linear(w, b, x): + if b is not None: + return torch.matmul(x, w.t()) + b.t() + else: + return torch.matmul(x, w.t()) # + b.t() + + ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) + qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] + qkvb = [None] * 3 + + q = linear(qkvw[0], qkvb[0], query_states) + k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] + query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + + query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0)) + value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0)) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + if "p2p" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + query_layer = query_layer / scale + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + + if rel_att is not None: + attention_scores = attention_scores + rel_att + + # bxhxlxd + if self.talking_head: + attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + if self.talking_head: + attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(*new_context_layer_shape) + if return_att: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bxhxqxk + elif relative_pos.dim() != 4: + raise ValueError(f"Relative postion ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) + relative_pos = relative_pos.long().to(query_layer.device) + rel_embeddings = rel_embeddings[ + self.max_relative_positions - att_span : self.max_relative_positions + att_span, : + ].unsqueeze(0) + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.pos_proj(rel_embeddings) + pos_key_layer = self.transpose_for_scores(pos_key_layer) + + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.pos_q_proj(rel_embeddings) + pos_query_layer = self.transpose_for_scores(pos_query_layer) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) + score += c2p_att + + # position->content + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) + if query_layer.size(-2) != key_layer.size(-2): + r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) + else: + r_pos = relative_pos + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + if query_layer.size(-2) != key_layer.size(-2): + pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) + + if "p2c" in self.pos_att_type: + p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) + ).transpose(-1, -2) + if query_layer.size(-2) != key_layer.size(-2): + p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) + score += p2c_att + + return score + + +class DeBERTaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + + def __init__(self, config): + super().__init__() + padding_idx = getattr(config, "padding_idx", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=padding_idx) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.output_to_half = False + self.config = config + + def forward(self, input_ids, token_type_ids=None, position_ids=None, mask=None): + seq_length = input_ids.size(1) + if position_ids is None: + position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(words_embeddings) + + embeddings = words_embeddings + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask) + embeddings = self.dropout(embeddings) + return embeddings + + class DeBERTaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. @@ -191,10 +1117,11 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, + return_tuple=True, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is None and inputs_embeds is None: + elif (input_ids is None) and (inputs_embeds is not None): raise ValueError("You have to specify either input_ids or inputs_embeds") if attention_mask is None: @@ -228,7 +1155,7 @@ def forward( sequence_output = encoded_layers[-1] - return (sequence_output, sequence_output[0]) + return (sequence_output, sequence_output[:, 0]) @add_start_docstrings( @@ -296,6 +1223,7 @@ def forward( position_ids=None, labels=None, output_hidden_states=None, + return_tuple=True, ): encoder_layer, cls = self.bert( diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index d2798877af81..0f76d3abc3d8 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -16,12 +16,24 @@ import logging import os +import pathlib +import random +import unicodedata +from functools import lru_cache +from zipfile import ZipFile -from DeBERTa.deberta import GPT2Tokenizer +import requests +import tqdm from .tokenization_utils import PreTrainedTokenizer +try: + import regex as re +except ImportError: + raise ImportError("Please install regex with: pip install regex") + + logger = logging.getLogger(__name__) #################################################### @@ -62,6 +74,416 @@ __all__ = ["DeBERTaTokenizer"] +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip([tuple(k) for k in bpe_merges], range(len(bpe_merges)))) + self.cache = {} + self.random = random.Random(0) + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def split_to_words(self, text): + return list(re.findall(self.pat, text)) + + def encode(self, text): + bpe_tokens = [] + for token in self.split_to_words(text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text + + +def get_encoder(encoder, vocab): + return Encoder(encoder=encoder, bpe_merges=vocab,) + + +def _is_whitespace(char): + """Checks whether `chars` is a whitespace character.""" + # \t, \n, and \r are technically contorl characters but we treat them + # as whitespace since they are generally considered as such. + if char == " " or char == "\t" or char == "\n" or char == "\r": + return True + cat = unicodedata.category(char) + if cat == "Zs": + return True + return False + + +def _is_control(char): + """Checks whether `chars` is a control character.""" + # These are technically control characters but we count them as whitespace + # characters. + if char == "\t" or char == "\n" or char == "\r": + return False + cat = unicodedata.category(char) + if cat.startswith("C"): + return True + return False + + +def _is_punctuation(char): + """Checks whether `chars` is a punctuation character.""" + cp = ord(char) + # We treat all non-letter/number ASCII as punctuation. + # Characters such as "^", "$", and "`" are not in the Unicode + # Punctuation class but we treat them as punctuation anyways, for + # consistency. + if (cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126): + return True + cat = unicodedata.category(char) + if cat.startswith("P"): + return True + return False + + +def download_asset(name, tag=None, no_cache=False, cache_dir=None): + _tag = tag + if _tag is None: + _tag = "latest" + if not cache_dir: + cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") + os.makedirs(cache_dir, exist_ok=True) + output = os.path.join(cache_dir, name) + if os.path.exists(output) and (not no_cache): + return output + + repo = "https://api.github.com/repos/microsoft/DeBERTa/releases" + releases = requests.get(repo).json() + if tag and tag != "latest": + release = [r for r in releases if r["name"].lower() == tag.lower()] + if len(release) != 1: + raise Exception(f"{tag} can't be found in the repository.") + else: + release = releases[0] + asset = [s for s in release["assets"] if s["name"].lower() == name.lower()] + if len(asset) != 1: + raise Exception(f"{name} can't be found in the release.") + url = asset[0]["url"] + headers = {} + headers["Accept"] = "application/octet-stream" + resp = requests.get(url, stream=True, headers=headers) + if resp.status_code != 200: + raise Exception(f"Request for {url} return {resp.status_code}, {resp.text}") + try: + with open(output, "wb") as fs: + progress = tqdm( + total=int(resp.headers["Content-Length"]) if "Content-Length" in resp.headers else -1, + ncols=80, + desc=f"Downloading {name}", + ) + for c in resp.iter_content(chunk_size=1024 * 1024): + fs.write(c) + progress.update(len(c)) + progress.close() + except Exception: + os.remove(output) + raise + + return output + + +def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None): + import torch + + if name is None: + name = "bpe_encoder" + + model_path = name + if model_path and (not os.path.exists(model_path)) and not (("/" in model_path) or ("\\" in model_path)): + _tag = tag + if _tag is None: + _tag = "latest" + if not cache_dir: + cache_dir = os.path.join(pathlib.Path.home(), f".~DeBERTa/assets/{_tag}/") + os.makedirs(cache_dir, exist_ok=True) + out_dir = os.path.join(cache_dir, name) + model_path = os.path.join(out_dir, "bpe_encoder.bin") + if (not os.path.exists(model_path)) or no_cache: + asset = download_asset(name + ".zip", tag=tag, no_cache=no_cache, cache_dir=cache_dir) + with ZipFile(asset, "r") as zipf: + for zip_info in zipf.infolist(): + if zip_info.filename[-1] == "/": + continue + zip_info.filename = os.path.basename(zip_info.filename) + zipf.extract(zip_info, out_dir) + elif not model_path: + return None, None + + encoder_state = torch.load(model_path) + return encoder_state + + +class GPT2Tokenizer(object): + """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + + Args: + vocab_file (:obj:`str`, optional): + The local path of vocabulary package or the release name of vocabulary in `DeBERTa GitHub releases `_, \ + e.g. "bpe_encoder", default: `None`. + + If it's `None`, then it will download the vocabulary in the latest release from GitHub. The vocabulary file is a \ + state dictionary with three items, "dict_map", "vocab", "encoder" which correspond to three files used in `RoBERTa`, i.e. `dict.txt`, `vocab.txt` and `encoder.json`. \ + The difference between our wrapped GPT2 tokenizer and RoBERTa wrapped tokenizer are, + + - Special tokens, unlike `RoBERTa` which use ``, `` as the `start` token and `end` token of a sentence. We use `[CLS]` and `[SEP]` as the `start` and `end`\ + token of input sentence which is the same as `BERT`. + + - We remapped the token ids in our dictionary with regarding to the new special tokens, `[PAD]` => 0, `[CLS]` => 1, `[SEP]` => 2, `[UNK]` => 3, `[MASK]` => 50264 + + do_lower_case (:obj:`bool`, optional): + Whether to convert inputs to lower case. **Not used in GPT2 tokenizer**. + + special_tokens (:obj:`list`, optional): + List of special tokens to be added to the end of the vocabulary. + + + """ + + def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None): + self.pad_token = "[PAD]" + self.sep_token = "[SEP]" + self.unk_token = "[UNK]" + self.cls_token = "[CLS]" + + self.symbols = [] + self.count = [] + self.indices = {} + self.pad_token_id = self.add_symbol(self.pad_token) + self.cls_token_id = self.add_symbol(self.cls_token) + self.sep_token_id = self.add_symbol(self.sep_token) + self.unk_token_id = self.add_symbol(self.unk_token) + + self.gpt2_encoder = load_vocab(vocab_file) + self.bpe = get_encoder(self.gpt2_encoder["encoder"], self.gpt2_encoder["vocab"]) + for w, n in self.gpt2_encoder["dict_map"]: + self.add_symbol(w, n) + + self.mask_token = "[MASK]" + self.mask_id = self.add_symbol(self.mask_token) + self.special_tokens = ["[MASK]", "[SEP]", "[PAD]", "[UNK]", "[CLS]"] + if special_tokens is not None: + for t in special_tokens: + self.add_special_token(t) + + self.vocab = self.indices + self.ids_to_tokens = self.symbols + + def tokenize(self, text): + """ Convert an input text to tokens. + + Args: + text (:obj:`str`): input text to be tokenized. + + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + """ + bpe = self._encode(text) + + return [t for t in bpe.split(" ") if t] + + def convert_tokens_to_ids(self, tokens): + """ Convert list of tokens to ids. + Args: + tokens (:obj:`list`): list of tokens + + Returns: + List of ids + """ + + return [self.vocab[t] for t in tokens] + + def convert_ids_to_tokens(self, ids): + """ Convert list of ids to tokens. + Args: + ids (:obj:`list`): list of ids + + Returns: + List of tokens + """ + + tokens = [] + for i in ids: + tokens.append(self.ids_to_tokens[i]) + return tokens + + def split_to_words(self, text): + return self.bpe.split_to_words(text) + + def decode(self, tokens): + """ Decode list of tokens to text strings. + Args: + tokens (:obj:`list`): list of tokens. + + Returns: + Text string corresponds to the input tokens. + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + >>> tokenizer.decode(tokens) + 'Hello world!' + """ + return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) + + def add_special_token(self, token): + """Adds a special token to the dictionary. + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + + Returns: + The id of new token in the vocabulary. + + """ + self.special_tokens.append(token) + return self.add_symbol(token) + + def part_of_whole_word(self, token, is_bos=False): + if is_bos: + return True + s = self._decode(token) + if len(s) == 1 and (_is_whitespace(list(s)[0]) or _is_control(list(s)[0]) or _is_punctuation(list(s)[0])): + return False + + return not s.startswith(" ") + + def sym(self, id): + return self.ids_to_tokens[id] + + def id(self, sym): + return self.vocab[sym] + + def _encode(self, x: str) -> str: + return " ".join(map(str, self.bpe.encode(x))) + + def _decode(self, x: str) -> str: + return self.bpe.decode(map(int, x.split())) + + def add_symbol(self, word, n=1): + """Adds a word to the dictionary. + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. + + Returns: + The id of the new word. + + """ + if word in self.indices: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def save_pretrained(self, path: str): + import torch + + torch.save(self.gpt2_encoder, path) + + class DeBERTaTokenizer(PreTrainedTokenizer): r""" Constructs a XxxTokenizer. @@ -183,7 +605,7 @@ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_spe "You should not supply a second sequence if the provided sequence of " "ids is already formated with special tokens for the model." ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0,)) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index a28454efcb68..b86580b025d4 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -17,7 +17,7 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, torch_device +from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -184,24 +184,6 @@ def create_and_check_deberta_for_sequence_classification( self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_labels]) self.check_loss_output(result) - # TODO: TBD - def create_and_check_deberta_for_masked_lm( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - pass - - # TODO: TBD - def create_and_check_deberta_for_question_answering( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - pass - - # TODO: TBD - def create_and_check_deberta_for_token_classification( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - ): - pass - def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -262,7 +244,7 @@ def test_inputs_embeds(self): def test_model_common_attributes(self): pass - # @slow + @slow def test_model_from_pretrained(self): for model_name in DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = DeBERTaModel.from_pretrained(model_name) diff --git a/tests/test_tokenization_deberta.py b/tests/test_tokenization_deberta.py index b2cbc769b9b2..e2085396512b 100644 --- a/tests/test_tokenization_deberta.py +++ b/tests/test_tokenization_deberta.py @@ -18,11 +18,13 @@ import unittest from typing import Tuple +from transformers.testing_utils import require_torch from transformers.tokenization_deberta import DeBERTaTokenizer from .test_tokenization_common import TokenizerTesterMixin +@require_torch class DeBERTaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = DeBERTaTokenizer From 525d6fd142a8b4a61068f549a3b2de78e50ec9b2 Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Wed, 2 Sep 2020 15:33:47 -0400 Subject: [PATCH 03/18] Address comments --- src/transformers/__init__.py | 15 +- src/transformers/configuration_auto.py | 2 +- src/transformers/configuration_deberta.py | 62 ++--- src/transformers/modeling_deberta.py | 265 +++++++++------------- src/transformers/tokenization_deberta.py | 122 +++++----- tests/test_modeling_deberta.py | 11 +- 6 files changed, 219 insertions(+), 258 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4722bbcb26d9..1ce0d4fcfa3e 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -25,7 +25,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig -from .configuration_deberta import DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig +from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -303,6 +303,12 @@ CamembertModel, ) from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel + from .modeling_deberta import ( + DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + DeBERTaForSequenceClassification, + DeBERTaModel, + DeBERTaPreTrainedModel, + ) from .modeling_distilbert import ( DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, DistilBertForMaskedLM, @@ -481,13 +487,6 @@ load_tf_weights_in_xlnet, ) - from .modeling_deberta import ( - DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST, - DeBERTaPreTrainedModel, - DeBERTaModel, - DeBERTaForSequenceClassification, - ) - # Optimization from .optimization import ( Adafactor, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 7962aa9ed6dc..e59e16b76539 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -23,7 +23,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig -from .configuration_deberta import DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig +from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig from .configuration_encoder_decoder import EncoderDecoderConfig diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 5a82256f84f8..46cf676a7156 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -19,11 +19,11 @@ from .configuration_utils import PretrainedConfig -__all__ = ["DeBERTaConfig", "DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP"] +__all__ = ["DeBERTaConfig", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"] logger = logging.getLogger(__name__) -DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP = { +DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { "microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/config.json", "microsoft/deberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-large/config.json", } @@ -31,35 +31,35 @@ class DeBERTaConfig(PretrainedConfig): r""" - :class:`~transformers.DeBERTaConfig` is the configuration class to store the configuration of a - `DeBERTaModel`. + :class:`~transformers.DeBERTaConfig` is the configuration class to store the configuration of a + `DeBERTaModel`. - Arguments: - hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. - num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. - num_attention_heads (int): Number of attention heads for each attention layer in - the Transformer encoder, default: `12`. - intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder, default: `3072`. - hidden_act (str): The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. - hidden_dropout_prob (float): The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler, default: `0.1`. - attention_probs_dropout_prob (float): The dropout ratio for the attention - probabilities, default: `0.1`. - max_position_embeddings (int): The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048), default: `512`. - type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into - `DeBERTa` model, default: `-1`. - initializer_range (int): The sttdev of the _normal_initializer for - initializing all weight matrices, default: `0.02`. - relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. - max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. - padding_idx (int): The value used to pad input_ids, default: `0`. - position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. - pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p", default: "None". - vocab_size (int): The size of the vocabulary, default: `-1`. + Arguments: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. + num_attention_heads (int): Number of attention heads for each attention layer in + the Transformer encoder, default: `12`. + intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder, default: `3072`. + hidden_act (str): The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. + hidden_dropout_prob (float): The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler, default: `0.1`. + attention_probs_dropout_prob (float): The dropout ratio for the attention + probabilities, default: `0.1`. + max_position_embeddings (int): The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048), default: `512`. + type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into + `DeBERTa` model, default: `0`. + initializer_range (int): The sttdev of the _normal_initializer for + initializing all weight matrices, default: `0.02`. + relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. + max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. + padding_idx (int): The value used to pad input_ids, default: `0`. + position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. + pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p", default: "None". + vocab_size (int): The size of the vocabulary, default: `-1`. """ model_type = "deberta" @@ -73,7 +73,7 @@ def __init__( hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, - type_vocab_size=-1, + type_vocab_size=0, initializer_range=0.02, relative_attention=False, max_relative_positions=-1, diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index bbf63e79c4d9..dc00e1942e51 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -28,6 +28,7 @@ from .configuration_deberta import DeBERTaConfig from .file_utils import add_start_docstrings +from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput from .modeling_utils import PreTrainedModel @@ -40,7 +41,7 @@ "DeBERTaModel", "DeBERTaForSequenceClassification", "DeBERTaPreTrainedModel", - "DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST", + "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", ] logger = logging.getLogger(__name__) @@ -49,7 +50,7 @@ # This list contrains shortcut names for some of # the pretrained weights provided with the models #################################################### -DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST = [ +DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", "microsoft/deberta-large", ] @@ -59,7 +60,7 @@ def gelu(x): """Implementation of the gelu activation function. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - """ + """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) @@ -82,13 +83,13 @@ def linear_act(x): def traceable(cls): - """ Decorator over customer functions - There is an issue for tracing customer python torch Function, using this decorator to work around it. - e.g. - @traceable - class MyOp(torch.autograd.Function): - xxx - """ + """Decorator over customer functions + There is an issue for tracing customer python torch Function, using this decorator to work around it. + e.g. + @traceable + class MyOp(torch.autograd.Function): + xxx + """ class _Function(object): @staticmethod @@ -147,86 +148,34 @@ def _json_default(obj): return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n" -class ModelConfig(AbsModelConfig): - """Configuration class to store the configuration of a :class:`~DeBERTa.deberta.DeBERTa` model. - - Attributes: - hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. - num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. - num_attention_heads (int): Number of attention heads for each attention layer in - the Transformer encoder, default: `12`. - intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder, default: `3072`. - hidden_act (str): The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. - hidden_dropout_prob (float): The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler, default: `0.1`. - attention_probs_dropout_prob (float): The dropout ratio for the attention - probabilities, default: `0.1`. - max_position_embeddings (int): The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048), default: `512`. - type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into - `DeBERTa` model, default: `-1`. - initializer_range (int): The sttdev of the _normal_initializer for - initializing all weight matrices, default: `0.02`. - relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. - max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. - padding_idx (int): The value used to pad input_ids, default: `0`. - position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. - pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p"., default: "None". - - - """ - - def __init__(self): - """Constructs ModelConfig. - - """ - - self.hidden_size = 768 - self.num_hidden_layers = 12 - self.num_attention_heads = 12 - self.hidden_act = "gelu" - self.intermediate_size = 3072 - self.hidden_dropout_prob = 0.1 - self.attention_probs_dropout_prob = 0.1 - self.max_position_embeddings = 512 - self.type_vocab_size = 0 - self.initializer_range = 0.02 - self.layer_norm_eps = 1e-7 - self.padding_idx = 0 - self.vocab_size = -1 - - class PoolConfig(AbsModelConfig): """Configuration class to store the configuration of `pool layer`. - Parameters: - config (:class:`~DeBERTa.deberta.ModelConfig`): The model config. The field of pool config will be initalized with the `pooling` field in model config. + Parameters: + config (:class:`~transformers.DeBERTaConfig`): The model config. The field of pool config will be initialized with the `pooling` field in model config. - Attributes: - hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. + Attributes: + hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. - dropout (float): The dropout rate applied on the output of `[CLS]` token, + dropout (float): The dropout rate applied on the output of `[CLS]` token, - hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh']. + hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh']. - Example:: - # Here is the content of an exmple model config file in json format + Example:: + # Here is the content of an example model config file in json format - { - "hidden_size": 768, - "num_hidden_layers" 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - ... - "pooling": { - "hidden_size": 768, - "hidden_act": "gelu", - "dropout": 0.1 - } - } + { + "hidden_size": 768, + "num_hidden_layers" 12, + "num_attention_heads": 12, + "intermediate_size": 3072, + ... + "pooling": { + "hidden_size": 768, + "hidden_act": "gelu", + "dropout": 0.1 + } + } """ @@ -234,7 +183,7 @@ def __init__(self, config=None): """Constructs PoolConfig. Args: - `config`: the config of the model. The field of pool config will be initalized with the 'pooling' field in model config. + `config`: the config of the model. The field of pool config will be initialized with the 'pooling' field in model config. """ self.hidden_size = 768 @@ -250,8 +199,7 @@ def __init__(self, config=None): class TraceMode: - """ Trace context used when tracing modules contains customer operators/Functions - """ + """Trace context used when tracing modules contains customer operators/Functions""" def __enter__(self): os.environ["JIT_TRACE"] = "True" @@ -284,27 +232,24 @@ def output_dim(self): @traceable class XSoftmax(torch.autograd.Function): - """ Masked Softmax which is optimized for saving memory - - Args: - input (:obj:`torch.tensor`): The input tensor that will apply softmax. - mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation. - dim (int): The dimenssion that will apply softmax. - Example:: - import torch - from DeBERTa.deberta import XSoftmax - # Make a tensor - x = torch.randn([4,20,100]) - # Create a mask - mask = (x>0).int() - y = XSoftmax.apply(x, mask, dim=-1) - """ + """Masked Softmax which is optimized for saving memory - @staticmethod - def forward(self, input, mask, dim): - """ + Args: + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax caculation. + dim (int): The dimenssion that will apply softmax. + Example:: + import torch + from transformers.modeling_deroberta import XSoftmax + # Make a tensor + x = torch.randn([4,20,100]) + # Create a mask + mask = (x>0).int() + y = XSoftmax.apply(x, mask, dim=-1) """ + @staticmethod + def forward(self, input, mask, dim): self.dim = dim if version.Version(torch.__version__) >= version.Version("1.2.0a"): rmask = ~(mask.bool()) @@ -319,9 +264,6 @@ def forward(self, input, mask, dim): @staticmethod def backward(self, grad_output): - """ - """ - (output,) = self.saved_tensors inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) return inputGrad, None, None @@ -359,6 +301,8 @@ def get_mask(input, local_context): @traceable class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + @staticmethod def forward(ctx, input, local_ctx): mask, dropout = get_mask(input, local_ctx) @@ -379,13 +323,13 @@ def backward(ctx, grad_output): class StableDropout(torch.nn.Module): - """ Optimized dropout module for stabilizing the training + """Optimized dropout module for stabilizing the training - Args: + Args: - drop_prob (float): the dropout probabilities + drop_prob (float): the dropout probabilities - """ + """ def __init__(self, drop_prob): super().__init__() @@ -394,13 +338,13 @@ def __init__(self, drop_prob): self.context_stack = None def forward(self, x): - """ Call the module + """Call the module - Args: - x (:obj:`torch.tensor`): The input tensor to apply dropout + Args: + x (:obj:`torch.tensor`): The input tensor to apply dropout - """ + """ if self.training and self.drop_prob > 0: return XDropout.apply(x, self.get_context()) return x @@ -430,21 +374,21 @@ def get_context(self): def MaskedLayerNorm(layerNorm, input, mask=None): - """ Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. - Args: - layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function - input (:obj:`torch.tensor`): The input tensor - mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` + """Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. + Args: + layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function + input (:obj:`torch.tensor`): The input tensor + mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` - Example:: + Example:: - # Create a tensor b x n x d - x = torch.randn([1,10,100]) - m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) - LayerNorm = DeBERTa.deberta.BertLayerNorm(100) - y = MaskedLayerNorm(LayerNorm, x, m) + # Create a tensor b x n x d + x = torch.randn([1,10,100]) + m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) + LayerNorm = DeBERTa.deberta.BertLayerNorm(100) + y = MaskedLayerNorm(LayerNorm, x, m) - """ + """ output = layerNorm(input).to(input) if mask is None: return output @@ -457,8 +401,7 @@ def MaskedLayerNorm(layerNorm, input, mask=None): class BertLayerNorm(nn.Module): - """LayerNorm module in the TF style (epsilon inside the square root). - """ + """LayerNorm module in the TF style (epsilon inside the square root).""" def __init__(self, size, eps=1e-12): super().__init__() @@ -594,8 +537,7 @@ def forward( class DeBERTaEncoder(nn.Module): - """ Modified BertEncoder with relative position bias support - """ + """Modified BertEncoder with relative position bias support""" def __init__(self, config): super().__init__() @@ -681,7 +623,7 @@ def forward( def build_relative_position(query_size, key_size, device): - """ Build relative position according to the query and key + """Build relative position according to the query and key We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key :math:`P_k` is range from (0, key_size), The relative positions from query to key is @@ -725,7 +667,7 @@ class DisentangledSelfAttention(torch.nn.Module): Parameters: config (:obj:`str`): A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ - for more details, please refer :class:`~DeBERTa.deberta.ModelConfig` + for more details, please refer :class:`~transformers.DeBERTaConfig` """ @@ -780,7 +722,7 @@ def forward( relative_pos=None, rel_embeddings=None, ): - """ Call the module + """Call the module Args: hidden_states (:obj:`torch.FloatTensor`): @@ -919,8 +861,7 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd class DeBERTaEmbeddings(nn.Module): - """Construct the embeddings from word, position and token_type embeddings. - """ + """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): super().__init__() @@ -974,8 +915,8 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, mask=None): class DeBERTaPreTrainedModel(PreTrainedModel): - """ An abstract class to handle weights initialization and - a simple interface for downloading and loading pretrained models. + """An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. """ config_class = DeBERTaConfig @@ -991,7 +932,7 @@ def _init_weights(self, module): module.bias.data.zero_() -DeBERTa_START_DOCSTRING = r""" The DeBERTa model was proposed in +DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention`_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two improvements, it out perform BERT/RoBERTa on a majority @@ -1012,7 +953,7 @@ def _init_weights(self, module): Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ -DeBERTa_INPUTS_DOCSTRING = r""" +DEBERTA_INPUTS_DOCSTRING = r""" Inputs: **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: Indices of input sequence tokens in the vocabulary. @@ -1054,8 +995,8 @@ def _init_weights(self, module): @add_start_docstrings( "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", - DeBERTa_START_DOCSTRING, - DeBERTa_INPUTS_DOCSTRING, + DEBERTA_START_DOCSTRING, + DEBERTA_INPUTS_DOCSTRING, ) class DeBERTaModel(DeBERTaPreTrainedModel): r""" @@ -1079,8 +1020,8 @@ class DeBERTaModel(DeBERTaPreTrainedModel): Examples:: - tokenizer = DeBERTaTokenizer.from_pretrained('deberta-base-uncased') - model = DeBERTaModel.from_pretrained('deberta-base-uncased') + tokenizer = DeBERTaTokenizer.from_pretrained('microsoft/deberta-base') + model = DeBERTaModel.from_pretrained('microsoft/deberta-base') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 outputs = model(input_ids) last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple @@ -1103,9 +1044,9 @@ def set_input_embeddings(self, new_embeddings): self.embeddings.word_embeddings = new_embeddings def _prune_heads(self, heads_to_prune): - """ Prunes heads of the model. - heads_to_prune: dict of {layer_num: list of heads to prune in this layer} - See base class PreTrainedModel + """Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel """ raise NotImplementedError("The prune function is not implemented in DeBERTa model.") @@ -1117,7 +1058,9 @@ def forward( position_ids=None, head_mask=None, inputs_embeds=None, - return_tuple=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -1155,14 +1098,21 @@ def forward( sequence_output = encoded_layers[-1] - return (sequence_output, sequence_output[:, 0]) + if not return_dict: + return (sequence_output, sequence_output[:, 0]) + else: + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=sequence_output[:, 0], + hidden_states=encoded_layers, + ) @add_start_docstrings( """DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, - DeBERTa_START_DOCSTRING, - DeBERTa_INPUTS_DOCSTRING, + DEBERTA_START_DOCSTRING, + DEBERTA_INPUTS_DOCSTRING, ) class DeBERTaForSequenceClassification(DeBERTaPreTrainedModel): r""" @@ -1187,8 +1137,8 @@ class DeBERTaForSequenceClassification(DeBERTaPreTrainedModel): Examples:: - tokenizer = DeBERTaTokenizer.from_pretrained('deberta-base-uncased') - model = DeBERTaForSequenceClassification.from_pretrained('deberta-base-uncased') + tokenizer = DeBERTaTokenizer.from_pretrained('microsoft/deberta-base') + model = DeBERTaForSequenceClassification.from_pretrained('microsoft/deberta-base') input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 outputs = model(input_ids, labels=labels) @@ -1222,8 +1172,9 @@ def forward( token_type_ids=None, position_ids=None, labels=None, + output_attentions=None, output_hidden_states=None, - return_tuple=True, + return_dict=None, ): encoder_layer, cls = self.bert( @@ -1253,4 +1204,10 @@ def forward( else: log_softmax = torch.nn.LogSoftmax(-1) loss = -((log_softmax(logits) * labels).sum(-1)).mean() - return (loss, logits) + if not return_dict: + return (loss, logits) + else: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + ) diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index 0f76d3abc3d8..83056d9b5592 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -22,9 +22,10 @@ from functools import lru_cache from zipfile import ZipFile -import requests import tqdm +import requests + from .tokenization_utils import PreTrainedTokenizer @@ -183,7 +184,10 @@ def decode(self, tokens): def get_encoder(encoder, vocab): - return Encoder(encoder=encoder, bpe_merges=vocab,) + return Encoder( + encoder=encoder, + bpe_merges=vocab, + ) def _is_whitespace(char): @@ -358,44 +362,44 @@ def __init__(self, vocab_file=None, do_lower_case=True, special_tokens=None): self.ids_to_tokens = self.symbols def tokenize(self, text): - """ Convert an input text to tokens. - - Args: - text (:obj:`str`): input text to be tokenized. + """Convert an input text to tokens. - Returns: - A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + Args: + text (:obj:`str`): input text to be tokenized. - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - """ + Returns: + A list of byte tokens where each token represent the byte id in GPT2 byte dictionary + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + """ bpe = self._encode(text) return [t for t in bpe.split(" ") if t] def convert_tokens_to_ids(self, tokens): - """ Convert list of tokens to ids. - Args: - tokens (:obj:`list`): list of tokens + """Convert list of tokens to ids. + Args: + tokens (:obj:`list`): list of tokens - Returns: - List of ids - """ + Returns: + List of ids + """ return [self.vocab[t] for t in tokens] def convert_ids_to_tokens(self, ids): - """ Convert list of ids to tokens. - Args: - ids (:obj:`list`): list of ids + """Convert list of ids to tokens. + Args: + ids (:obj:`list`): list of ids - Returns: - List of tokens - """ + Returns: + List of tokens + """ tokens = [] for i in ids: @@ -406,33 +410,33 @@ def split_to_words(self, text): return self.bpe.split_to_words(text) def decode(self, tokens): - """ Decode list of tokens to text strings. - Args: - tokens (:obj:`list`): list of tokens. - - Returns: - Text string corresponds to the input tokens. - - Example:: - >>> tokenizer = GPT2Tokenizer() - >>> text = "Hello world!" - >>> tokens = tokenizer.tokenize(text) - >>> print(tokens) - ['15496', '995', '0'] - >>> tokenizer.decode(tokens) - 'Hello world!' - """ + """Decode list of tokens to text strings. + Args: + tokens (:obj:`list`): list of tokens. + + Returns: + Text string corresponds to the input tokens. + + Example:: + >>> tokenizer = GPT2Tokenizer() + >>> text = "Hello world!" + >>> tokens = tokenizer.tokenize(text) + >>> print(tokens) + ['15496', '995', '0'] + >>> tokenizer.decode(tokens) + 'Hello world!' + """ return self.bpe.decode([int(t) for t in tokens if t not in self.special_tokens]) def add_special_token(self, token): """Adds a special token to the dictionary. - Args: - token (:obj:`str`): Tthe new token/word to be added to the vocabulary. + Args: + token (:obj:`str`): Tthe new token/word to be added to the vocabulary. - Returns: - The id of new token in the vocabulary. + Returns: + The id of new token in the vocabulary. - """ + """ self.special_tokens.append(token) return self.add_symbol(token) @@ -459,14 +463,14 @@ def _decode(self, x: str) -> str: def add_symbol(self, word, n=1): """Adds a word to the dictionary. - Args: - word (:obj:`str`): Tthe new token/word to be added to the vocabulary. - n (int, optional): The frequency of the word. + Args: + word (:obj:`str`): Tthe new token/word to be added to the vocabulary. + n (int, optional): The frequency of the word. - Returns: - The id of the new word. + Returns: + The id of the new word. - """ + """ if word in self.indices: idx = self.indices[word] self.count[idx] = self.count[idx] + n @@ -549,8 +553,7 @@ def get_vocab(self): return vocab def _tokenize(self, text): - """ Take as input a string and return a list of strings (tokens) for words/sub-words - """ + """Take as input a string and return a list of strings (tokens) for words/sub-words""" if self.do_lower_case: text = text.lower() return self.gpt2_tokenizer.tokenize(text) @@ -605,7 +608,12 @@ def get_special_tokens_mask(self, token_ids_0, token_ids_1=None, already_has_spe "You should not supply a second sequence if the provided sequence of " "ids is already formated with special tokens for the model." ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0,)) + return list( + map( + lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, + token_ids_0, + ) + ) if token_ids_1 is not None: return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index b86580b025d4..cdb4a6615358 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -24,15 +24,12 @@ if is_torch_available(): - from transformers import ( + from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification, DeBERTaConfig, - DeBERTaModel, - # XxxForMaskedLM, - # XxxForQuestionAnswering, DeBERTaForSequenceClassification, - # XxxForTokenClassification, + DeBERTaModel, ) - from transformers.modeling_deberta import DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST @require_torch @@ -246,6 +243,6 @@ def test_model_common_attributes(self): @slow def test_model_from_pretrained(self): - for model_name in DeBERTa_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = DeBERTaModel.from_pretrained(model_name) self.assertIsNotNone(model) From b1323c2c0c13ca1ea1527fd3f697cbd4233d065a Mon Sep 17 00:00:00 2001 From: Lysandre Date: Wed, 9 Sep 2020 16:27:07 +0200 Subject: [PATCH 04/18] Patch DeBERTa Documentation Style --- src/transformers/configuration_deberta.py | 76 +++--- src/transformers/modeling_deberta.py | 294 +++++++++++----------- src/transformers/tokenization_deberta.py | 4 +- tests/test_modeling_deberta.py | 22 +- 4 files changed, 205 insertions(+), 191 deletions(-) diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 46cf676a7156..bd3cd7e71090 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -32,39 +32,56 @@ class DeBERTaConfig(PretrainedConfig): r""" :class:`~transformers.DeBERTaConfig` is the configuration class to store the configuration of a - `DeBERTaModel`. + :class:`~transformers.DeBERTaModel`. Arguments: - hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. - num_hidden_layers (int): Number of hidden layers in the Transformer encoder, default: `12`. - num_attention_heads (int): Number of attention heads for each attention layer in - the Transformer encoder, default: `12`. - intermediate_size (int): The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder, default: `3072`. - hidden_act (str): The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported, default: `gelu`. - hidden_dropout_prob (float): The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler, default: `0.1`. - attention_probs_dropout_prob (float): The dropout ratio for the attention - probabilities, default: `0.1`. - max_position_embeddings (int): The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048), default: `512`. - type_vocab_size (int): The vocabulary size of the `token_type_ids` passed into - `DeBERTa` model, default: `0`. - initializer_range (int): The sttdev of the _normal_initializer for - initializing all weight matrices, default: `0.02`. - relative_attention (:obj:`bool`): Whether use relative position encoding, default: `False`. - max_relative_positions (int): The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`], default: -1, use the same value as `max_position_embeddings`. - padding_idx (int): The value used to pad input_ids, default: `0`. - position_biased_input (:obj:`bool`): Whether add absolute position embedding to content embedding, default: `True`. - pos_att_type (:obj:`str`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p", default: "None". - vocab_size (int): The size of the vocabulary, default: `-1`. + vocab_size (:obj:`int`, optional, defaults to 50265): + Vocabulary size of the DeBERTa model. Defines the different tokens that + can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.DeBERTaModel`. + hidden_size (:obj:`int`, optional, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (:obj:`int`, optional, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (:obj:`int`, optional, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (:obj:`int`, optional, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): + The non-linear activation function (function or string) in the encoder and pooler. + If string, "gelu", "relu", "swish" and "gelu_new" are supported. + hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, optional, defaults to 512): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, optional, defaults to 2): + The vocabulary size of the `token_type_ids` passed into :class:`~transformers.DeBERTaModel`. + initializer_range (:obj:`float`, optional, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. + relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether use relative position encoding. + max_relative_positions (:obj:`int`, `optional`, defaults to 1): + The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`]. + Use the same value as `max_position_embeddings`. + pad_token_id (:obj:`int`, `optional`, defaults to 0): + The value used to pad input_ids. + position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether add absolute position embedding to content embedding. + pos_att_type (:obj:`str`, `optional`, defaults to "None"): + The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], + e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p". + layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + The epsilon used by the layer normalization layers. """ model_type = "deberta" def __init__( self, + vocab_size=50265, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, @@ -75,13 +92,12 @@ def __init__( max_position_embeddings=512, type_vocab_size=0, initializer_range=0.02, + layer_norm_eps=1e-7, relative_attention=False, max_relative_positions=-1, - padding_idx=0, + pad_token_id=0, position_biased_input=True, pos_att_type="None", - vocab_size=-1, - layer_norm_eps=1e-7, **kwargs ): super().__init__(**kwargs) @@ -98,7 +114,7 @@ def __init__( self.initializer_range = initializer_range self.relative_attention = relative_attention self.max_relative_positions = max_relative_positions - self.padding_idx = padding_idx + self.pad_token_id = pad_token_id self.position_biased_input = position_biased_input self.pos_att_type = pos_att_type self.vocab_size = vocab_size diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index dc00e1942e51..e633f519793c 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -27,8 +27,8 @@ from torch.nn import CrossEntropyLoss from .configuration_deberta import DeBERTaConfig -from .file_utils import add_start_docstrings -from .modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput +from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable +from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput from .modeling_utils import PreTrainedModel @@ -46,8 +46,11 @@ logger = logging.getLogger(__name__) +_CONFIG_FOR_DOC = "DeBERTaConfig" +_TOKENIZER_FOR_DOC = "DeBERTaTokenizer" + #################################################### -# This list contrains shortcut names for some of +# This list contains shortcut names for some of # the pretrained weights provided with the models #################################################### DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -574,52 +577,57 @@ def forward( self, hidden_states, attention_mask, - output_all_encoded_layers=True, - return_att=False, + output_hidden_states=True, + output_attentions=False, query_states=None, relative_pos=None, + return_dict=False, ): attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) - all_encoder_layers = [] - att_matrixs = [] + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + if isinstance(hidden_states, Sequence): next_kv = hidden_states[0] else: next_kv = hidden_states rel_embeddings = self.get_rel_embedding() for i, layer_module in enumerate(self.layer): - output_states = layer_module( + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = layer_module( next_kv, attention_mask, - return_att, + output_attentions, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings, ) - if return_att: - output_states, att_m = output_states + if output_attentions: + hidden_states, att_m = hidden_states if query_states is not None: - query_states = output_states + query_states = hidden_states if isinstance(hidden_states, Sequence): next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None else: - next_kv = output_states - - if output_all_encoded_layers: - all_encoder_layers.append(output_states) - if return_att: - att_matrixs.append(att_m) - if not output_all_encoded_layers: - all_encoder_layers.append(output_states) - if return_att: - att_matrixs.append(att_m) - if return_att: - return (all_encoder_layers, att_matrixs) - else: - return all_encoder_layers + next_kv = hidden_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) def build_relative_position(query_size, key_size, device): @@ -865,9 +873,9 @@ class DeBERTaEmbeddings(nn.Module): def __init__(self, config): super().__init__() - padding_idx = getattr(config, "padding_idx", 0) + pad_token_id = getattr(config, "pad_token_id", 0) self.embedding_size = getattr(config, "embedding_size", config.hidden_size) - self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=padding_idx) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) self.position_biased_input = getattr(config, "position_biased_input", True) if not self.position_biased_input: @@ -885,21 +893,32 @@ def __init__(self, config): self.output_to_half = False self.config = config - def forward(self, input_ids, token_type_ids=None, position_ids=None, mask=None): - seq_length = input_ids.size(1) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + if position_ids is None: - position_ids = torch.arange(0, seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + position_ids = self.position_ids[:, :seq_length] + if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) - words_embeddings = self.word_embeddings(input_ids) if self.position_embeddings is not None: position_embeddings = self.position_embeddings(position_ids.long()) else: - position_embeddings = torch.zeros_like(words_embeddings) + position_embeddings = torch.zeros_like(inputs_embeds) - embeddings = words_embeddings + embeddings = inputs_embeds if self.position_biased_input: embeddings += position_embeddings if self.config.type_vocab_size > 0: @@ -954,80 +973,51 @@ def _init_weights(self, module): """ DEBERTA_INPUTS_DOCSTRING = r""" - Inputs: - **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. - To match pre-training, DeBERTa input sequence should be formatted with [CLS] and [SEP] tokens as follows: - - (a) For sequence pairs: - - ``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]`` - - - (b) For single sequences: - - ``tokens: [CLS] the dog is hairy . [SEP]`` Indices can be obtained using :class:`transformers.DeBERTaTokenizer`. See :func:`transformers.PreTrainedTokenizer.encode` and - :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. - **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + :func:`transformers.PreTrainedTokenizer.__call__` for details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. - **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token - **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. - **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: - Mask to nullify selected heads of the self-attention modules. - Mask values selected in ``[0, 1]``: - ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. - **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: - Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. + + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. + return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a + plain tuple. """ @add_start_docstrings( "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", DEBERTA_START_DOCSTRING, - DEBERTA_INPUTS_DOCSTRING, ) class DeBERTaModel(DeBERTaPreTrainedModel): - r""" - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` - Sequence of hidden-states at the output of the last layer of the model. - **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` - Last layer hidden-state of the first token of the sequence (classification token) - further processed by a Linear layer and a Tanh activation function. The Linear - layer weights are trained from the next sentence prediction (classification) - objective during DeBERTa pretraining. This output is usually *not* a good summary - of the semantic content of the input, you're often better with averaging or pooling - the sequence of hidden-states for the whole input sequence. - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = DeBERTaTokenizer.from_pretrained('microsoft/deberta-base') - model = DeBERTaModel.from_pretrained('microsoft/deberta-base') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - outputs = model(input_ids) - last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple - - """ - def __init__(self, config): super().__init__(config) @@ -1050,33 +1040,62 @@ def _prune_heads(self, heads_to_prune): """ raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="microsoft/deberta-base", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, - head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif (input_ids is None) and (inputs_embeds is not None): + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: raise ValueError("You have to specify either input_ids or inputs_embeds") + device = input_ids.device if input_ids is not None else inputs_embeds.device + if attention_mask is None: - attention_mask = torch.ones_like(input_ids) + attention_mask = torch.ones(input_shape, device=device) if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) embedding_output = self.embeddings( - input_ids.to(torch.long), token_type_ids.to(torch.long), position_ids, attention_mask + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=attention_mask, + inputs_embeds=inputs_embeds, ) - encoded_layers = self.encoder(embedding_output, attention_mask, output_all_encoded_layers=True) + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + ) + encoded_layers = encoder_outputs[1] if self.z_steps > 1: hidden_states = encoded_layers[-2] @@ -1099,62 +1118,29 @@ def forward( sequence_output = encoded_layers[-1] if not return_dict: - return (sequence_output, sequence_output[:, 0]) - else: - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=sequence_output[:, 0], - hidden_states=encoded_layers, - ) + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) @add_start_docstrings( """DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, DEBERTA_START_DOCSTRING, - DEBERTA_INPUTS_DOCSTRING, ) class DeBERTaForSequenceClassification(DeBERTaPreTrainedModel): - r""" - **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: - Labels for computing the sequence classification/regression loss. - Indices should be in ``[0, ..., config.num_labels - 1]``. - If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), - If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). - - Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: - **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: - Classification (or regression if config.num_labels==1) loss. - **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` - Classification (or regression if config.num_labels==1) scores (before SoftMax). - **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) - list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) - of shape ``(batch_size, sequence_length, hidden_size)``: - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - **attentions**: (`optional`, returned when ``output_attentions=True``) - list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. - - Examples:: - - tokenizer = DeBERTaTokenizer.from_pretrained('microsoft/deberta-base') - model = DeBERTaForSequenceClassification.from_pretrained('microsoft/deberta-base') - input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 - labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 - outputs = model(input_ids, labels=labels) - loss, logits = outputs[:2] - - """ - def __init__(self, config): super().__init__(config) num_labels = getattr(config, "num_labels", 2) self.num_labels = num_labels - self.bert = DeBERTaModel(config) + self.deberta = DeBERTaModel(config) pool_config = PoolConfig(self.config) - output_dim = self.bert.config.hidden_size self.pooler = ContextPooler(pool_config) output_dim = self.pooler.output_dim() @@ -1165,26 +1151,51 @@ def __init__(self, config): self.init_weights() + @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_code_sample_docstrings( + tokenizer_class=_TOKENIZER_FOR_DOC, + checkpoint="microsoft/deberta-base", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, + inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): - - encoder_layer, cls = self.bert( - input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, position_ids=position_ids + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) + + encoder_layer = outputs[0] pooled_output = self.pooler(encoder_layer) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) - loss = torch.tensor(0).to(logits) + loss = None if labels is not None: if self.num_labels == 1: # regression task @@ -1205,9 +1216,12 @@ def forward( log_softmax = torch.nn.LogSoftmax(-1) loss = -((log_softmax(logits) * labels).sum(-1)).mean() if not return_dict: - return (loss, logits) + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output else: return SequenceClassifierOutput( loss=loss, logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, ) diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index 83056d9b5592..6de8ad5a6f01 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -490,8 +490,8 @@ def save_pretrained(self, path: str): class DeBERTaTokenizer(PreTrainedTokenizer): r""" - Constructs a XxxTokenizer. - :class:`~transformers.DeBERTaTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece + Constructs a :class:`~transformers.DeBERTaTokenizer`, which runs end-to-end tokenization: punctuation + splitting + wordpiece Args: vocab_file: Path to a one-wordpiece-per-line vocabulary file diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index cdb4a6615358..04bbd81e02fd 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -153,9 +153,9 @@ def create_and_check_deberta_model( model = DeBERTaModel(config=config) model.to(torch_device) model.eval() - sequence_output, cls = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) - sequence_output, cls = model(input_ids, token_type_ids=token_type_ids) - sequence_output, cls = model(input_ids) + sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] + sequence_output = model(input_ids, token_type_ids=token_type_ids)[0] + sequence_output = model(input_ids)[0] result = { "sequence_output": sequence_output, @@ -225,22 +225,6 @@ def test_for_token_classification(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_deberta_for_token_classification(*config_and_inputs) - @unittest.skip - def test_attention_outputs(self): - pass - - @unittest.skip - def test_hidden_states_output(self): - pass - - @unittest.skip - def test_inputs_embeds(self): - pass - - @unittest.skip - def test_model_common_attributes(self): - pass - @slow def test_model_from_pretrained(self): for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: From e39816f37f3e39b8c2ed73e1bfc2d376e71bb76a Mon Sep 17 00:00:00 2001 From: Pengcheng He Date: Sat, 19 Sep 2020 17:16:25 -0400 Subject: [PATCH 05/18] Add final tests --- src/transformers/configuration_auto.py | 5 +-- src/transformers/modeling_deberta.py | 10 ++++-- tests/test_modeling_deberta.py | 46 +++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index e59e16b76539..4d18c02f328d 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -74,7 +74,7 @@ RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - DeBERTa_PRETRAINED_CONFIG_ARCHIVE_MAP, + DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ] for key, value, in pretrained_map.items() ) @@ -97,6 +97,7 @@ ("reformer", ReformerConfig), ("longformer", LongformerConfig), ("roberta", RobertaConfig), + ("deberta", DeBERTaConfig), ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), ("bert", BertConfig), @@ -110,7 +111,6 @@ ("encoder-decoder", EncoderDecoderConfig), ("funnel", FunnelConfig), ("lxmert", LxmertConfig), - ("deberta", DeBERTaConfig), ] ) @@ -144,6 +144,7 @@ ("encoder-decoder", "Encoder decoder"), ("funnel", "Funnel Transformer"), ("lxmert", "LXMERT"), + ("deberta", "DeBERTa"), ] ) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index e633f519793c..64f49a703b54 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -1139,7 +1139,7 @@ def __init__(self, config): num_labels = getattr(config, "num_labels", 2) self.num_labels = num_labels - self.deberta = DeBERTaModel(config) + self.bert = DeBERTaModel(config) pool_config = PoolConfig(self.config) self.pooler = ContextPooler(pool_config) output_dim = self.pooler.output_dim() @@ -1151,6 +1151,12 @@ def __init__(self, config): self.init_weights() + def get_input_embeddings(self): + return self.bert.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.bert.set_input_embeddings(new_embeddings) + @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, @@ -1179,7 +1185,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.deberta( + outputs = self.bert( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index 04bbd81e02fd..6ed8717b3c82 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -14,8 +14,11 @@ # limitations under the License. +import random import unittest +import numpy as np + from transformers import is_torch_available from transformers.testing_utils import require_torch, slow, torch_device @@ -24,12 +27,15 @@ if is_torch_available(): + import torch + + from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST + from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification, DeBERTaConfig, DeBERTaForSequenceClassification, DeBERTaModel, ) - from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST @require_torch @@ -230,3 +236,41 @@ def test_model_from_pretrained(self): for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = DeBERTaModel.from_pretrained(model_name) self.assertIsNotNone(model) + + +@require_torch +class DeBERTaModelIntegrationTest(unittest.TestCase): + @unittest.skip(reason="Model not available yet") + def test_inference_masked_lm(self): + pass + + @slow + def test_inference_no_head(self): + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + DeBERTaModel.base_model_prefix = "bert" + model = DeBERTaModel.from_pretrained("microsoft/deberta-base") + + input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + output = model(input_ids)[0] + # compare the actual values for a slice. + expected_slice = torch.tensor( + [[[-0.0218, -0.6641, -0.3665], [-0.3907, -0.4716, -0.6640], [0.7461, 1.2570, -0.9063]]] + ) + self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4), f"{output[:, :3, :3]}") + + @slow + def test_inference_classification_head(self): + random.seed(0) + np.random.seed(0) + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + model = DeBERTaForSequenceClassification.from_pretrained("microsoft/deberta-base") + input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) + output = model(input_ids)[0] + expected_shape = torch.Size((1, 2)) + self.assertEqual(output.shape, expected_shape) + expected_tensor = torch.tensor([[0.0884, -0.1047]]) + self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4), f"{output}") From 19b4e263934cbe6f0271ab80e0c3b7bccdaf9da8 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 21 Sep 2020 13:34:34 +0200 Subject: [PATCH 06/18] Style --- tests/test_modeling_deberta.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index 6ed8717b3c82..7930ff7b6f48 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -29,13 +29,12 @@ if is_torch_available(): import torch - from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST - from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification, DeBERTaConfig, DeBERTaForSequenceClassification, DeBERTaModel, ) + from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST @require_torch From 44eef87adbfb9a00626e864583d7682073ae9f05 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 22 Sep 2020 11:00:33 +0200 Subject: [PATCH 07/18] Enable tests + nitpicks --- src/transformers/configuration_deberta.py | 12 +- src/transformers/modeling_deberta.py | 155 +--------------------- tests/test_modeling_deberta.py | 2 - 3 files changed, 15 insertions(+), 154 deletions(-) diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index bd3cd7e71090..5bb52f7c51aa 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -14,14 +14,12 @@ # limitations under the License. """ DeBERTa model configuration """ -import logging +from .utils import logging from .configuration_utils import PretrainedConfig -__all__ = ["DeBERTaConfig", "DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP"] - -logger = logging.getLogger(__name__) +logger = logging.get_logger(__name__) DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { "microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/config.json", @@ -98,6 +96,8 @@ def __init__( pad_token_id=0, position_biased_input=True, pos_att_type="None", + pooler_dropout=0, + pooler_hidden_act='gelu', **kwargs ): super().__init__(**kwargs) @@ -119,3 +119,7 @@ def __init__( self.pos_att_type = pos_att_type self.vocab_size = vocab_size self.layer_norm_eps = layer_norm_eps + + self.pooler_hidden_size = kwargs.get('pooler_hidden_size', hidden_size) + self.pooler_dropout = pooler_dropout + self.pooler_hidden_act = pooler_hidden_act diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 64f49a703b54..8c2d108de04c 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -16,7 +16,7 @@ import copy import json -import logging +from .utils import logging import math import os from collections import Sequence @@ -30,21 +30,10 @@ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput from .modeling_utils import PreTrainedModel +from torch import _softmax_backward_data as _softmax_backward_data -if version.Version(torch.__version__) >= version.Version("1.0.0"): - from torch import _softmax_backward_data as _softmax_backward_data -else: - from torch import softmax_backward_data as _softmax_backward_data - -__all__ = [ - "DeBERTaModel", - "DeBERTaForSequenceClassification", - "DeBERTaPreTrainedModel", - "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", -] - -logger = logging.getLogger(__name__) +logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeBERTaConfig" _TOKENIZER_FOR_DOC = "DeBERTaTokenizer" @@ -85,138 +74,11 @@ def linear_act(x): } -def traceable(cls): - """Decorator over customer functions - There is an issue for tracing customer python torch Function, using this decorator to work around it. - e.g. - @traceable - class MyOp(torch.autograd.Function): - xxx - """ - - class _Function(object): - @staticmethod - def apply(*args): - jit_trace = os.getenv("JIT_TRACE", "False").lower() == "true" - if jit_trace: - return cls.forward(_Function, *args) - else: - return cls.apply(*args) - - @staticmethod - def save_for_backward(*args): - pass - - _Function.__name__ = cls.__name__ - _Function.__doc__ = cls.__doc__ - return _Function - - -class AbsModelConfig(object): - def __init__(self): - pass - - @classmethod - def from_dict(cls, json_object): - """Constructs a `ModelConfig` from a Python dictionary of parameters.""" - config = cls() - for key, value in json_object.items(): - if isinstance(value, dict): - value = AbsModelConfig.from_dict(value) - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `ModelConfig` from a json file of parameters.""" - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - - def _json_default(obj): - if isinstance(obj, AbsModelConfig): - return obj.__dict__ - - return json.dumps(self.__dict__, indent=2, sort_keys=True, default=_json_default) + "\n" - - -class PoolConfig(AbsModelConfig): - """Configuration class to store the configuration of `pool layer`. - - Parameters: - config (:class:`~transformers.DeBERTaConfig`): The model config. The field of pool config will be initialized with the `pooling` field in model config. - - Attributes: - hidden_size (int): Size of the encoder layers and the pooler layer, default: `768`. - - dropout (float): The dropout rate applied on the output of `[CLS]` token, - - hidden_act (:obj:`str`): The activation function of the projection layer, it can be one of ['gelu', 'tanh']. - - Example:: - # Here is the content of an example model config file in json format - - { - "hidden_size": 768, - "num_hidden_layers" 12, - "num_attention_heads": 12, - "intermediate_size": 3072, - ... - "pooling": { - "hidden_size": 768, - "hidden_act": "gelu", - "dropout": 0.1 - } - } - - """ - - def __init__(self, config=None): - """Constructs PoolConfig. - - Args: - `config`: the config of the model. The field of pool config will be initialized with the 'pooling' field in model config. - """ - - self.hidden_size = 768 - self.dropout = 0 - self.hidden_act = "gelu" - if config: - pool_config = getattr(config, "pooling", config) - if isinstance(pool_config, dict): - pool_config = AbsModelConfig.from_dict(pool_config) - self.hidden_size = getattr(pool_config, "hidden_size", config.hidden_size) - self.dropout = getattr(pool_config, "dropout", 0) - self.hidden_act = getattr(pool_config, "hidden_act", "gelu") - - -class TraceMode: - """Trace context used when tracing modules contains customer operators/Functions""" - - def __enter__(self): - os.environ["JIT_TRACE"] = "True" - return self - - def __exit__(self, exp_value, exp_type, trace): - del os.environ["JIT_TRACE"] - - class ContextPooler(nn.Module): def __init__(self, config): super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = StableDropout(config.dropout) + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) self.config = config def forward(self, hidden_states, mask=None): @@ -226,14 +88,13 @@ def forward(self, hidden_states, mask=None): context_token = hidden_states[:, 0] context_token = self.dropout(context_token) pooled_output = self.dense(context_token) - pooled_output = ACT2FN[self.config.hidden_act](pooled_output) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) return pooled_output def output_dim(self): return self.config.hidden_size -@traceable class XSoftmax(torch.autograd.Function): """Masked Softmax which is optimized for saving memory @@ -302,7 +163,6 @@ def get_mask(input, local_context): return mask, dropout -@traceable class XDropout(torch.autograd.Function): """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" @@ -1140,8 +1000,7 @@ def __init__(self, config): self.num_labels = num_labels self.bert = DeBERTaModel(config) - pool_config = PoolConfig(self.config) - self.pooler = ContextPooler(pool_config) + self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim() self.classifier = torch.nn.Linear(output_dim, num_labels) diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index 7930ff7b6f48..541cfa9684be 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -51,9 +51,7 @@ class DeBERTaModelTest(ModelTesterMixin, unittest.TestCase): test_torchscript = False test_pruning = False - test_resize_embeddings = False test_head_masking = False - test_missing_keys = False is_encoder_decoder = False class DeBERTaModelTester(object): From f5c6277eb9ea5e1252f4deb725da9df2417d9eef Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 22 Sep 2020 11:47:43 +0200 Subject: [PATCH 08/18] position IDs --- src/transformers/modeling_deberta.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 8c2d108de04c..2f0188caebe1 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -800,6 +800,7 @@ class DeBERTaPreTrainedModel(PreTrainedModel): config_class = DeBERTaConfig base_model_prefix = "deberta" + authorized_missing_keys = ['position_ids'] def _init_weights(self, module): """ Initialize the weights """ From 83729c0f2bb7870335b29332ebfd791c3b921b58 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 24 Sep 2020 17:25:09 +0200 Subject: [PATCH 09/18] BERT -> DeBERTa --- src/transformers/configuration_deberta.py | 7 +++---- src/transformers/modeling_deberta.py | 14 +++++++------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 5bb52f7c51aa..13a5953899f3 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -14,9 +14,8 @@ # limitations under the License. """ DeBERTa model configuration """ -from .utils import logging - from .configuration_utils import PretrainedConfig +from .utils import logging logger = logging.get_logger(__name__) @@ -97,7 +96,7 @@ def __init__( position_biased_input=True, pos_att_type="None", pooler_dropout=0, - pooler_hidden_act='gelu', + pooler_hidden_act="gelu", **kwargs ): super().__init__(**kwargs) @@ -120,6 +119,6 @@ def __init__( self.vocab_size = vocab_size self.layer_norm_eps = layer_norm_eps - self.pooler_hidden_size = kwargs.get('pooler_hidden_size', hidden_size) + self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size) self.pooler_dropout = pooler_dropout self.pooler_hidden_act = pooler_hidden_act diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 2f0188caebe1..50279b2d42cc 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -16,13 +16,13 @@ import copy import json -from .utils import logging import math import os from collections import Sequence import torch from packaging import version +from torch import _softmax_backward_data as _softmax_backward_data from torch import nn from torch.nn import CrossEntropyLoss @@ -30,7 +30,7 @@ from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput from .modeling_utils import PreTrainedModel -from torch import _softmax_backward_data as _softmax_backward_data +from .utils import logging logger = logging.get_logger(__name__) @@ -800,7 +800,7 @@ class DeBERTaPreTrainedModel(PreTrainedModel): config_class = DeBERTaConfig base_model_prefix = "deberta" - authorized_missing_keys = ['position_ids'] + authorized_missing_keys = ["position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -1000,7 +1000,7 @@ def __init__(self, config): num_labels = getattr(config, "num_labels", 2) self.num_labels = num_labels - self.bert = DeBERTaModel(config) + self.deberta = DeBERTaModel(config) self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim() @@ -1012,10 +1012,10 @@ def __init__(self, config): self.init_weights() def get_input_embeddings(self): - return self.bert.get_input_embeddings() + return self.deberta.get_input_embeddings() def set_input_embeddings(self, new_embeddings): - self.bert.set_input_embeddings(new_embeddings) + self.deberta.set_input_embeddings(new_embeddings) @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) @add_code_sample_docstrings( @@ -1045,7 +1045,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.bert( + outputs = self.deberta( input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, From 83f37cd7d64f41be5ff1a61c0a572b6dddf0bb29 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 24 Sep 2020 17:29:52 +0200 Subject: [PATCH 10/18] Quality --- src/transformers/modeling_deberta.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 50279b2d42cc..2ae4b3827c67 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -15,14 +15,12 @@ """ PyTorch DeBERTa model. """ import copy -import json import math -import os from collections import Sequence import torch from packaging import version -from torch import _softmax_backward_data as _softmax_backward_data +from torch import _softmax_backward_data from torch import nn from torch.nn import CrossEntropyLoss From 28f80bb7561bb0986c90f1d5f6216ad06be09484 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 24 Sep 2020 17:57:28 +0200 Subject: [PATCH 11/18] Style --- src/transformers/modeling_deberta.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 2ae4b3827c67..14a2f9f8d5fb 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -20,8 +20,7 @@ import torch from packaging import version -from torch import _softmax_backward_data -from torch import nn +from torch import _softmax_backward_data, nn from torch.nn import CrossEntropyLoss from .configuration_deberta import DeBERTaConfig From 370f9f06d4b2ceacc9795d6c10222994777b2dd2 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 25 Sep 2020 11:30:59 +0200 Subject: [PATCH 12/18] Tokenization --- src/transformers/tokenization_deberta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index 6de8ad5a6f01..8bf3cd6e50f5 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -634,9 +634,9 @@ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): return len(cls + token_ids_0 + sep) * [0] return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] - def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs): + def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): add_prefix_space = kwargs.pop("add_prefix_space", False) - if is_pretokenized or add_prefix_space: + if is_split_into_words or add_prefix_space: text = " " + text return (text, kwargs) From ae6282b2717bdd7874ed3c685c71bc3a86ce5254 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 29 Sep 2020 10:35:06 +0200 Subject: [PATCH 13/18] Last updates. --- docs/source/index.rst | 1 + src/transformers/activations.py | 6 ++++ src/transformers/modeling_deberta.py | 47 ++++++---------------------- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 3202daf409e6..942b52ba5358 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -214,6 +214,7 @@ conversion utilities for the following models: model_doc/bertgeneration model_doc/camembert model_doc/ctrl + model_doc/deberta model_doc/dialogpt model_doc/distilbert model_doc/dpr diff --git a/src/transformers/activations.py b/src/transformers/activations.py index 52483cab2132..30a1a7ac01ec 100644 --- a/src/transformers/activations.py +++ b/src/transformers/activations.py @@ -44,6 +44,10 @@ def mish(x): return x * torch.tanh(torch.nn.functional.softplus(x)) +def linear_act(x): + return x + + ACT2FN = { "relu": F.relu, "swish": swish, @@ -52,6 +56,8 @@ def mish(x): "gelu_new": gelu_new, "gelu_fast": gelu_fast, "mish": mish, + "linear": linear_act, + "sigmoid": torch.sigmoid, } diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 14a2f9f8d5fb..0ad42980755f 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -23,6 +23,7 @@ from torch import _softmax_backward_data, nn from torch.nn import CrossEntropyLoss +from .activations import ACT2FN from .configuration_deberta import DeBERTaConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput @@ -35,42 +36,12 @@ _CONFIG_FOR_DOC = "DeBERTaConfig" _TOKENIZER_FOR_DOC = "DeBERTaTokenizer" -#################################################### -# This list contains shortcut names for some of -# the pretrained weights provided with the models -#################################################### DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", "microsoft/deberta-large", ] -def gelu(x): - """Implementation of the gelu activation function. - For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): - 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - """ - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - -def swish(x): - return x * torch.sigmoid(x) - - -def linear_act(x): - return x - - -ACT2FN = { - "gelu": gelu, - "relu": torch.nn.functional.relu, - "swish": swish, - "tanh": torch.nn.functional.tanh, - "linear": linear_act, - "sigmoid": torch.sigmoid, -} - - class ContextPooler(nn.Module): def __init__(self, config): super().__init__() @@ -840,32 +811,32 @@ def _init_weights(self, module): :func:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ - attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. `What are attention masks? <../glossary.html#attention-mask>`__ - token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` corresponds to a `sentence B` token `What are token type IDs? <../glossary.html#token-type-ids>`_ - position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. `What are position IDs? <../glossary.html#position-ids>`_ - inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): + output_attentions (:obj:`bool`, `optional`): If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. - output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): + output_hidden_states (:obj:`bool`, `optional`): If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail. - return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`): + return_dict (:obj:`bool`, `optional`): If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. """ @@ -1034,7 +1005,7 @@ def forward( return_dict=None, ): r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), From 5bef2ea347288f48bf6d69bb86ee49cd14d75fa7 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 29 Sep 2020 17:17:46 +0200 Subject: [PATCH 14/18] @patrickvonplaten's comments --- src/transformers/configuration_deberta.py | 50 +++++---- src/transformers/modeling_deberta.py | 124 +++++++++------------- 2 files changed, 79 insertions(+), 95 deletions(-) diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 13a5953899f3..9a3813920ef8 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -32,32 +32,35 @@ class DeBERTaConfig(PretrainedConfig): :class:`~transformers.DeBERTaModel`. Arguments: - vocab_size (:obj:`int`, optional, defaults to 50265): - Vocabulary size of the DeBERTa model. Defines the different tokens that - can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.DeBERTaModel`. - hidden_size (:obj:`int`, optional, defaults to 768): + vocab_size (:obj:`int`, `optional`, defaults to 30522): + Vocabulary size of the DeBERTa model. Defines the number of different tokens that can be represented by the + :obj:`inputs_ids` passed when calling :class:`~transformers.DebertaModel` or + :class:`~transformers.TFDebertaModel`. + hidden_size (:obj:`int`, `optional`, defaults to 768): Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (:obj:`int`, optional, defaults to 12): + num_hidden_layers (:obj:`int`, `optional`, defaults to 12): Number of hidden layers in the Transformer encoder. - num_attention_heads (:obj:`int`, optional, defaults to 12): + num_attention_heads (:obj:`int`, `optional`, defaults to 12): Number of attention heads for each attention layer in the Transformer encoder. - intermediate_size (:obj:`int`, optional, defaults to 3072): - Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. - hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"): + intermediate_size (:obj:`int`, `optional`, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`): The non-linear activation function (function or string) in the encoder and pooler. - If string, "gelu", "relu", "swish" and "gelu_new" are supported. - hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1): - The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1): + If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`, + :obj:`"mish"`, :obj:`"linear"`, :obj:`"sigmoid"` and :obj:`"gelu_new"` are supported. + hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): + The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1): The dropout ratio for the attention probabilities. - max_position_embeddings (:obj:`int`, optional, defaults to 512): + max_position_embeddings (:obj:`int`, `optional`, defaults to 512): The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). - type_vocab_size (:obj:`int`, optional, defaults to 2): - The vocabulary size of the `token_type_ids` passed into :class:`~transformers.DeBERTaModel`. - initializer_range (:obj:`float`, optional, defaults to 0.02): + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.DebertaModel` or + :class:`~transformers.TFDebertaModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): + layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): The epsilon used by the layer normalization layers. relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether use relative position encoding. @@ -68,9 +71,9 @@ class DeBERTaConfig(PretrainedConfig): The value used to pad input_ids. position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether add absolute position embedding to content embedding. - pos_att_type (:obj:`str`, `optional`, defaults to "None"): + pos_att_type (:obj:`List[str]`, `optional`): The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], - e.g. "p2c", "p2c|c2p", "p2c|c2p|p2p". + e.g. :obj:`["p2c"]`, :obj:`["p2c", "c2p"]`, :obj:`["p2c", "c2p", 'p2p"]`. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. """ @@ -94,7 +97,7 @@ def __init__( max_relative_positions=-1, pad_token_id=0, position_biased_input=True, - pos_att_type="None", + pos_att_type=None, pooler_dropout=0, pooler_hidden_act="gelu", **kwargs @@ -115,6 +118,11 @@ def __init__( self.max_relative_positions = max_relative_positions self.pad_token_id = pad_token_id self.position_biased_input = position_biased_input + + # Backwards compatibility + if type(pos_att_type) == str: + pos_att_type = [x.strip() for x in pos_att_type.lower().split("|")] + self.pos_att_type = pos_att_type self.vocab_size = vocab_size self.layer_norm_eps = layer_norm_eps diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 0ad42980755f..3ab0543311cf 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -14,7 +14,6 @@ # limitations under the License. """ PyTorch DeBERTa model. """ -import copy import math from collections import Sequence @@ -59,6 +58,7 @@ def forward(self, hidden_states, mask=None): pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) return pooled_output + @property def output_dim(self): return self.config.hidden_size @@ -204,34 +204,7 @@ def get_context(self): return self.drop_prob -def MaskedLayerNorm(layerNorm, input, mask=None): - """Masked LayerNorm which will apply mask over the output of LayerNorm to avoid inaccurate updatings to the LayerNorm module. - Args: - layernorm (:obj:`~DeBERTa.deberta.BertLayerNorm`): LayerNorm module or function - input (:obj:`torch.tensor`): The input tensor - mask (:obj:`torch.IntTensor`): The mask to applied on the output of LayerNorm where `0` indicate the output of that element will be ignored, i.e. set to `0` - - Example:: - - # Create a tensor b x n x d - x = torch.randn([1,10,100]) - m = torch.tensor([[1,1,1,0,0,0,0,0,0,0]], dtype=torch.int) - LayerNorm = DeBERTa.deberta.BertLayerNorm(100) - y = MaskedLayerNorm(LayerNorm, x, m) - - """ - output = layerNorm(input).to(input) - if mask is None: - return output - if mask.dim() != input.dim(): - if mask.dim() == 4: - mask = mask.squeeze(1).squeeze(1) - mask = mask.unsqueeze(2) - mask = mask.to(output.dtype) - return output * mask - - -class BertLayerNorm(nn.Module): +class DebertaLayerNorm(nn.Module): """LayerNorm module in the TF style (epsilon inside the square root).""" def __init__(self, size, eps=1e-12): @@ -240,38 +213,37 @@ def __init__(self, size, eps=1e-12): self.bias = nn.Parameter(torch.zeros(size)) self.variance_epsilon = eps - def forward(self, x): - input_type = x.dtype - x = x.float() - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - x = x.to(input_type) - y = self.weight * x + self.bias + def forward(self, hidden_states): + input_type = hidden_states.dtype + hidden_states = hidden_states.float() + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to(input_type) + y = self.weight * hidden_states + self.bias return y -class BertSelfOutput(nn.Module): +# Copied from transformers.modeling_bert.BertSelfOutput with Bert->Deberta +class DebertaSelfOutput(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) - self.config = config - def forward(self, hidden_states, input_states, mask=None): + def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states += input_states - hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -class BertAttention(nn.Module): +class DebertaAttention(nn.Module): def __init__(self, config): super().__init__() self.self = DisentangledSelfAttention(config) - self.output = BertSelfOutput(config) + self.output = DebertaSelfOutput(config) self.config = config def forward( @@ -303,13 +275,15 @@ def forward( return attention_output -class BertIntermediate(nn.Module): +# Copied from transformers.modeling_bert.BertIntermediate with Bert->Deberta +class DebertaIntermediate(nn.Module): def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - self.intermediate_act_fn = ( - ACT2FN[config.hidden_act] if isinstance(config.hidden_act, str) else config.hidden_act - ) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act def forward(self, hidden_states): hidden_states = self.dense(hidden_states) @@ -317,28 +291,28 @@ def forward(self, hidden_states): return hidden_states -class BertOutput(nn.Module): +# Copied from transformers.modeling_bert.BertOutpu with Bert->Deberta +class DebertaOutput(nn.Module): def __init__(self, config): - super(BertOutput, self).__init__() + super(DebertaOutput, self).__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.config = config - def forward(self, hidden_states, input_states, mask=None): + def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states += input_states - hidden_states = MaskedLayerNorm(self.LayerNorm, hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states -class BertLayer(nn.Module): +class DebertaLayer(nn.Module): def __init__(self, config): - super(BertLayer, self).__init__() - self.attention = BertAttention(config) - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) + super(DebertaLayer, self).__init__() + self.attention = DebertaAttention(config) + self.intermediate = DebertaIntermediate(config) + self.output = DebertaOutput(config) def forward( self, @@ -372,8 +346,7 @@ class DeBERTaEncoder(nn.Module): def __init__(self, config): super().__init__() - layer = BertLayer(config) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)]) self.relative_attention = getattr(config, "relative_attention", False) if self.relative_attention: self.max_relative_positions = getattr(config, "max_relative_positions", -1) @@ -520,7 +493,7 @@ def __init__(self, config): self.in_proj = torch.nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) self.q_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) self.v_bias = torch.nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) - self.pos_att_type = [x.strip() for x in getattr(config, "pos_att_type", "none").lower().split("|")] # c2p|p2c + self.pos_att_type = config.pos_att_type self.relative_attention = getattr(config, "relative_attention", False) self.talking_head = getattr(config, "talking_head", False) @@ -600,18 +573,12 @@ def linear(w, b, x): k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] - query_layer += self.transpose_for_scores(self.q_bias.unsqueeze(0).unsqueeze(0)) - value_layer += self.transpose_for_scores(self.v_bias.unsqueeze(0).unsqueeze(0)) + query_layer += self.transpose_for_scores(self.q_bias[None, None, :]) + value_layer += self.transpose_for_scores(self.v_bias[None, None, :]) rel_att = None # Take the dot product between "query" and "key" to get the raw attention scores. - scale_factor = 1 - if "c2p" in self.pos_att_type: - scale_factor += 1 - if "p2c" in self.pos_att_type: - scale_factor += 1 - if "p2p" in self.pos_att_type: - scale_factor += 1 + scale_factor = 1 + len(self.pos_att_type) scale = math.sqrt(query_layer.size(-1) * scale_factor) query_layer = query_layer / scale attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -716,7 +683,7 @@ def __init__(self, config): if self.embedding_size != config.hidden_size: self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) - self.LayerNorm = BertLayerNorm(config.hidden_size, config.layer_norm_eps) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) self.dropout = StableDropout(config.hidden_dropout_prob) self.output_to_half = False self.config = config @@ -756,7 +723,16 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=N if self.embedding_size != self.config.hidden_size: embeddings = self.embed_proj(embeddings) - embeddings = MaskedLayerNorm(self.LayerNorm, embeddings, mask) + embeddings = self.LayerNorm(embeddings) + + if mask.dim() != input.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + embeddings = self.dropout(embeddings) return embeddings From 0e4b45bfa20b21f424ccfc4c29fab387ec680278 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 29 Sep 2020 17:21:35 +0200 Subject: [PATCH 15/18] Not everything can be a copy --- src/transformers/modeling_deberta.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 3ab0543311cf..8f9d09924ef0 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -224,7 +224,6 @@ def forward(self, hidden_states): return y -# Copied from transformers.modeling_bert.BertSelfOutput with Bert->Deberta class DebertaSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -291,7 +290,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.modeling_bert.BertOutpu with Bert->Deberta class DebertaOutput(nn.Module): def __init__(self, config): super(DebertaOutput, self).__init__() From 67e2e292c6e9a18bfcc6c9f1eebe37c40175956a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 29 Sep 2020 17:32:13 +0200 Subject: [PATCH 16/18] Apply most of @sgugger's review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/model_doc/deberta.rst | 4 ++-- src/transformers/configuration_deberta.py | 8 ++++---- src/transformers/modeling_deberta.py | 18 +++++++----------- src/transformers/tokenization_deberta.py | 6 +++--- tests/test_modeling_deberta.py | 2 +- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/docs/source/model_doc/deberta.rst b/docs/source/model_doc/deberta.rst index 56c925c38f86..20f608667e4f 100644 --- a/docs/source/model_doc/deberta.rst +++ b/docs/source/model_doc/deberta.rst @@ -4,7 +4,7 @@ DeBERTa Overview ~~~~~~~~~~~~~~~~~~~~~ -The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `_ +The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `__ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen It is based on Google's BERT model released in 2018 and Facebook's RoBERTa model released in 2019. @@ -23,7 +23,7 @@ on MNLI by +0.9% (90.2% vs. 91.1%), on SQuAD v2.0 by +2.3% (88.4% vs. 90.7%) and models will be made publicly available at https://github.com/microsoft/DeBERTa.* -The original code can be found `here `_. +The original code can be found `here `__. DeBERTaConfig diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 9a3813920ef8..420b7ebd91a2 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020, Microsoft +# Copyright 2020, Microsoft and the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -65,14 +65,14 @@ class DeBERTaConfig(PretrainedConfig): relative_attention (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether use relative position encoding. max_relative_positions (:obj:`int`, `optional`, defaults to 1): - The range of relative positions [`-max_position_embeddings`, `max_position_embeddings`]. - Use the same value as `max_position_embeddings`. + The range of relative positions :obj:`[-max_position_embeddings, max_position_embeddings]`. + Use the same value as :obj:`max_position_embeddings`. pad_token_id (:obj:`int`, `optional`, defaults to 0): The value used to pad input_ids. position_biased_input (:obj:`bool`, `optional`, defaults to :obj:`True`): Whether add absolute position embedding to content embedding. pos_att_type (:obj:`List[str]`, `optional`): - The type of relative position attention, it can be a combination of [`p2c`, `c2p`, `p2p`], + The type of relative position attention, it can be a combination of :obj:`["p2c", "c2p", "p2p"]`, e.g. :obj:`["p2c"]`, :obj:`["p2c", "c2p"]`, :obj:`["p2c", "c2p", 'p2p"]`. layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): The epsilon used by the layer normalization layers. diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 8f9d09924ef0..2fff1f5e0be3 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 Microsoft +# Copyright 2020 Microsoft and the Hugging Face Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -755,19 +755,15 @@ def _init_weights(self, module): DEBERTA_START_DOCSTRING = r""" The DeBERTa model was proposed in - `DeBERTa: Decoding-enhanced BERT with Disentangled Attention`_ + `DeBERTa: Decoding-enhanced BERT with Disentangled Attention `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pre-trianing data. - This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and - refer to the PyTorch documentation for all matter related to general usage and behavior. + This model is also a PyTorch `torch.nn.Module `__ subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general + usage and behavior.``` - .. _`DeBERTa: Decoding-enhanced BERT with Disentangled Attention`: - https://arxiv.org/abs/2006.03654 - - .. _`torch.nn.Module`: - https://pytorch.org/docs/stable/nn.html#module Parameters: config (:class:`~transformers.DeBERTaConfig`): Model configuration class with all the parameters of the model. @@ -843,7 +839,7 @@ def _prune_heads(self, heads_to_prune): """ raise NotImplementedError("The prune function is not implemented in DeBERTa model.") - @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="microsoft/deberta-base", @@ -959,7 +955,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): self.deberta.set_input_embeddings(new_embeddings) - @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) + @add_start_docstrings_to_callable(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="microsoft/deberta-base", diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index 8bf3cd6e50f5..ecf21d49bd75 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2020 Microsoft. +# Copyright 2020 Microsoft and the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -307,7 +307,7 @@ def load_vocab(name=None, tag=None, no_cache=False, cache_dir=None): class GPT2Tokenizer(object): - """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer + """ A wrapper of GPT2 tokenizer with similar interface as BERT tokenizer Args: vocab_file (:obj:`str`, optional): @@ -490,7 +490,7 @@ def save_pretrained(self, path: str): class DeBERTaTokenizer(PreTrainedTokenizer): r""" - Constructs a :class:`~transformers.DeBERTaTokenizer`, which runs end-to-end tokenization: punctuation + Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece Args: diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index 541cfa9684be..f0faf2e53e4e 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2018 Microsoft Authors. +# Copyright 2018 Microsoft Authors and the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From f0846b19eb7156e2d7576e10a5349b086a62f8b2 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 29 Sep 2020 17:45:52 +0200 Subject: [PATCH 17/18] Last reviews --- src/transformers/modeling_deberta.py | 19 +++---- src/transformers/tokenization_deberta.py | 65 ++++++++++++++---------- 2 files changed, 49 insertions(+), 35 deletions(-) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 2fff1f5e0be3..2d34b31fcc40 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -266,7 +266,7 @@ def forward( self_output, att_matrix = self_output if query_states is None: query_states = hidden_states - attention_output = self.output(self_output, query_states, attention_mask) + attention_output = self.output(self_output, query_states) if return_att: return (attention_output, att_matrix) @@ -332,7 +332,7 @@ def forward( if return_att: attention_output, att_matrix = attention_output intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output, attention_mask) + layer_output = self.output(intermediate_output, attention_output) if return_att: return (layer_output, att_matrix) else: @@ -723,13 +723,14 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=N embeddings = self.LayerNorm(embeddings) - if mask.dim() != input.dim(): - if mask.dim() == 4: - mask = mask.squeeze(1).squeeze(1) - mask = mask.unsqueeze(2) - mask = mask.to(embeddings.dtype) + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) - embeddings = embeddings * mask + embeddings = embeddings * mask embeddings = self.dropout(embeddings) return embeddings @@ -940,7 +941,7 @@ def __init__(self, config): self.deberta = DeBERTaModel(config) self.pooler = ContextPooler(config) - output_dim = self.pooler.output_dim() + output_dim = self.pooler.output_dim self.classifier = torch.nn.Linear(output_dim, num_labels) drop_out = getattr(config, "cls_dropout", None) diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index ecf21d49bd75..dd6c5e5c2b8a 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -37,16 +37,8 @@ logger = logging.getLogger(__name__) -#################################################### -# Mapping from the keyword arguments names of Tokenizer `__init__` -# to file names for serializing Tokenizer instances -#################################################### VOCAB_FILES_NAMES = {"vocab_file": "bpe_encoder.bin"} -#################################################### -# Mapping from the keyword arguments names of Tokenizer `__init__` -# to pretrained vocabulary URL for all the model shortcut names. -#################################################### PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { "microsoft/deberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/microsoft/deberta-base/bpe_encoder.bin", @@ -54,19 +46,11 @@ } } -#################################################### -# Mapping from model shortcut names to max length of inputs -#################################################### PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { "microsoft/deberta-base": 512, "microsoft/deberta-large": 512, } -#################################################### -# Mapping from model shortcut names to a dictionary of additional -# keyword arguments for Tokenizer `__init__`. -# To be used for checkpoint specific configurations. -#################################################### PRETRAINED_INIT_CONFIGURATION = { "microsoft/deberta-base": {"do_lower_case": False}, "microsoft/deberta-large": {"do_lower_case": False}, @@ -494,8 +478,26 @@ class DeBERTaTokenizer(PreTrainedTokenizer): splitting + wordpiece Args: - vocab_file: Path to a one-wordpiece-per-line vocabulary file - do_lower_case: Whether to lower case the input. Only has an effect when do_basic_tokenize=True + vocab_file (:obj:`str`): + File containing the vocabulary. + do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not to lowercase the input when tokenizing. + unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences + for sequence classification or for a text and a question for question answering. + It is also used as the last token of a sequence built with special tokens. + pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): + The token used for padding, for example when batching sequences of different lengths. + cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): + The classifier token which is used when doing sequence classification (classification of the whole + sequence instead of per-token classification). It is the first token of the sequence when built with + special tokens. + mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. """ vocab_files_names = VOCAB_FILES_NAMES @@ -514,14 +516,6 @@ def __init__( mask_token="[MASK]", **kwargs ): - """Constructs a XxxTokenizer. - - Args: - **vocab_file**: Path to a one-wordpiece-per-line vocabulary file - **do_lower_case**: (`optional`) boolean (default False) - Whether to lower case the input - Only has an effect when do_basic_tokenize=True - """ super().__init__( unk_token=unk_token, sep_token=sep_token, @@ -578,6 +572,15 @@ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): - single sequence: [CLS] X [SEP] - pair of sequences: [CLS] A [SEP] B [SEP] + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. """ if token_ids_1 is None: @@ -627,6 +630,16 @@ def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): | first sequence | second sequence if token_ids_1 is None, only returns the first portion of the mask (0's). + ~ + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs. + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given + sequence(s). """ sep = [self.sep_token_id] cls = [self.cls_token_id] From 0a08565ef293b14f30f35754139ace6d76f6f362 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Tue, 29 Sep 2020 17:57:09 +0200 Subject: [PATCH 18/18] DeBERTa -> Deberta --- docs/source/model_doc/deberta.rst | 20 ++++++------- src/transformers/__init__.py | 10 +++---- src/transformers/configuration_auto.py | 4 +-- src/transformers/configuration_deberta.py | 6 ++-- src/transformers/modeling_auto.py | 8 ++--- src/transformers/modeling_deberta.py | 30 +++++++++---------- src/transformers/tokenization_auto.py | 6 ++-- src/transformers/tokenization_deberta.py | 4 +-- tests/test_modeling_deberta.py | 36 +++++++++++------------ tests/test_tokenization_deberta.py | 8 ++--- 10 files changed, 66 insertions(+), 66 deletions(-) diff --git a/docs/source/model_doc/deberta.rst b/docs/source/model_doc/deberta.rst index 20f608667e4f..aeb7da69edfb 100644 --- a/docs/source/model_doc/deberta.rst +++ b/docs/source/model_doc/deberta.rst @@ -26,37 +26,37 @@ models will be made publicly available at https://github.com/microsoft/DeBERTa.* The original code can be found `here `__. -DeBERTaConfig +DebertaConfig ~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DeBERTaConfig +.. autoclass:: transformers.DebertaConfig :members: -DeBERTaTokenizer +DebertaTokenizer ~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DeBERTaTokenizer +.. autoclass:: transformers.DebertaTokenizer :members: build_inputs_with_special_tokens, get_special_tokens_mask, create_token_type_ids_from_sequences, save_vocabulary -DeBERTaModel +DebertaModel ~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DeBERTaModel +.. autoclass:: transformers.DebertaModel :members: -DeBERTaPreTrainedModel +DebertaPreTrainedModel ~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DeBERTaPreTrainedModel +.. autoclass:: transformers.DebertaPreTrainedModel :members: -DeBERTaForSequenceClassification +DebertaForSequenceClassification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autoclass:: transformers.DeBERTaForSequenceClassification +.. autoclass:: transformers.DebertaForSequenceClassification :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index f896048beaaf..c888157a85dc 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -35,7 +35,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig -from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig +from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -157,7 +157,7 @@ from .tokenization_bertweet import BertweetTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer -from .tokenization_deberta import DeBERTaTokenizer +from .tokenization_deberta import DebertaTokenizer from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_dpr import ( DPRContextEncoderTokenizer, @@ -314,9 +314,9 @@ from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel, CTRLModel, CTRLPreTrainedModel from .modeling_deberta import ( DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - DeBERTaForSequenceClassification, - DeBERTaModel, - DeBERTaPreTrainedModel, + DebertaForSequenceClassification, + DebertaModel, + DebertaPreTrainedModel, ) from .modeling_distilbert import ( DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index c374924c7a03..a5e76d923ef8 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -23,7 +23,7 @@ from .configuration_bert_generation import BertGenerationConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig -from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DeBERTaConfig +from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig from .configuration_dpr import DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, DPRConfig from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig @@ -102,7 +102,7 @@ ("reformer", ReformerConfig), ("longformer", LongformerConfig), ("roberta", RobertaConfig), - ("deberta", DeBERTaConfig), + ("deberta", DebertaConfig), ("flaubert", FlaubertConfig), ("fsmt", FSMTConfig), ("bert", BertConfig), diff --git a/src/transformers/configuration_deberta.py b/src/transformers/configuration_deberta.py index 420b7ebd91a2..a11527b7b57b 100644 --- a/src/transformers/configuration_deberta.py +++ b/src/transformers/configuration_deberta.py @@ -26,10 +26,10 @@ } -class DeBERTaConfig(PretrainedConfig): +class DebertaConfig(PretrainedConfig): r""" - :class:`~transformers.DeBERTaConfig` is the configuration class to store the configuration of a - :class:`~transformers.DeBERTaModel`. + :class:`~transformers.DebertaConfig` is the configuration class to store the configuration of a + :class:`~transformers.DebertaModel`. Arguments: vocab_size (:obj:`int`, `optional`, defaults to 30522): diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index c4c73bf60ff2..f28ab6952aeb 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -26,7 +26,7 @@ BertGenerationConfig, CamembertConfig, CTRLConfig, - DeBERTaConfig, + DebertaConfig, DistilBertConfig, DPRConfig, ElectraConfig, @@ -91,7 +91,7 @@ CamembertModel, ) from .modeling_ctrl import CTRLLMHeadModel, CTRLModel -from .modeling_deberta import DeBERTaForSequenceClassification, DeBERTaModel +from .modeling_deberta import DebertaForSequenceClassification, DebertaModel from .modeling_distilbert import ( DistilBertForMaskedLM, DistilBertForMultipleChoice, @@ -233,7 +233,7 @@ (FunnelConfig, FunnelModel), (LxmertConfig, LxmertModel), (BertGenerationConfig, BertGenerationEncoder), - (DeBERTaConfig, DeBERTaModel), + (DebertaConfig, DebertaModel), (DPRConfig, DPRQuestionEncoder), ] ) @@ -362,7 +362,7 @@ (XLMConfig, XLMForSequenceClassification), (ElectraConfig, ElectraForSequenceClassification), (FunnelConfig, FunnelForSequenceClassification), - (DeBERTaConfig, DeBERTaForSequenceClassification), + (DebertaConfig, DebertaForSequenceClassification), ] ) diff --git a/src/transformers/modeling_deberta.py b/src/transformers/modeling_deberta.py index 2d34b31fcc40..ec6661a4cc92 100644 --- a/src/transformers/modeling_deberta.py +++ b/src/transformers/modeling_deberta.py @@ -23,7 +23,7 @@ from torch.nn import CrossEntropyLoss from .activations import ACT2FN -from .configuration_deberta import DeBERTaConfig +from .configuration_deberta import DebertaConfig from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable from .modeling_outputs import BaseModelOutput, SequenceClassifierOutput from .modeling_utils import PreTrainedModel @@ -32,8 +32,8 @@ logger = logging.get_logger(__name__) -_CONFIG_FOR_DOC = "DeBERTaConfig" -_TOKENIZER_FOR_DOC = "DeBERTaTokenizer" +_CONFIG_FOR_DOC = "DebertaConfig" +_TOKENIZER_FOR_DOC = "DebertaTokenizer" DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "microsoft/deberta-base", @@ -339,7 +339,7 @@ def forward( return layer_output -class DeBERTaEncoder(nn.Module): +class DebertaEncoder(nn.Module): """Modified BertEncoder with relative position bias support""" def __init__(self, config): @@ -474,7 +474,7 @@ class DisentangledSelfAttention(torch.nn.Module): Parameters: config (:obj:`str`): A model config class instance with the configuration to build a new model. The schema is similar to `BertConfig`, \ - for more details, please refer :class:`~transformers.DeBERTaConfig` + for more details, please refer :class:`~transformers.DebertaConfig` """ @@ -661,7 +661,7 @@ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embedd return score -class DeBERTaEmbeddings(nn.Module): +class DebertaEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, config): @@ -736,12 +736,12 @@ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=N return embeddings -class DeBERTaPreTrainedModel(PreTrainedModel): +class DebertaPreTrainedModel(PreTrainedModel): """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = DeBERTaConfig + config_class = DebertaConfig base_model_prefix = "deberta" authorized_missing_keys = ["position_ids"] @@ -767,7 +767,7 @@ def _init_weights(self, module): Parameters: - config (:class:`~transformers.DeBERTaConfig`): Model configuration class with all the parameters of the model. + config (:class:`~transformers.DebertaConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ @@ -777,7 +777,7 @@ def _init_weights(self, module): input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): Indices of input sequence tokens in the vocabulary. - Indices can be obtained using :class:`transformers.DeBERTaTokenizer`. + Indices can be obtained using :class:`transformers.DebertaTokenizer`. See :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for details. @@ -817,12 +817,12 @@ def _init_weights(self, module): "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", DEBERTA_START_DOCSTRING, ) -class DeBERTaModel(DeBERTaPreTrainedModel): +class DebertaModel(DebertaPreTrainedModel): def __init__(self, config): super().__init__(config) - self.embeddings = DeBERTaEmbeddings(config) - self.encoder = DeBERTaEncoder(config) + self.embeddings = DebertaEmbeddings(config) + self.encoder = DebertaEncoder(config) self.z_steps = 0 self.config = config self.init_weights() @@ -932,14 +932,14 @@ def forward( the pooled output) e.g. for GLUE tasks. """, DEBERTA_START_DOCSTRING, ) -class DeBERTaForSequenceClassification(DeBERTaPreTrainedModel): +class DebertaForSequenceClassification(DebertaPreTrainedModel): def __init__(self, config): super().__init__(config) num_labels = getattr(config, "num_labels", 2) self.num_labels = num_labels - self.deberta = DeBERTaModel(config) + self.deberta = DebertaModel(config) self.pooler = ContextPooler(config) output_dim = self.pooler.output_dim diff --git a/src/transformers/tokenization_auto.py b/src/transformers/tokenization_auto.py index e887ac517d48..052b133bb867 100644 --- a/src/transformers/tokenization_auto.py +++ b/src/transformers/tokenization_auto.py @@ -25,7 +25,7 @@ BertGenerationConfig, CamembertConfig, CTRLConfig, - DeBERTaConfig, + DebertaConfig, DistilBertConfig, DPRConfig, ElectraConfig, @@ -62,7 +62,7 @@ from .tokenization_bertweet import BertweetTokenizer from .tokenization_camembert import CamembertTokenizer from .tokenization_ctrl import CTRLTokenizer -from .tokenization_deberta import DeBERTaTokenizer +from .tokenization_deberta import DebertaTokenizer from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast from .tokenization_dpr import DPRQuestionEncoderTokenizer, DPRQuestionEncoderTokenizerFast from .tokenization_electra import ElectraTokenizer, ElectraTokenizerFast @@ -127,7 +127,7 @@ (CTRLConfig, (CTRLTokenizer, None)), (FSMTConfig, (FSMTTokenizer, None)), (BertGenerationConfig, (BertGenerationTokenizer, None)), - (DeBERTaConfig, (DeBERTaTokenizer, None)), + (DebertaConfig, (DebertaTokenizer, None)), (LayoutLMConfig, (LayoutLMTokenizer, None)), (RagConfig, (RagTokenizer, None)), ] diff --git a/src/transformers/tokenization_deberta.py b/src/transformers/tokenization_deberta.py index dd6c5e5c2b8a..086b5d506e29 100644 --- a/src/transformers/tokenization_deberta.py +++ b/src/transformers/tokenization_deberta.py @@ -56,7 +56,7 @@ "microsoft/deberta-large": {"do_lower_case": False}, } -__all__ = ["DeBERTaTokenizer"] +__all__ = ["DebertaTokenizer"] @lru_cache() @@ -472,7 +472,7 @@ def save_pretrained(self, path: str): torch.save(self.gpt2_encoder, path) -class DeBERTaTokenizer(PreTrainedTokenizer): +class DebertaTokenizer(PreTrainedTokenizer): r""" Constructs a DeBERTa tokenizer, which runs end-to-end tokenization: punctuation splitting + wordpiece diff --git a/tests/test_modeling_deberta.py b/tests/test_modeling_deberta.py index f0faf2e53e4e..33994074a083 100644 --- a/tests/test_modeling_deberta.py +++ b/tests/test_modeling_deberta.py @@ -30,21 +30,21 @@ import torch from transformers import ( # XxxForMaskedLM,; XxxForQuestionAnswering,; XxxForTokenClassification, - DeBERTaConfig, - DeBERTaForSequenceClassification, - DeBERTaModel, + DebertaConfig, + DebertaForSequenceClassification, + DebertaModel, ) from transformers.modeling_deberta import DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST @require_torch -class DeBERTaModelTest(ModelTesterMixin, unittest.TestCase): +class DebertaModelTest(ModelTesterMixin, unittest.TestCase): all_model_classes = ( ( - DeBERTaModel, - DeBERTaForSequenceClassification, - ) # , DeBERTaForMaskedLM, DeBERTaForQuestionAnswering, DeBERTaForTokenClassification) + DebertaModel, + DebertaForSequenceClassification, + ) # , DebertaForMaskedLM, DebertaForQuestionAnswering, DebertaForTokenClassification) if is_torch_available() else () ) @@ -54,7 +54,7 @@ class DeBERTaModelTest(ModelTesterMixin, unittest.TestCase): test_head_masking = False is_encoder_decoder = False - class DeBERTaModelTester(object): + class DebertaModelTester(object): def __init__( self, parent, @@ -128,7 +128,7 @@ def prepare_config_and_inputs(self): token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) choice_labels = ids_tensor([self.batch_size], self.num_choices) - config = DeBERTaConfig( + config = DebertaConfig( vocab_size=self.vocab_size, hidden_size=self.hidden_size, num_hidden_layers=self.num_hidden_layers, @@ -153,7 +153,7 @@ def check_loss_output(self, result): def create_and_check_deberta_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): - model = DeBERTaModel(config=config) + model = DebertaModel(config=config) model.to(torch_device) model.eval() sequence_output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)[0] @@ -171,7 +171,7 @@ def create_and_check_deberta_for_sequence_classification( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): config.num_labels = self.num_labels - model = DeBERTaForSequenceClassification(config) + model = DebertaForSequenceClassification(config) model.to(torch_device) model.eval() loss, logits = model( @@ -199,8 +199,8 @@ def prepare_config_and_inputs_for_common(self): return config, inputs_dict def setUp(self): - self.model_tester = DeBERTaModelTest.DeBERTaModelTester(self) - self.config_tester = ConfigTester(self, config_class=DeBERTaConfig, hidden_size=37) + self.model_tester = DebertaModelTest.DebertaModelTester(self) + self.config_tester = ConfigTester(self, config_class=DebertaConfig, hidden_size=37) def test_config(self): self.config_tester.run_common_tests() @@ -231,12 +231,12 @@ def test_for_token_classification(self): @slow def test_model_from_pretrained(self): for model_name in DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - model = DeBERTaModel.from_pretrained(model_name) + model = DebertaModel.from_pretrained(model_name) self.assertIsNotNone(model) @require_torch -class DeBERTaModelIntegrationTest(unittest.TestCase): +class DebertaModelIntegrationTest(unittest.TestCase): @unittest.skip(reason="Model not available yet") def test_inference_masked_lm(self): pass @@ -247,8 +247,8 @@ def test_inference_no_head(self): np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) - DeBERTaModel.base_model_prefix = "bert" - model = DeBERTaModel.from_pretrained("microsoft/deberta-base") + DebertaModel.base_model_prefix = "bert" + model = DebertaModel.from_pretrained("microsoft/deberta-base") input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) output = model(input_ids)[0] @@ -264,7 +264,7 @@ def test_inference_classification_head(self): np.random.seed(0) torch.manual_seed(0) torch.cuda.manual_seed_all(0) - model = DeBERTaForSequenceClassification.from_pretrained("microsoft/deberta-base") + model = DebertaForSequenceClassification.from_pretrained("microsoft/deberta-base") input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) output = model(input_ids)[0] expected_shape = torch.Size((1, 2)) diff --git a/tests/test_tokenization_deberta.py b/tests/test_tokenization_deberta.py index e2085396512b..29437d6cc804 100644 --- a/tests/test_tokenization_deberta.py +++ b/tests/test_tokenization_deberta.py @@ -19,21 +19,21 @@ from typing import Tuple from transformers.testing_utils import require_torch -from transformers.tokenization_deberta import DeBERTaTokenizer +from transformers.tokenization_deberta import DebertaTokenizer from .test_tokenization_common import TokenizerTesterMixin @require_torch -class DeBERTaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): +class DebertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): - tokenizer_class = DeBERTaTokenizer + tokenizer_class = DebertaTokenizer def setUp(self): super().setUp() def get_tokenizer(self, name="microsoft/deberta-base", **kwargs): - return DeBERTaTokenizer.from_pretrained(name, **kwargs) + return DebertaTokenizer.from_pretrained(name, **kwargs) def get_input_output_texts(self, tokenizer): input_text = "lower newer"