Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d19e7c9
refactor
patrickvonplaten Nov 23, 2020
814cb3d
further refactor
patrickvonplaten Nov 23, 2020
8a384f1
fix the rest tomorrow
patrickvonplaten Nov 23, 2020
f754101
save intermediate
patrickvonplaten Nov 26, 2020
1497056
finish slow tokenizer
patrickvonplaten Nov 26, 2020
79fdcf8
make more tests pass
patrickvonplaten Nov 26, 2020
293d19a
finish refactor
patrickvonplaten Nov 26, 2020
4215881
fix comment
patrickvonplaten Nov 26, 2020
60bcfb3
clean further
patrickvonplaten Nov 26, 2020
c0a5983
fix name
patrickvonplaten Nov 26, 2020
898d29e
fix naming
patrickvonplaten Nov 26, 2020
cd7487b
Update src/transformers/models/reformer/tokenization_reformer.py
patrickvonplaten Nov 26, 2020
95e8cbd
Apply suggestions from code review
patrickvonplaten Nov 26, 2020
e249ed2
Apply suggestions from code review
patrickvonplaten Nov 26, 2020
bf325e6
refactor
patrickvonplaten Nov 26, 2020
3c2921c
fix init tokenizers
patrickvonplaten Nov 26, 2020
da39bee
refactor
patrickvonplaten Nov 26, 2020
b7a052a
improve convert
patrickvonplaten Nov 27, 2020
7e0720e
refactor
patrickvonplaten Nov 27, 2020
6915481
correct convert slow tokenizer
patrickvonplaten Nov 27, 2020
b39abb6
Merge remote-tracking branch 'main/master' into refactor_pegasus_tok
patrickvonplaten Nov 27, 2020
49e8cb7
Merge branch 'master' of https://github.com/huggingface/transformers …
patrickvonplaten Nov 27, 2020
c0e6663
final fix for Pegasus Tok
patrickvonplaten Nov 27, 2020
bb97982
remove ipdb
patrickvonplaten Nov 27, 2020
7b1a4a5
improve links
patrickvonplaten Nov 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,10 +531,12 @@ class BertGenerationConverter(SpmConverter):
class PegasusConverter(SpmConverter):
def vocab(self, proto):
vocab = [
(self.original_tokenizer.pad_token, 0),
(self.original_tokenizer.eos_token, 0),
(self.original_tokenizer.pad_token, 0.0),
(self.original_tokenizer.eos_token, 0.0),
(self.original_tokenizer.mask_token_sent, 0.0),
(self.original_tokenizer.mask_token, 0.0),
]
vocab += [(f"unk_{i}", -100) for i in range(2, 2 + self.original_tokenizer.offset)]
Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Nov 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was wrong previously -> it should have been "<unk_{i}>"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

vocab += [(f"<unk_{i}>", -100.0) for i in range(2, self.original_tokenizer.offset)]
vocab += [(piece.piece, piece.score) for piece in proto.pieces[2:]]
return vocab

Expand All @@ -543,13 +545,10 @@ def unk_id(self, proto):

def post_processor(self):
eos = self.original_tokenizer.eos_token
return processors.TemplateProcessing(
single=["$A", eos],
pair=["$A", "$B", eos],
special_tokens=[
(eos, self.original_tokenizer.eos_token_id),
],
)
special_tokens = [
(eos, self.original_tokenizer.eos_token_id),
]
return processors.TemplateProcessing(single=["$A", eos], pair=["$A", "$B", eos], special_tokens=special_tokens)


class T5Converter(SpmConverter):
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/models/albert/tokenization_albert_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@

class AlbertTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" ALBERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on `SentencePiece
<https://github.com/google/sentencepiece>`__. This tokenizer inherits from
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomwolf @LysandreJik @n1t0 - I don't think the fast tokenizers are based on google's sentencepiece anymore, so I removed this statement from all fast tokenizers.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed

