@@ -287,41 +287,57 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
287287 return tokens;
288288}
289289
290+ // TODO: Calculate this constant from the vocabulary
291+ #define MAX_TOKEN_LEN 18
292+ // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece
290293std::vector<gpt_vocab::id> llama_tokenize (const gpt_vocab & vocab, const std::string & text, bool bos) {
291- // auto res = gpt_tokenize(vocab, text);
292-
293- // if (bos) {
294- // res.insert(res.begin(), 1); // TODO: replace with vocab.bos
295- // }
296-
297294 std::vector<gpt_vocab::id> res;
298-
299- if (bos) {
300- res.push_back (1 ); // TODO: replace with vocab.bos
301- }
302-
303- // find the longest token that matches the text
304- int pos = 0 ;
305- while (true ) {
306- int l = 0 ;
307- int t = 0 ;
308- for (const auto & kv : vocab.id_to_token ) {
309- if (kv.second .size () < l) continue ;
310- if (kv.second .size () > text.size () - pos) continue ;
311- if (text.substr (pos, kv.second .size ()) == kv.second ) {
312- l = kv.second .size ();
313- t = kv.first ;
295+ std::vector<int > score;
296+ std::vector<gpt_vocab::id> prev;
297+ int len = text.length ();
298+
299+ score.resize (len + 1 );
300+ prev.resize (len + 1 );
301+
302+ // Forward pass
303+ for (int i = 0 ; i < len; i++) {
304+ int max_len = std::min (len - i, MAX_TOKEN_LEN);
305+ for (int sub_len = 1 ; sub_len <= len - i; sub_len++) {
306+ auto sub = text.substr (i, sub_len);
307+ auto token = vocab.token_to_id .find (sub);
308+ if (token != vocab.token_to_id .end ()) {
309+ int token_score = sub.length () * sub.length ();
310+ int local_score = score[i] + token_score;
311+ int next = i + sub_len;
312+ if (score[next] < local_score) {
313+ score[next] = local_score;
314+ prev[next] = (*token).second ;
315+ }
314316 }
315317 }
318+ }
316319
317- if (l == 0 ) {
318- break ;
320+ // Backward pass
321+ int i = len;
322+ while (i > 0 ) {
323+ gpt_vocab::id token_id = prev[i];
324+ if (token_id == 0 ) {
325+ // TODO: Return error or something more meaningful
326+ printf (" failed to tokenize string!\n " );
327+ break ;
319328 }
329+ res.push_back (token_id);
330+ auto token = (*vocab.id_to_token .find (token_id)).second ;
331+ i -= token.length ();
332+ }
320333
321- res. push_back (t);
322- pos += l;
334+ if (bos) {
335+ res. push_back ( 1 ); // TODO: replace with vocab.bos
323336 }
324337
338+ // Pieces are in reverse order so correct that
339+ std::reverse (res.begin (), res.end ());
340+
325341 return res;
326342}
327343
0 commit comments