Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f6ae66f
add conversion script
patrickvonplaten Aug 17, 2020
525f6db
improve conversion script
patrickvonplaten Aug 18, 2020
ce2af43
make style
patrickvonplaten Aug 20, 2020
213353d
add tryout files
patrickvonplaten Sep 6, 2020
c5b7efd
fix
patrickvonplaten Sep 7, 2020
15fea8e
update
patrickvonplaten Sep 7, 2020
da477b6
add causal bert
patrickvonplaten Sep 7, 2020
749e96c
better names
patrickvonplaten Sep 7, 2020
06f1517
add tokenizer file as well
patrickvonplaten Sep 7, 2020
74e38e3
finish causal_bert
patrickvonplaten Sep 7, 2020
4c5340e
fix small bugs
patrickvonplaten Sep 7, 2020
8f2e8c3
improve generate
patrickvonplaten Sep 8, 2020
bf4f425
change naming
patrickvonplaten Sep 8, 2020
1b9a716
renaming
patrickvonplaten Sep 8, 2020
355b8a5
renaming
patrickvonplaten Sep 8, 2020
7d0ea85
renaming
patrickvonplaten Sep 8, 2020
738320e
remove leftover files
patrickvonplaten Sep 8, 2020
a33d06c
clean files
patrickvonplaten Sep 8, 2020
b15c962
add fix tokenizer
patrickvonplaten Sep 9, 2020
a6392eb
finalize
patrickvonplaten Sep 9, 2020
db06984
correct slow test
patrickvonplaten Sep 9, 2020
ed36280
update docs
patrickvonplaten Sep 9, 2020
e88e005
small fixes
patrickvonplaten Sep 9, 2020
e782bf8
fix link
patrickvonplaten Sep 9, 2020
550beed
adapt check repo
patrickvonplaten Sep 9, 2020
85e1294
apply sams and sylvains recommendations
patrickvonplaten Sep 10, 2020
d69d0da
fix import
patrickvonplaten Sep 10, 2020
d432196
implement Lysandres recommendations
patrickvonplaten Sep 10, 2020
aa953cb
fix logger warn
patrickvonplaten Sep 10, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ conversion utilities for the following models:
26. `Funnel Transformer <https://github.com/laiguokun/Funnel-Transformer>`_ (from CMU/Google Brain) released with the paper
`Funnel-Transformer: Filtering out Sequential Redundancy for Efficient Language Processing
<https://arxiv.org/abs/2006.03236>`_ by Zihang Dai, Guokun Lai, Yiming Yang, Quoc V. Le.
27. `Other community models <https://huggingface.co/models>`_, contributed by the `community
27. `Bert For Sequence Generation <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`_ (from Google) released with the paper
`Leveraging Pre-trained Checkpoints for Sequence Generation Tasks
<https://arxiv.org/abs/1907.12461>`_ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
28. `Other community models <https://huggingface.co/models>`_, contributed by the `community
<https://huggingface.co/users>`_.

.. toctree::
Expand Down Expand Up @@ -221,6 +224,7 @@ conversion utilities for the following models:
model_doc/mbart
model_doc/funnel
model_doc/lxmert
model_doc/bertgeneration
internal/modeling_utils
internal/tokenization_utils
internal/pipelines_utils
82 changes: 82 additions & 0 deletions docs/source/model_doc/bertgeneration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
BertForSeqGeneration
----------------------------------------------------

Overview
~~~~~~~~~~~~~~~~~~~~~

The BertForSeqGeneration model is a BERT model that can be leveraged for sequence-to-sequence tasks using :class:`~transformers.EncoderDecoderModel` as proposed in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.

The abstract from the paper is the following:

