|
| 1 | +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import base64 |
| 16 | +import json |
| 17 | +import os |
| 18 | +from pathlib import Path |
| 19 | +from typing import Dict, List, Optional |
| 20 | + |
| 21 | +try: |
| 22 | + import tiktoken |
| 23 | +except ImportError: |
| 24 | + pass |
| 25 | + |
| 26 | +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec |
| 27 | + |
| 28 | +__all__ = ['TiktokenTokenizer'] |
| 29 | + |
| 30 | + |
| 31 | +def reload_mergeable_ranks( |
| 32 | + path: str, |
| 33 | + max_vocab: Optional[int] = None, |
| 34 | +) -> Dict[bytes, int]: |
| 35 | + """ |
| 36 | + Reload the tokenizer JSON file and convert it to Tiktoken format. |
| 37 | + """ |
| 38 | + assert path.endswith(".json") |
| 39 | + |
| 40 | + # reload vocab |
| 41 | + with open(path, "r") as f: |
| 42 | + vocab = json.load(f) |
| 43 | + assert isinstance(vocab, list) |
| 44 | + print(f"Vocab size: {len(vocab)}") |
| 45 | + if max_vocab is not None: |
| 46 | + vocab = vocab[:max_vocab] |
| 47 | + print(f"Cutting vocab to first {len(vocab)} tokens.") |
| 48 | + |
| 49 | + # build ranks |
| 50 | + ranks: Dict[bytes, int] = {} |
| 51 | + for i, x in enumerate(vocab): |
| 52 | + assert x.keys() == {"rank", "token_bytes", "token_str"} |
| 53 | + assert x["rank"] == i |
| 54 | + merge = base64.b64decode(x["token_bytes"]) |
| 55 | + assert i >= 256 or merge == bytes([i]) |
| 56 | + ranks[merge] = x["rank"] |
| 57 | + |
| 58 | + # sanity check |
| 59 | + assert len(ranks) == len(vocab) |
| 60 | + assert set(ranks.values()) == set(range(len(ranks))) |
| 61 | + |
| 62 | + return ranks |
| 63 | + |
| 64 | + |
| 65 | +PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" |
| 66 | +DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072 |
| 67 | +SPECIAL_TOKENS = ["<unk>", "<s>", "</s>"] |
| 68 | +SPECIAL_TOKEN_TEMPLATE = "<SPECIAL_{id}>" |
| 69 | + |
| 70 | + |
| 71 | +class TiktokenTokenizer(TokenizerSpec): |
| 72 | + """ |
| 73 | + TiktokenTokenizer https://github.com/openai/tiktoken. |
| 74 | +
|
| 75 | + Args: |
| 76 | + model_path: path to tokenizer vocabulary |
| 77 | + num_special_tokens: number of special tokens to generate |
| 78 | + special_tokens: template for user-defined special tokens |
| 79 | + pattern: Regex pattern to split the text |
| 80 | + """ |
| 81 | + |
| 82 | + def __init__( |
| 83 | + self, |
| 84 | + vocab_file: str, |
| 85 | + pattern: str = PATTERN_TIKTOKEN, |
| 86 | + vocab_size: int = DEFAULT_TIKTOKEN_MAX_VOCAB, # 131072 |
| 87 | + num_special_tokens: int = 1000, |
| 88 | + special_tokens: Optional[List[str]] = None, |
| 89 | + ): |
| 90 | + if not vocab_file or not os.path.exists(vocab_file): |
| 91 | + raise ValueError(f"vocab_file: {vocab_file} is invalid") |
| 92 | + |
| 93 | + if special_tokens is None: |
| 94 | + special_tokens = SPECIAL_TOKENS.copy() |
| 95 | + |
| 96 | + assert len(special_tokens) == len(set(special_tokens)), f"Special tokens should be unique: {special_tokens}" |
| 97 | + assert len(special_tokens) <= num_special_tokens < vocab_size |
| 98 | + assert set(SPECIAL_TOKENS) <= set(special_tokens), f"Custom special tokens should include {SPECIAL_TOKENS}" |
| 99 | + |
| 100 | + self._unk_id = special_tokens.index("<unk>") |
| 101 | + self._bos_id = special_tokens.index("<s>") |
| 102 | + self._eos_id = special_tokens.index("</s>") |
| 103 | + |
| 104 | + self._vocab_size = vocab_size |
| 105 | + print(f'{self._vocab_size = }') |
| 106 | + self.num_special_tokens = num_special_tokens |
| 107 | + special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)] |
| 108 | + if special_filler: |
| 109 | + print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") |
| 110 | + self.special_tokens = special_tokens + special_filler |
| 111 | + assert len(set(self.special_tokens)) == len(self.special_tokens) == num_special_tokens, self.special_tokens |
| 112 | + self.inner_vocab_size = vocab_size - num_special_tokens |
| 113 | + |
| 114 | + # reload vocab |
| 115 | + self.token2id = reload_mergeable_ranks(vocab_file, max_vocab=self.inner_vocab_size) |
| 116 | + self.id2token = {v: k for k, v in self.token2id.items()} |
| 117 | + assert set(range(self.inner_vocab_size)) == set(self.id2token.keys()) |
| 118 | + |
| 119 | + self.shifted_id2token = {i: tok for i, tok in enumerate(self.special_tokens)} |
| 120 | + for key, value in self.id2token.items(): |
| 121 | + self.shifted_id2token[key + self.num_special_tokens] = value |
| 122 | + |
| 123 | + self.tokenizer = tiktoken.Encoding( |
| 124 | + name=Path(vocab_file).parent.name, |
| 125 | + pat_str=pattern, |
| 126 | + mergeable_ranks=self.token2id, |
| 127 | + special_tokens={}, # special tokens are handled manually |
| 128 | + ) |
| 129 | + |
| 130 | + def text_to_tokens(self, text: str): |
| 131 | + token_ids = self.tokenizer.encode(text) |
| 132 | + return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids] |
| 133 | + |
| 134 | + def tokens_to_text(self, tokens: List[int]): |
| 135 | + token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens] |
| 136 | + return self.tokenizer.decode(token_ids) |
| 137 | + |
| 138 | + def token_to_id(self, token): |
| 139 | + return self.tokenizer.encode_single_token(token) |
| 140 | + |
| 141 | + def tokens_to_ids(self, tokens): |
| 142 | + return [self.tokenizer.encode_single_token(token) for token in tokens] |
| 143 | + |
| 144 | + def ids_to_tokens(self, token_ids): |
| 145 | + tokens = [] |
| 146 | + for token_id in token_ids: |
| 147 | + if token_id < self.num_special_tokens: |
| 148 | + tokens.append(self.special_tokens[token_id]) |
| 149 | + else: |
| 150 | + token_id -= self.num_special_tokens |
| 151 | + token_bytes = self.tokenizer.decode_single_token_bytes(token_id) |
| 152 | + tokens.append(token_bytes.decode('utf-8', errors='replace')) |
| 153 | + return tokens |
| 154 | + |
| 155 | + def text_to_ids(self, text: str): |
| 156 | + tokens = self.tokenizer.encode(text) |
| 157 | + tokens = [t + self.num_special_tokens for t in tokens] |
| 158 | + return tokens |
| 159 | + |
| 160 | + def ids_to_text(self, tokens: List[int]): |
| 161 | + # Filter out special tokens and adjust the remaining tokens |
| 162 | + adjusted_tokens = [ |
| 163 | + t - self.num_special_tokens |
| 164 | + for t in tokens |
| 165 | + if t not in {self.bos, self.eos} and t >= self.num_special_tokens |
| 166 | + ] |
| 167 | + |
| 168 | + # Decode only if there are tokens left after filtering |
| 169 | + if adjusted_tokens: |
| 170 | + return self.tokenizer.decode(adjusted_tokens) |
| 171 | + else: |
| 172 | + return "" # Return an empty string if all tokens were filtered out |
| 173 | + |
| 174 | + @property |
| 175 | + def bos_id(self): |
| 176 | + return self._bos_id |
| 177 | + |
| 178 | + @property |
| 179 | + def eos_id(self): |
| 180 | + return self._eos_id |
| 181 | + |
| 182 | + @property |
| 183 | + def unk_id(self): |
| 184 | + return self._unk_id |
| 185 | + |
| 186 | + @property |
| 187 | + def vocab(self): |
| 188 | + return self.token2id |
| 189 | + |
| 190 | + @property |
| 191 | + def decoder(self): |
| 192 | + return self.shifted_id2token |
| 193 | + |
| 194 | + @property |
| 195 | + def encoder(self): |
| 196 | + return self.vocab |
| 197 | + |
| 198 | + @property |
| 199 | + def vocab_size(self) -> int: |
| 200 | + return self._vocab_size |
0 commit comments