Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
f8f7487
Implemented fast version of tokenizers
mfuntowicz Jan 27, 2020
c435009
Bumped tokenizers version requirements to latest 0.2.1
mfuntowicz Jan 27, 2020
96bc6e6
Added matching tests
mfuntowicz Jan 27, 2020
c2a5805
Matching OpenAI GPT tokenization !
mfuntowicz Jan 28, 2020
92ce90d
Matching GPT2 on tokenizers
mfuntowicz Jan 29, 2020
0e19ed3
Expose add_prefix_space as constructor parameter for GPT2
mfuntowicz Jan 29, 2020
7f5e943
Matching Roberta tokenization !
mfuntowicz Jan 29, 2020
8d4322a
Removed fast implementation of CTRL.
mfuntowicz Feb 3, 2020
02dcd7c
Binding TransformerXL tokenizers to Rust.
mfuntowicz Feb 3, 2020
4a2ef66
Updating tests accordingly.
mfuntowicz Feb 3, 2020
3ad2ed6
Added tokenizers as top-level modules.
mfuntowicz Feb 3, 2020
233773f
Black & isort.
mfuntowicz Feb 3, 2020
0f31b31
Rename LookupTable to WordLevel to match Rust side.
mfuntowicz Feb 4, 2020
f806ec1
Black.
mfuntowicz Feb 4, 2020
e58a2ad
Use "fast" suffix instead of "ru" for rust tokenizers implementations.
mfuntowicz Feb 5, 2020
8ec45ee
Introduce tokenize() method on fast tokenizers.
mfuntowicz Feb 5, 2020
fa926ed
encode_plus dispatchs to batch_encode_plus
mfuntowicz Feb 5, 2020
da3a899
batch_encode_plus now dispatchs to encode if there is only one input …
mfuntowicz Feb 5, 2020
a63b25d
Bind all the encode_plus parameter to the forwarded batch_encode_plus…
mfuntowicz Feb 5, 2020
11414a6
Bump tokenizers dependency to 0.3.0
mfuntowicz Feb 5, 2020
b4cf279
Formatting.
mfuntowicz Feb 5, 2020
4b57478
Fix tokenization_auto with support for new (python, fast) mapping sch…
mfuntowicz Feb 6, 2020
5cdef87
Give correct fixtures path in test_tokenization_fast.py for the CLI.
mfuntowicz Feb 6, 2020
78d975a
Expose max_len_ properties on BertTokenizerFast
mfuntowicz Feb 6, 2020
285da47
Move max_len_ properties to PreTrainedTokenizerFast and override in s…
mfuntowicz Feb 6, 2020
30ce9ee
_convert_encoding should keep the batch axis tensor if only one sampl…
mfuntowicz Feb 6, 2020
cb59a27
Add warning message for RobertaTokenizerFast if used for MLM.
mfuntowicz Feb 7, 2020
7cd0858
Added use_fast (bool) parameter on AutoTokenizer.from_pretrained().
mfuntowicz Feb 7, 2020
18ca932
Let's tokenizers handle all the truncation and padding stuff.
mfuntowicz Feb 7, 2020
ea75afc
Allow to provide tokenizer arguments during pipeline creation.
mfuntowicz Feb 7, 2020
f2ccac3
Update test_fill_mask pipeline to not use fast tokenizers.
mfuntowicz Feb 7, 2020
1d7cdde
Fix too much parameters for convert_encoding.
mfuntowicz Feb 7, 2020
14dfb32
When enabling padding, max_length should be set to None.
mfuntowicz Feb 7, 2020
7be2a07
Avoid returning nested tensors of length 1 when calling encode_plus
mfuntowicz Feb 7, 2020
3ae3811
Ensure output is padded when return_tensor is not None.
mfuntowicz Feb 7, 2020
f29a103
Disable transfoxl unittest if pytorch is not available (required to l…
mfuntowicz Feb 7, 2020
e18396a
encode_plus should not remove the leading batch axis if return_tensor…
mfuntowicz Feb 10, 2020
c16df3b
Temporary disable fast tokenizers on QA pipelines.
mfuntowicz Feb 10, 2020
490e690
Fix formatting issues.
mfuntowicz Feb 10, 2020
2f0df23
Update tokenizers to 0.4.0
n1t0 Feb 10, 2020
1f25635
Update style
n1t0 Feb 10, 2020
8fd7e67
Enable truncation + stride unit test on fast tokenizers.
mfuntowicz Feb 11, 2020
0edb712
Add unittest ensuring special_tokens set match between Python and Rust.
mfuntowicz Feb 11, 2020
f934a2b
Ensure special_tokens are correctly set during construction.
mfuntowicz Feb 11, 2020
028a2ab
Give more warning feedback to the user in case of padding without pad…
mfuntowicz Feb 11, 2020
73aa1da
quality & format.
mfuntowicz Feb 11, 2020
5fcb4f0
Added possibility to add a single token as str
mfuntowicz Feb 11, 2020
1cad8f7
Added unittest for add_tokens and add_special_tokens on fast tokenizers.
mfuntowicz Feb 11, 2020
bce7676
Fix rebase mismatch on pipelines qa default model.
mfuntowicz Feb 11, 2020
1111567
Addressing review comment: Using offset mapping relative to the origi…
mfuntowicz Feb 12, 2020
84a8c80
Addressing review comment: save_vocabulary requires folder and file name
mfuntowicz Feb 12, 2020
bc38709
Addressing review comment: Simplify import for Bert.
mfuntowicz Feb 12, 2020
1f38b59
Addressing review comment: truncate_and_pad disables padding accordin…
mfuntowicz Feb 12, 2020
150f38c
Addressing review comment: Remove private member access in tokenize()
mfuntowicz Feb 12, 2020
e07af64
Addressing review comment: Bump tokenizers dependency to 0.4.2
mfuntowicz Feb 12, 2020
9ebefaf
format & quality.
mfuntowicz Feb 12, 2020
6c58a79
Addressing review comment: Use named arguments when applicable.
mfuntowicz Feb 13, 2020
43afcec
Addressing review comment: Add Github link to Roberta/GPT2 space issu…
mfuntowicz Feb 13, 2020
623064d
Addressing review comment: Move max_len_single_sentence / max_len_sen…
mfuntowicz Feb 13, 2020
adc9d59
Addressing review comment: Relax type checking to include tuple and l…
mfuntowicz Feb 13, 2020
60152ec
Addressing review comment: Document the truncate_and_pad manager beha…
mfuntowicz Feb 13, 2020
7c9d853
Raise an exception if return_offsets_mapping is not available with th…
mfuntowicz Feb 14, 2020
339175d
Ensure padding is set on the tokenizers before setting any padding st…
mfuntowicz Feb 17, 2020
c64b472
On pytorch we need to stack tensor to get proper new axis.
mfuntowicz Feb 17, 2020
d2ff615
Generalize tests to different framework removing hard written return_…
mfuntowicz Feb 17, 2020
2689cf0
Bump tokenizer dependency for num_special_tokens_to_add
mfuntowicz Feb 18, 2020
cc94880
Overflowing tokens in batch_encode_plus are now stacked over the batc…
mfuntowicz Feb 18, 2020
7d05684
Improved error message for padding strategy without pad token.
mfuntowicz Feb 18, 2020
a2043b0
Bumping tokenizers dependency to 0.5.0 for release.
mfuntowicz Feb 19, 2020
f8e3cf4
Optimizing convert_encoding around 4x improvement. :rocket:
mfuntowicz Feb 19, 2020
8499ddc
expose pad_to_max_length in encode_plus to avoid duplicating the para…
mfuntowicz Feb 19, 2020
ad61705
Generate a proper overflow_to_sampling_mapping when return_overflowin…
mfuntowicz Feb 19, 2020
4c50f20
Fix unittests for overflow_to_sampling_mapping not being returned as …
mfuntowicz Feb 19, 2020
3342897
Format & quality.
mfuntowicz Feb 19, 2020
590ceb5
Remove perfect alignment constraint for Roberta (allowing 1% differen…
mfuntowicz Feb 19, 2020
56748e8
Triggering final CI
mfuntowicz Feb 19, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
packages=find_packages("src"),
install_requires=[
"numpy",
"tokenizers == 0.0.11",
"tokenizers == 0.5.0",
# accessing files from S3 directly
"boto3",
# filesystem locks e.g. to prevent parallel downloads
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@
from .tokenization_bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast

# Tokenizers
from .tokenization_utils import PreTrainedTokenizer
Expand Down
12 changes: 8 additions & 4 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def span_to_answer(self, text: str, start: int, end: int):
"default": {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
"config": None,
"tokenizer": "distilbert-base-cased",
"tokenizer": ("distilbert-base-cased", {"use_fast": False}),
},
},
"fill-mask": {
Expand All @@ -992,7 +992,7 @@ def span_to_answer(self, text: str, start: int, end: int):
"default": {
"model": {"pt": "distilroberta-base", "tf": "distilroberta-base"},
"config": None,
"tokenizer": "distilroberta-base",
"tokenizer": ("distilroberta-base", {"use_fast": False}),
},
},
}
Expand Down Expand Up @@ -1057,8 +1057,12 @@ def pipeline(
modelcard = config

# Instantiate tokenizer if needed
if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
if isinstance(tokenizer, (str, tuple)):
if isinstance(tokenizer, tuple):
# For tuple we have (tokenizer name, {kwargs})
tokenizer = AutoTokenizer.from_pretrained(tokenizer[0], **tokenizer[1])
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer)

# Instantiate config if needed
if isinstance(config, str):
Expand Down
51 changes: 29 additions & 22 deletions src/transformers/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@
)
from .configuration_utils import PretrainedConfig
from .tokenization_albert import AlbertTokenizer
from .tokenization_bert import BertTokenizer
from .tokenization_bert import BertTokenizer, BertTokenizerFast
from .tokenization_bert_japanese import BertJapaneseTokenizer
from .tokenization_camembert import CamembertTokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_distilbert import DistilBertTokenizer
from .tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .tokenization_flaubert import FlaubertTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_roberta import RobertaTokenizer
from .tokenization_gpt2 import GPT2Tokenizer, GPT2TokenizerFast
from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer
from .tokenization_transfo_xl import TransfoXLTokenizer, TransfoXLTokenizerFast
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
from .tokenization_xlnet import XLNetTokenizer
Expand All @@ -58,20 +58,20 @@