:class:`~transformers.PreTrainedTokenizerFast` which contains most of the main methods. Users should refer to this
superclass for more information regarding those methods
Construct a "fast" ALBERT tokenizer (backed by HuggingFace's `tokenizers` library). Based on `Unigram
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=unigram#models>`__. This tokenizer
inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main methods. Users should
refer to this superclass for more information regarding those methods

Args:
vocab_file (:obj:`str`):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
class CamembertTokenizerFast(PreTrainedTokenizerFast):
"""
Construct a "fast" CamemBERT tokenizer (backed by HuggingFace's `tokenizers` library). Adapted from
:class:`~transformers.RobertaTokenizer` and :class:`~transformers.XLNetTokenizer`. Based on `SentencePiece
<https://github.com/google/sentencepiece>`__.
:class:`~transformers.RobertaTokenizer` and :class:`~transformers.XLNetTokenizer`. Based on `BPE
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.

This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the main
methods. Users should refer to this superclass for more information regarding those methods.
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/mbart/tokenization_mbart_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@

class MBartTokenizerFast(XLMRobertaTokenizerFast):
"""
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library).
Construct a "fast" MBART tokenizer (backed by HuggingFace's `tokenizers` library). Based on `BPE
<https://huggingface.co/docs/tokenizers/python/latest/components.html?highlight=BPE#models>`__.

:class:`~transformers.MBartTokenizerFast` is a subclass of :class:`~transformers.XLMRobertaTokenizerFast` and adds
a new :meth:`~transformers.MBartTokenizerFast.prepare_seq2seq_batch`.
Expand Down
182 changes: 159 additions & 23 deletions src/transformers/models/pegasus/tokenization_pegasus.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional
import os
from shutil import copyfile
from typing import Dict, List, Optional, Tuple

import sentencepiece as spm

from ...file_utils import add_start_docstrings
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ..reformer.tokenization_reformer import ReformerTokenizer
from ...utils import logging


SPIECE_UNDERLINE = "▁"
Expand All @@ -32,31 +37,145 @@
}


class PegasusTokenizer(ReformerTokenizer):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pegasus has nothing to do with Reformer, so decouple it here.

logger = logging.get_logger(__name__)


class PegasusTokenizer(PreTrainedTokenizer):
r"""
Construct a Pegasus tokenizer.
Construct a PEGASUS tokenizer. Based on `SentencePiece <https://github.com/google/sentencepiece>`__.

This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
Users should refer to this superclass for more information regarding those methods.

:class:`~transformers.PegasusTokenizer` is identical to :class:`~transformers.ReformerTokenizer` and adds a new
:meth:`~transformers.PegasusTokenizer.prepare_seq2seq_batch`
Args:
vocab_file (:obj:`str`):
`SentencePiece <https://github.com/google/sentencepiece>`__ file (generally has a `.spm` extension) that
contains the vocabulary necessary to instantiate a tokenizer.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sequence token.

Refer to superclass :class:`~transformers.ReformerTokenizer` for usage examples and documentation concerning the
initialization parameters and other methods.
.. note::

When building a sequence using special tokens, this is not the token that is used for the end of
sequence. The token used is the :obj:`sep_token`.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
mask_token (:obj:`str`, `optional`, defaults to :obj:`"<mask_2>"`):
The token used for masking single token values. This is the token used when training this model with masked
language modeling (MLM). This is the token that the PEGASUS encoder will try to predict during pretraining.
It corresponds to `[MASK2]` in `PEGASUS: Pre-training with Extracted Gap-sentences for Abstractive
Summarization <https://arxiv.org/pdf/1912.08777.pdf>`__.
mask_token_sent (:obj:`str`, `optional`, defaults to :obj:`"<mask_1>"`):
The token used for masking whole target sentences. This is the token used when training this model with gap
sentences generation (GSG). This is the sentence that the PEGASUS decoder will try to predict during
pretraining. It corresponds to `[MASK1]` in `PEGASUS: Pre-training with Extracted Gap-sentences for
Abstractive Summarization <https://arxiv.org/pdf/1912.08777.pdf>`__.
Comment on lines +66 to +75
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great docs here.

additional_special_tokens (:obj:`List[str]`, `optional`):
Additional special tokens used by the tokenizer. If no additional_special_tokens are provided <mask_2> and
<unk_2, ..., unk_102> are used as additional special tokens corresponding to the `original PEGASUS
tokenizer
<https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66>`__
that uses the tokens 2 - 104 only for pretraining
"""
offset = 103 # entries 2-104 are only used for pretraining
vocab_files_names = VOCAB_FILES_NAMES

offset = 103 # entries 2 - 104 are only used for pretraining
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]

def __init__(
self,
vocab_file,
pad_token="<pad>",
eos_token="</s>",
unk_token="<unk>",
mask_token="<mask_2>",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pegasus has two masked tokens that were previously not added to the tokenizer. There are defined as the 2nd and 3rd token according to the original implementation: https://github.com/google-research/pegasus/blob/939830367bcf411193d2b5eca2f2f90f3f9260ca/pegasus/ops/pretrain_parsing_ops.cc#L66

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolves both this #8536 and this #8594 issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the bos_token which is supposed to be passed into the decoder is missing?

mask_token_sent="<mask_1>",
additional_special_tokens=None,
**kwargs
):
if additional_special_tokens is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As Sam pointed out before tokens 2-104 were only used for pre-training. I think to add them to the additional_special_tokens in this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes indeed, that the good place to put them.