*Unsupervised pre-training of large neural models has recently revolutionized Natural Language Processing. By warm-starting from the publicly released checkpoints, NLP practitioners have pushed the state-of-the-art on multiple benchmarks while saving significant amounts of compute time. So far the focus has been mainly on the Natural Language Understanding tasks. In this paper, we demonstrate the efficacy of pre-trained checkpoints for Sequence Generation. We developed a Transformer-based sequence-to-sequence model that is compatible with publicly available pre-trained BERT, GPT-2 and RoBERTa checkpoints and conducted an extensive empirical study on the utility of initializing our model, both encoder and decoder, with these checkpoints. Our models result in new state-of-the-art results on Machine Translation, Text Summarization, Sentence Splitting, and Sentence Fusion.*

Usage:

- The model can be used in combination with the :class:`~transformers.EncoderDecoderModel` to leverage two bert pretrained bert checkpoints for subsequent fine-tuning.

::

# leverage checkpoints for Bert2Bert model...
encoder = BertGenerationEncoder.from_pretrained("bert-large-uncased", bos_token_id=101, eos_token_id=102) # use BERT's cls token as BOS token and sep token as EOS token
decoder = BertGenerationDecoder.from_pretrained("bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102) # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

# create tokenizer...
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")

input_ids = tokenizer('This is a long article to summarize', add_special_tokens=False, return_tensors="pt").input_ids
labels = tokenizer('This is a short summary', return_tensors="pt").input_ids

# train...
loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels, return_dict=True).loss
loss.backward()


- Pretrained :class:`~transformers.EncoderDecoderModel` are also directly available in the model hub, *e.g.*:


::

# instantiate sentence fusion model
sentence_fuser = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse")
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")

input_ids = tokenizer('This is the first sentence. This is the second sentence.', add_special_tokens=False, return_tensors="pt").input_ids

outputs = sentence_fuser.generate(input_ids)

print(tokenizer.decode(outputs[0]))


Tips:

- :class:`~transformers.BertGenerationEncoder` and :class:`~transformers.BertGenerationDecoder` should be used in combination with :class:`~transformers.EncoderDecoder`.
- For summarization, sentence splitting, sentence fusion and translation, no special tokens are required for the input. Therefore, no EOS token should be added to the end of the input.

The original code can be found `here <https://tfhub.dev/s?module-type=text-generation&subtype=module,placeholder>`__.

BertGenerationConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertGenerationConfig
:members:


BertGenerationTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertGenerationTokenizer
:members:

BertGenerationEncoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertGenerationEncoder
:members:


