diff --git a/megatron/tokenizer/gpt2_tokenization.py b/megatron/tokenizer/gpt2_tokenization.py index 3f37e4490..bce13e8ec 100644 --- a/megatron/tokenizer/gpt2_tokenization.py +++ b/megatron/tokenizer/gpt2_tokenization.py @@ -78,7 +78,6 @@ def bytes_to_unicode(): def get_pairs(word): """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). """ pairs = set() @@ -153,8 +152,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs) return tokenizer + def __init__(self, vocab_file, merges_file, errors='replace', - special_tokens=None, max_len=None): + special_tokens=None, max_len=None, max_token_len_cache=9): + """ + max_token_len_cache determines whether a normalized token will be cached. It tries to only store shorter tokens in the cache, + with the heuristic that they are more frequent. Increasing this may make tokenization faster but will also take more memory. + The upper bound of the normalized token cache is fixed at 1_000_000 tokens. + """ self.max_len = max_len if max_len is not None else int(1e12) self.encoder = json.load(open(vocab_file)) self.decoder = {v: k for k, v in self.encoder.items()} @@ -164,8 +169,6 @@ def __init__(self, vocab_file, merges_file, errors='replace', bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] bpe_merges = [tuple(merge.split()) for merge in bpe_data] self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) - self.cache = {} - # Should haved added re.IGNORECASE so BPE merges can happen for # capitalized versions of contractions self.pat = re.compile( @@ -174,6 +177,7 @@ def __init__(self, vocab_file, merges_file, errors='replace', self.special_tokens = {} self.special_tokens_decoder = {} self.set_special_tokens(special_tokens) + self.max_token_len_cache = max_token_len_cache def __len__(self): return len(self.encoder) + len(self.special_tokens) @@ -191,10 +195,9 @@ def set_special_tokens(self, special_tokens): for i, tok in enumerate(special_tokens)) self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} logger.info("Special tokens {}".format(self.special_tokens)) - + + @lru_cache(1_000_000) def bpe(self, token): - if token in self.cache: - return self.cache[token] word = tuple(token) pairs = get_pairs(word) @@ -230,20 +233,33 @@ def bpe(self, token): else: pairs = get_pairs(word) word = ' '.join(word) - self.cache[token] = word return word + + @lru_cache(1_000_000) + def normalize_token_and_cache(self, token): + return self.normalize_token(token) + + def normalize_token(self, token): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + ret = [bpe_token for bpe_token in self.bpe(token).split(' ')] + return ret def tokenize(self, text): """ Tokenize a string. """ + max_token_len_cache = self.max_token_len_cache bpe_tokens = [] + if sys.version_info[0] == 2: + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[ord(b)] for b in token) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens for token in re.findall(self.pat, text): - if sys.version_info[0] == 2: - token = ''.join(self.byte_encoder[ord(b)] for b in token) - else: - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + if len(token) <= max_token_len_cache: + bpe_tokens.extend(self.normalize_token_and_cache(token)) + else: + bpe_tokens.extend(self.normalize_token(token)) return bpe_tokens - + def convert_tokens_to_ids(self, tokens): """ Converts a sequence of tokens into ids using the vocab. """ ids = [] diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 9001af5cd..73e268459 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -21,10 +21,9 @@ import os import sys -from megatron.data.indexed_dataset import best_fitting_dtype - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) +from megatron.data.indexed_dataset import best_fitting_dtype import time import torch