-
Notifications
You must be signed in to change notification settings - Fork 228
Add LRU cache, add faster tokenization #37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
160e2bd
e7c3d51
dc80008
a405b9e
54ab4e3
e729aba
3501149
75cce0b
cc57925
1811892
02b2d2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,7 +78,6 @@ def bytes_to_unicode(): | |||||||
|
|
||||||||
| def get_pairs(word): | ||||||||
| """Return set of symbol pairs in a word. | ||||||||
|
|
||||||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | ||||||||
| """ | ||||||||
| pairs = set() | ||||||||
|
|
@@ -153,8 +152,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, | |||||||
| **kwargs) | ||||||||
| return tokenizer | ||||||||
|
|
||||||||
|
|
||||||||
| def __init__(self, vocab_file, merges_file, errors='replace', | ||||||||
| special_tokens=None, max_len=None): | ||||||||
| special_tokens=None, max_len=None, max_token_len_cache=9): | ||||||||
| """ | ||||||||
| max_token_len_cache determines whether a normalized token will be cached. It tries to only store shorter tokens in the cache, | ||||||||
| with the heuristic that they are more frequent. Increasing this may make tokenization faster but will also take more memory. | ||||||||
| The upper bound of the normalized token cache is fixed at 1_000_000 tokens. | ||||||||
| """ | ||||||||
| self.max_len = max_len if max_len is not None else int(1e12) | ||||||||
| self.encoder = json.load(open(vocab_file)) | ||||||||
| self.decoder = {v: k for k, v in self.encoder.items()} | ||||||||
|
|
@@ -164,8 +169,6 @@ def __init__(self, vocab_file, merges_file, errors='replace', | |||||||
| bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | ||||||||
| bpe_merges = [tuple(merge.split()) for merge in bpe_data] | ||||||||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | ||||||||
| self.cache = {} | ||||||||
|
|
||||||||
| # Should haved added re.IGNORECASE so BPE merges can happen for | ||||||||
| # capitalized versions of contractions | ||||||||
| self.pat = re.compile( | ||||||||
|
|
@@ -174,6 +177,7 @@ def __init__(self, vocab_file, merges_file, errors='replace', | |||||||
| self.special_tokens = {} | ||||||||
| self.special_tokens_decoder = {} | ||||||||
| self.set_special_tokens(special_tokens) | ||||||||
| self.max_token_len_cache = max_token_len_cache | ||||||||
|
|
||||||||
| def __len__(self): | ||||||||
| return len(self.encoder) + len(self.special_tokens) | ||||||||
|
|
@@ -191,10 +195,9 @@ def set_special_tokens(self, special_tokens): | |||||||
| for i, tok in enumerate(special_tokens)) | ||||||||
| self.special_tokens_decoder = {v: k for k, v in self.special_tokens.items()} | ||||||||
| logger.info("Special tokens {}".format(self.special_tokens)) | ||||||||
|
|
||||||||
|
|
||||||||
| @lru_cache(1_000_000) | ||||||||
| def bpe(self, token): | ||||||||
| if token in self.cache: | ||||||||
| return self.cache[token] | ||||||||
| word = tuple(token) | ||||||||
| pairs = get_pairs(word) | ||||||||
|
|
||||||||
|
|
@@ -230,20 +233,33 @@ def bpe(self, token): | |||||||
| else: | ||||||||
| pairs = get_pairs(word) | ||||||||
| word = ' '.join(word) | ||||||||
| self.cache[token] = word | ||||||||
| return word | ||||||||
|
|
||||||||
| @lru_cache(1_000_000) | ||||||||
| def normalize_token_and_cache(self, token): | ||||||||
| return self.normalize_token(token) | ||||||||
|
|
||||||||
| def normalize_token(self, token): | ||||||||
| token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | ||||||||
| ret = [bpe_token for bpe_token in self.bpe(token).split(' ')] | ||||||||
| return ret | ||||||||
|
|
||||||||
| def tokenize(self, text): | ||||||||
| """ Tokenize a string. """ | ||||||||
| max_token_len_cache = self.max_token_len_cache | ||||||||
| bpe_tokens = [] | ||||||||
| if sys.version_info[0] == 2: | ||||||||
| for token in re.findall(self.pat, text): | ||||||||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||||||||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) | ||||||||
|
Comment on lines
+253
to
+254
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the if statement being outside the for loop is better.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a better version would just to ignore that case since python 2 is not supported.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, but then it would be inconsistent since we STILL have code that tests for python2.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove that code and anywhere else where similar branches are found? |
||||||||
| return bpe_tokens | ||||||||
| for token in re.findall(self.pat, text): | ||||||||
| if sys.version_info[0] == 2: | ||||||||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) | ||||||||
| else: | ||||||||
| token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) | ||||||||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) | ||||||||
| if len(token) <= max_token_len_cache: | ||||||||
| bpe_tokens.extend(self.normalize_token_and_cache(token)) | ||||||||
| else: | ||||||||
| bpe_tokens.extend(self.normalize_token(token)) | ||||||||
| return bpe_tokens | ||||||||
|
|
||||||||
| def convert_tokens_to_ids(self, tokens): | ||||||||
| """ Converts a sequence of tokens into ids using the vocab. """ | ||||||||
| ids = [] | ||||||||
|
|
||||||||
Uh oh!
There was an error while loading. Please reload this page.