assert isinstance(
additional_special_tokens, list
), f"additional_special_tokens should be of type {type(list)}, but is {type(additional_special_tokens)}"

additional_special_tokens_extended = (
([mask_token_sent] + additional_special_tokens)
if mask_token_sent not in additional_special_tokens
else additional_special_tokens
)
# fill additional tokens with ..., <unk_token_102> in case not all additional tokens are already taken
additional_special_tokens_extended += [
f"<unk_{i}>" for i in range(len(additional_special_tokens_extended), self.offset - 1)
]

if len(set(additional_special_tokens_extended)) != len(additional_special_tokens_extended):
raise ValueError(
f"Please make sure that the provided additional_special_tokens do not contain an incorrectly shifted list of <unk_x> tokens. Found {additional_special_tokens_extended}."
)
additional_special_tokens = additional_special_tokens_extended
else:
additional_special_tokens = [mask_token_sent]
additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]

def __init__(self, *args, pad_token="<pad>", **kwargs):
super().__init__(*args, **kwargs, pad_token="<pad>")
# Don't use reserved words added_token_encoder, added_tokens_decoder because of
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
assert len(self.added_tokens_decoder) == 0
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
# entries 2-104 are only used for pretraining and called unk_2, ...unk_104
self.encoder.update({i: f"unk_{i}" for i in range(2, self.offset + 2)})
super().__init__(
eos_token=eos_token,
unk_token=unk_token,
mask_token=mask_token,
pad_token=pad_token,
mask_token_sent=mask_token_sent,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
self.mask_token_sent = mask_token_sent

# add special tokens to encoder dict
self.encoder: Dict[int, str] = {
0: self.pad_token,
1: self.eos_token,
2: self.mask_token_sent,
3: self.mask_token,
}
# entries 2-104 are only used for pretraining and called <mask_1>, <mask_2>, unk_2, ...unk_102
# mask_token_sent is already added to list -> so start at 1
self.encoder.update({i + 3: additional_special_tokens[i] for i in range(1, self.offset - 1)})
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}

@property
def vocab_size(self) -> int:
return len(self.sp_model) + self.offset

def get_vocab(self) -> Dict[str, int]:
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab

def __getstate__(self):
state = self.__dict__.copy()
state["sp_model"] = None
return state

def __setstate__(self, d):
self.__dict__ = d
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)

def _tokenize(self, text, sample=False):
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
return pieces

def _convert_token_to_id(self, token: str) -> int:
""" Converts a token (str) to an id using the vocab. """
if token in self.decoder:
Expand All @@ -73,13 +192,13 @@ def _convert_id_to_token(self, index: int) -> str:
elif index in self.added_tokens_encoder:
return self.added_tokens_encoder[index]
else:
# assert index > self.offset, f"cannot decode ids between 2 and {self.offset}. Got {index}"
token = self.sp_model.IdToPiece(index - self.offset)
return token

@property
def vocab_size(self) -> int:
return len(self.sp_model) + self.offset
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
out_string = self.sp_model.decode_pieces(tokens)
return out_string

def num_special_tokens_to_add(self, pair=False):
"""Just EOS"""
Expand All @@ -88,7 +207,11 @@ def num_special_tokens_to_add(self, pair=False):
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
all_special_ids.remove(self.unk_token_id) # <unk> is only sometimes special
assert all_special_ids == set([0, 1])

assert all_special_ids == set(
range(len(self.additional_special_tokens) + 3)
), f"There should be 3 special tokens: mask_token, pad_token, and eos_token + {len(self.additional_special_tokens)} additional_special_tokens, but got {all_special_ids}"

return [1 if x in all_special_ids else 0 for x in seq]

def get_special_tokens_mask(
Expand All @@ -105,7 +228,7 @@ def get_special_tokens_mask(
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequences for sequence classification tasks by concatenating
and adding special tokens. A Pegasus sequence has the following format, where ``X`` represents the sequence:
and adding special tokens. A PEGASUS sequence has the following format, where ``X`` represents the sequence:

- single sequence: ``X </s>``
- pair of sequences: ``A B </s>`` (not intended use)
Expand Down Expand Up @@ -156,3 +279,16 @@ def prepare_seq2seq_batch(
labels: BatchEncoding = self(tgt_texts, **tokenizer_kwargs)["input_ids"]
model_inputs["labels"] = labels
return model_inputs

def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
)

if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)

return (out_vocab_file,)
Loading