TOKENIZER_MAPPING = OrderedDict(
[
(T5Config, T5Tokenizer),
(DistilBertConfig, DistilBertTokenizer),
(AlbertConfig, AlbertTokenizer),
(CamembertConfig, CamembertTokenizer),
(XLMRobertaConfig, XLMRobertaTokenizer),
(RobertaConfig, RobertaTokenizer),
(BertConfig, BertTokenizer),
(OpenAIGPTConfig, OpenAIGPTTokenizer),
(GPT2Config, GPT2Tokenizer),
(TransfoXLConfig, TransfoXLTokenizer),
(XLNetConfig, XLNetTokenizer),
(FlaubertConfig, FlaubertTokenizer),
(XLMConfig, XLMTokenizer),
(CTRLConfig, CTRLTokenizer),
(T5Config, (T5Tokenizer, None)),
(DistilBertConfig, (DistilBertTokenizer, DistilBertTokenizerFast)),
(AlbertConfig, (AlbertTokenizer, None)),
(CamembertConfig, (CamembertTokenizer, None)),
(XLMRobertaConfig, (XLMRobertaTokenizer, None)),
(RobertaConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(BertConfig, (BertTokenizer, BertTokenizerFast)),
(OpenAIGPTConfig, (OpenAIGPTTokenizer, OpenAIGPTTokenizerFast)),
(GPT2Config, (GPT2Tokenizer, GPT2TokenizerFast)),
(TransfoXLConfig, (TransfoXLTokenizer, TransfoXLTokenizerFast)),
(XLNetConfig, (XLNetTokenizer, None)),
(FlaubertConfig, (FlaubertTokenizer, None)),
(XLMConfig, (XLMTokenizer, None)),
(CTRLConfig, (CTRLTokenizer, None)),
]
)

Expand Down Expand Up @@ -154,6 +154,9 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.

use_fast: (`optional`) boolean, default True:
Indicate if transformers should try to load the fast version of the tokenizer (True) or use the Python one (False).

inputs: (`optional`) positional arguments: will be passed to the Tokenizer ``__init__`` method.

kwargs: (`optional`) keyword arguments: will be passed to the Tokenizer ``__init__`` method. Can be used to set special tokens like ``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``. See parameters in the doc string of :class:`~transformers.PreTrainedTokenizer` for details.
Expand All @@ -177,9 +180,13 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
if "bert-base-japanese" in pretrained_model_name_or_path:
return BertJapaneseTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

for config_class, tokenizer_class in TOKENIZER_MAPPING.items():
use_fast = kwargs.pop("use_fast", True)
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items():
if isinstance(config, config_class):
return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if tokenizer_class_fast and use_fast:
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else:
return tokenizer_class_py.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)