BertGenerationDecoder
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: transformers.BertGenerationDecoder
:members:
7 changes: 7 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from .configuration_bart import BartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_bert_generation import BertGenerationConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
Expand Down Expand Up @@ -142,6 +143,7 @@
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .tokenization_bart import BartTokenizer, BartTokenizerFast
from .tokenization_bert import BasicTokenizer, BertTokenizer, BertTokenizerFast, WordpieceTokenizer
from .tokenization_bert_generation import BertGenerationTokenizer
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
Expand Down Expand Up @@ -277,6 +279,11 @@
BertPreTrainedModel,
load_tf_weights_in_bert,
)
from .modeling_bert_generation import (
BertGenerationDecoder,
BertGenerationEncoder,
load_tf_weights_in_bert_generation,
)
from .modeling_camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
CamembertForCausalLM,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_bart import BART_PRETRAINED_CONFIG_ARCHIVE_MAP, BartConfig
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
from .configuration_bert_generation import BertGenerationConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
Expand Down Expand Up @@ -97,6 +98,10 @@
"albert",
AlbertConfig,
),
(
"bert-for-seq-generation",
BertGenerationConfig,
),
(
"camembert",
CamembertConfig,
Expand Down
106 changes: 106 additions & 0 deletions src/transformers/configuration_bert_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# coding=utf-8
# Copyright 2020 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BertForSeqGeneration model configuration """

from .configuration_utils import PretrainedConfig


class BertGenerationConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.BertGenerationPreTrainedModel`.
It is used to instantiate a BertGenerationConfig model according to the specified arguments, defining the model architecture.

Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
for more information.


Args:
vocab_size (:obj:`int`, `optional`, defaults to 50358):
Vocabulary size of the BertForSeqGeneration model. Defines the different tokens that
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertForSeqGeneration`.
hidden_size (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 24):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler.
If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with.
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.

Example::

>>> from transformers import BertGenerationConfig, BertGenerationEncoder

>>> # Initializing a BertForSeqGeneration config
>>> configuration = BertGenerationConfig()

>>> # Initializing a modelfrom the config
>>> model = BertGenerationEncoder(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "bert-for-seq-generation"

def __init__(
self,
vocab_size=50358,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
88 changes: 88 additions & 0 deletions src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Seq2Seq TF Hub checkpoint."""


import argparse

from transformers import (
BertConfig,
BertGenerationConfig,
BertGenerationDecoder,
BertGenerationEncoder,
load_tf_weights_in_bert_generation,
logging,
)


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_hub_path, pytorch_dump_path, is_encoder_named_decoder, vocab_size, is_encoder):
# Initialise PyTorch model
bert_config = BertConfig.from_pretrained(
"bert-large-cased",
vocab_size=vocab_size,
max_position_embeddings=512,
is_decoder=True,
add_cross_attention=True,
)
bert_config_dict = bert_config.to_dict()
del bert_config_dict["type_vocab_size"]
config = BertGenerationConfig(**bert_config_dict)
if is_encoder:
model = BertGenerationEncoder(config)
else:
model = BertGenerationDecoder(config)
print("Building PyTorch model from configuration: {}".format(str(config)))

# Load weights from tf checkpoint
load_tf_weights_in_bert_generation(
model,
tf_hub_path,
model_class="bert",
is_encoder_named_decoder=is_encoder_named_decoder,
is_encoder=is_encoder,
)

# Save pytorch-model
print("Save PyTorch model and config to {}".format(pytorch_dump_path))
model.save_pretrained(pytorch_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--tf_hub_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
)
parser.add_argument(
"--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
)
parser.add_argument(
"--is_encoder_named_decoder",
action="store_true",
help="If decoder has to be renamed to encoder in PyTorch model.",
)
parser.add_argument("--is_encoder", action="store_true", help="If model is an encoder.")
parser.add_argument("--vocab_size", default=50358, type=int, help="Vocab size of model")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(
args.tf_hub_path,
args.pytorch_dump_path,
args.is_encoder_named_decoder,
args.vocab_size,
is_encoder=args.is_encoder,
)
2 changes: 1 addition & 1 deletion src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def docstring_decorator(fn):
code_sample = TF_MASKED_LM_SAMPLE if is_tf_class else PT_MASKED_LM_SAMPLE
elif "LMHead" in model_class:
code_sample = TF_CAUSAL_LM_SAMPLE if is_tf_class else PT_CAUSAL_LM_SAMPLE
elif "Model" in model_class:
elif "Model" in model_class or "Encoder" in model_class:
code_sample = TF_BASE_MODEL_SAMPLE if is_tf_class else PT_BASE_MODEL_SAMPLE
else:
raise ValueError(f"Docstring can't be built for model {model_class}")
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,11 @@ def generate(
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need one for more check for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(out of scope)
I would be down for a helper method
determine_decoder_start_token_id to get this out of the main block.

):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
AutoConfig,
BartConfig,
BertConfig,
BertGenerationConfig,
CamembertConfig,
CTRLConfig,
DistilBertConfig,
Expand Down Expand Up @@ -73,6 +74,7 @@
BertLMHeadModel,
BertModel,
)
from .modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from .modeling_camembert import (
CamembertForCausalLM,
CamembertForMaskedLM,
Expand Down Expand Up @@ -213,6 +215,7 @@
(ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel),
(LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder),
]
)

Expand Down Expand Up @@ -284,6 +287,7 @@
), # XLM can be MLM and CLM => model should be split similar to BERT; leave here for now
(CTRLConfig, CTRLLMHeadModel),
(ReformerConfig, ReformerModelWithLMHead),
(BertGenerationConfig, BertGenerationDecoder),
]
)

Expand Down
Loading