Skip to content
44 changes: 30 additions & 14 deletions megatron/tokenizer/gpt2_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def bytes_to_unicode():

def get_pairs(word):
"""Return set of symbol pairs in a word.

Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
Expand Down Expand Up @@ -153,8 +152,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,
**kwargs)
return tokenizer


def __init__(self, vocab_file, merges_file, errors='replace',
special_tokens=None, max_len=None):
special_tokens=None, max_len=None, max_token_len_cache=9):
"""
max_token_len_cache determines whether a normalized token will be cached. It tries to only store shorter tokens in the cache,
with the heuristic that they are more frequent. Increasing this may make tokenization faster but will also take more memory.
The upper bound of the normalized token cache is fixed at 1_000_000 tokens.
"""
self.max_len = max_len if max_len is not None else int(1e12)
self.encoder = json.load(open(vocab_file))
self.decoder = {v: k for k, v in self.encoder.items()}
Expand All @@ -164,8 +169,6 @@ def __init__(self, vocab_file, merges_file, errors='replace',
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}

# Should haved added re.IGNORECASE so BPE merges can happen for
# capitalized versions of contractions
self.pat = re.compile(
Expand All @@ -174,6 +177,7 @@ def __init__(self, vocab_file, merges_file, errors='replace',
self.special_tokens = {}
self.special_tokens_decoder = {}
self.set_special_tokens(special_tokens)
self.max_token_len_cache = max_token_len_cache

def __len__(self):
return len(self.encoder) + len(self.special_tokens)
Expand All @@ -191,10 +195,9 @@ def set_special_tokens(self, special_tokens):
for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))


@lru_cache(1_000_000)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)

Expand Down Expand Up @@ -230,20 +233,33 @@ def bpe(self, token):
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word

@lru_cache(1_000_000)
def normalize_token_and_cache(self, token):
return self.normalize_token(token)

def normalize_token(self, token):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
ret = [bpe_token for bpe_token in self.bpe(token).split(' ')]
return ret

def tokenize(self, text):
""" Tokenize a string. """
max_token_len_cache = self.max_token_len_cache
bpe_tokens = []
if sys.version_info[0] == 2:
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[ord(b)] for b in token)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
Comment on lines +253 to +254
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
token = ''.join(self.byte_encoder[ord(b)] for b in token)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
bpe_tokens.extend(self.normalize_token(token))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the if statement being outside the for loop is better.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better version would just to ignore that case since python 2 is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but then it would be inconsistent since we STILL have code that tests for python2.

Copy link
Contributor

@stas00 stas00 Aug 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove that code and anywhere else where similar branches are found?

return bpe_tokens
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
if len(token) <= max_token_len_cache:
bpe_tokens.extend(self.normalize_token_and_cache(token))
else:
bpe_tokens.extend(self.normalize_token(token))
return bpe_tokens

def convert_tokens_to_ids(self, tokens):
""" Converts a sequence of tokens into ids using the vocab. """
ids = []
Expand Down
3 changes: 1 addition & 2 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@
import os
import sys

from megatron.data.indexed_dataset import best_fitting_dtype

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
from megatron.data.indexed_dataset import best_fitting_dtype
import time

import torch
Expand Down