raise ValueError(
"Unrecognized configuration class {} to build an AutoTokenizer.\n"
Expand Down
45 changes: 11 additions & 34 deletions src/transformers/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import os
import unicodedata

import tokenizers as tk
from tokenizers import BertWordPieceTokenizer

from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast

Expand Down Expand Up @@ -550,14 +550,19 @@ def __init__(
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
max_length=None,
pad_to_max_length=False,
stride=0,
truncation_strategy="longest_first",
add_special_tokens=True,
**kwargs
):
super().__init__(
BertWordPieceTokenizer(
vocab_file=vocab_file,
add_special_tokens=add_special_tokens,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
handle_chinese_chars=tokenize_chinese_chars,
lowercase=do_lower_case,
),
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
Expand All @@ -566,32 +571,4 @@ def __init__(
**kwargs,
)

self._tokenizer = tk.Tokenizer(tk.models.WordPiece.from_files(vocab_file, unk_token=unk_token))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(
tk.pre_tokenizers.BertPreTokenizer.new(
do_basic_tokenize=do_basic_tokenize,
do_lower_case=do_lower_case,
tokenize_chinese_chars=tokenize_chinese_chars,
never_split=never_split if never_split is not None else [],
)
)
self._tokenizer.with_decoder(tk.decoders.WordPiece.new())

if add_special_tokens:
self._tokenizer.with_post_processor(
tk.processors.BertProcessing.new(
(sep_token, self._tokenizer.token_to_id(sep_token)),
(cls_token, self._tokenizer.token_to_id(cls_token)),
)
)
if max_length is not None:
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
self._tokenizer.with_padding(
max_length=max_length if pad_to_max_length else None,
direction=self.padding_side,
pad_id=self.pad_token_id,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token,
)
self._decoder = tk.decoders.WordPiece.new()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to have this upstream now!

Comment on lines -569 to -597
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very satisfying

self.do_lower_case = do_lower_case
9 changes: 8 additions & 1 deletion src/transformers/tokenization_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging

from .tokenization_bert import BertTokenizer
from .tokenization_bert import BertTokenizer, BertTokenizerFast


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,3 +74,10 @@ class DistilBertTokenizer(BertTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION


class DistilBertTokenizerFast(BertTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
31 changes: 12 additions & 19 deletions src/transformers/tokenization_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import lru_cache

import regex as re
import tokenizers as tk
from tokenizers import ByteLevelBPETokenizer

from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast

Expand Down Expand Up @@ -259,26 +259,19 @@ def __init__(
unk_token="<|endoftext|>",
bos_token="<|endoftext|>",
eos_token="<|endoftext|>",
pad_to_max_length=False,
add_prefix_space=False,
max_length=None,
stride=0,
truncation_strategy="longest_first",
**kwargs
):
super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
super().__init__(
ByteLevelBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, add_prefix_space=add_prefix_space),
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
**kwargs,
)

self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file))
self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space))
self._tokenizer.with_decoder(tk.decoders.ByteLevel.new())
if max_length:
self._tokenizer.with_truncation(max_length, stride=stride, strategy=truncation_strategy)
self._tokenizer.with_padding(
max_length=max_length if pad_to_max_length else None,
direction=self.padding_side,
pad_id=self.pad_token_id if self.pad_token_id is not None else 0,
pad_type_id=self.pad_token_type_id,
pad_token=self.pad_token if self.pad_token is not None else "",
logger.warning(
"RobertaTokenizerFast has an issue when working on mask language modeling "
"where it introduces an extra encoded space before the mask token."
"See https://github.com/huggingface/transformers/pull/2778 for more information."
)
self._decoder = tk.decoders.ByteLevel.new()
101 changes: 100 additions & 1 deletion src/transformers/tokenization_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@
import logging
import os
import re
from typing import List, Optional, Union

from tokenizers import Tokenizer
from tokenizers.decoders import BPEDecoder
from tokenizers.implementations import BaseTokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import BertNormalizer, Sequence, unicode_normalizer_from_str
from tokenizers.pre_tokenizers import BertPreTokenizer
from tokenizers.trainers import BpeTrainer

from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import PreTrainedTokenizer, PreTrainedTokenizerFast


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -213,3 +222,93 @@ def save_vocabulary(self, save_directory):
index += 1

return vocab_file, merge_file


class _OpenAIGPTCharBPETokenizer(BaseTokenizer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we have to have this class here?

Don't we have an implementation of char-level BPE in tokenizers now?
Here: https://github.com/huggingface/tokenizers/blob/master/bindings/python/tokenizers/implementations/char_level_bpe.py#L9

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do need a special OpenaiGPT implementation because it slightly differs from the char-level BPE we have in tokenizers:

  • Normalizer is the same as Bert (BertNormalizer)
  • PreTokenizer is not Whitespace, it's the same as Bert (BertPreTokenizer)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we put TransformerXL into tokenizers.implementations, may be this one can make its way to tokenizers too. cc @n1t0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, I'm not too sure about this. I think tokenizers should stay a library with some generic implementations, with an easy way for everybody to build it's own custom tokenizer when needed. So I'd like to avoid introducing specific implementations for each new model/tokenizer. Otherwise, the next thing we'll discuss is whether we should have default vocabularies downloaded automatically with each specific implementation, and then we'll have as many implementations as models there are in transformers... I think it makes more sense to have specific customization details in transformers, next to the model that actually uses the custom tokenizer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, the next thing we'll discuss is whether we should have default vocabularies downloaded automatically with each specific implementation

You mean have all the things that made the success of Transformers? 😜

Jking Well ok for me to keep these in Transformers then.

"""
OpenAI character-level BPE Tokenizer
"""

def __init__(
self,
vocab_file: Optional[str] = None,
merges_file: Optional[str] = None,
unk_token: Optional[str] = "<unk>",
suffix: Optional[str] = "</w>",
dropout: Optional[float] = None,
unicode_normalizer: Optional[str] = None,
):
if vocab_file is not None and merges_file is not None:
tokenizer = Tokenizer(
BPE.from_files(
vocab_file, merges_file, dropout=dropout, unk_token=unk_token, end_of_word_suffix=suffix
)
)
else:
tokenizer = Tokenizer(BPE.empty())

# Check for Unicode normalization first (before everything else)
normalizers = []

if unicode_normalizer:
normalizers += [unicode_normalizer_from_str(unicode_normalizer)]

# OpenAI normalization is the same as Bert
normalizers += [BertNormalizer()]

# Create the normalizer structure
if len(normalizers) > 0:
if len(normalizers) > 1:
tokenizer.normalizer = Sequence(normalizers)
else:
tokenizer.normalizer = normalizers[0]

tokenizer.pre_tokenizer = BertPreTokenizer()
tokenizer.decoder = BPEDecoder(suffix=suffix)

parameters = {
"model": "BPE",
"unk_token": unk_token,
"suffix": suffix,
"dropout": dropout,
}

super().__init__(tokenizer, parameters)

def train(
self,
files: Union[str, List[str]],
vocab_size: int = 30000,
min_frequency: int = 2,
special_tokens: List[str] = ["<unk>"],
limit_alphabet: int = 1000,
initial_alphabet: List[str] = [],
suffix: Optional[str] = "</w>",
show_progress: bool = True,
):
""" Train the model using the given files """

trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=special_tokens,
limit_alphabet=limit_alphabet,
initial_alphabet=initial_alphabet,
end_of_word_suffix=suffix,
show_progress=show_progress,
)
if isinstance(files, str):
files = [files]
self._tokenizer.train(trainer, files)


class OpenAIGPTTokenizerFast(PreTrainedTokenizerFast):
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

def __init__(self, vocab_file, merges_file, unk_token="<unk>", **kwargs):
kwargs.setdefault("unk_token", unk_token)
super().__init__(
_OpenAIGPTCharBPETokenizer(vocab_file=vocab_file, merges_file=merges_file, unk_token=unk_token), **kwargs
)
Loading