@@ -1702,12 +1702,13 @@ struct llama_mlock {
17021702};
17031703using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>;
17041704
1705- static std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
1705+ // NOTE: avoid ever using this except for building the token_to_piece caches
1706+ static std::string llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) {
17061707    std::vector<char> result(8, 0);
1707-     const int n_tokens = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1708+     const int n_tokens = llama_token_to_piece(model , token, result.data(), result.size(), special);
17081709    if (n_tokens < 0) {
17091710        result.resize(-n_tokens);
1710-         int check = llama_token_to_piece(llama_get_model(ctx) , token, result.data(), result.size(), special);
1711+         int check = llama_token_to_piece(model , token, result.data(), result.size(), special);
17111712        GGML_ASSERT(check == -n_tokens);
17121713    }
17131714    else {
@@ -2162,7 +2163,11 @@ struct llama_vocab {
21622163    std::unordered_map<token, id> token_to_id;
21632164    std::vector<token_data>       id_to_token;
21642165
2165-     std::vector<id> special_tokens_cache;
2166+     bool has_cache = false;
2167+ 
2168+     std::vector<id> cache_special_tokens;
2169+     std::unordered_map<id, token> cache_token_to_piece;         // llama_token_to_piece(special = false);
2170+     std::unordered_map<id, token> cache_token_to_piece_special; // llama_token_to_piece(special = true);
21662171
21672172    std::map<std::pair<std::string, std::string>, int> bpe_ranks;
21682173
@@ -4833,18 +4838,26 @@ static void llm_load_vocab(
48334838    {
48344839        for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
48354840            if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4836-                 vocab.special_tokens_cache .push_back(id);
4841+                 vocab.cache_special_tokens .push_back(id);
48374842            }
48384843        }
48394844
4840-         std::sort( vocab.special_tokens_cache .begin(), vocab.special_tokens_cache .end(),
4845+         std::sort( vocab.cache_special_tokens .begin(), vocab.cache_special_tokens .end(),
48414846            [&] (const llama_vocab::id a, const llama_vocab::id b) {
48424847                return vocab.id_to_token[a].text.size() > vocab.id_to_token[b].text.size();
48434848            }
48444849        );
48454850
4846-         LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.special_tokens_cache.size());
4851+         LLAMA_LOG_INFO("%s: special tokens cache size = %u.\n", __func__, (uint32_t)vocab.cache_special_tokens.size());
4852+     }
4853+ 
4854+     // build token to piece caches
4855+     for (llama_token id = 0; id < (llama_token) n_vocab; ++id) {
4856+         vocab.cache_token_to_piece[id]         = llama_token_to_piece(&model, id, false);
4857+         vocab.cache_token_to_piece_special[id] = llama_token_to_piece(&model, id, true);
48474858    }
4859+ 
4860+     vocab.has_cache = true;
48484861}
48494862
48504863static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -13233,7 +13246,7 @@ struct fragment_buffer_variant {
1323313246
1323413247static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1323513248    // for each special token
13236-     for (const llama_vocab::id special_id : vocab.special_tokens_cache ) {
13249+     for (const llama_vocab::id special_id : vocab.cache_special_tokens ) {
1323713250        const auto & special_token = vocab.id_to_token[special_id].text;
1323813251
1323913252        // for each text fragment
@@ -14392,7 +14405,7 @@ void llama_sample_repetition_penalties(
1439214405
1439314406void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) {
1439414407    GGML_ASSERT(ctx);
14395-     const  int64_t t_start_sample_us = ggml_time_us();
14408+     int64_t t_start_sample_us = ggml_time_us();
1439614409
1439714410    bool allow_eog = false;
1439814411    for (const auto & stack : grammar->stacks) {
@@ -14408,8 +14421,8 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
1440814421    candidates_grammar.reserve(candidates->size);
1440914422
1441014423    for (size_t i = 0; i < candidates->size; ++i) {
14411-         const llama_token id    = candidates->data[i].id;
14412-         const std::string piece = llama_token_to_piece( ctx, id, false );
14424+         const llama_token id       = candidates->data[i].id;
14425+         const std::string &  piece = ctx->model.vocab.cache_token_to_piece.at(id );
1441314426
1441414427        if (llama_token_is_eog(&ctx->model, id)) {
1441514428            if (!allow_eog) {
@@ -14609,7 +14622,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
1460914622        GGML_ASSERT(false);
1461014623    }
1461114624
14612-     const std::string piece = llama_token_to_piece( ctx,  token, false );
14625+     const std::string &  piece = ctx->model.vocab.cache_token_to_piece.at( token);
1461314626
1461414627    // Note terminating 0 in decoded string
1461514628    const auto   decoded     = decode_utf8(piece, grammar->partial_utf8);
@@ -18292,69 +18305,79 @@ static std::string llama_decode_text(const std::string & text) {
1829218305
1829318306// does not write null-terminator to buf
1829418307int32_t llama_token_to_piece(const struct llama_model * model, llama_token token, char * buf, int32_t length, bool special) {
18308+     if (model->vocab.has_cache) {
18309+         const auto & cache = special ? model->vocab.cache_token_to_piece_special : model->vocab.cache_token_to_piece;
18310+         const auto & res = cache.at(token);
18311+         if (length < (int) res.size()) {
18312+             return -(int) res.size();
18313+         }
18314+         memcpy(buf, res.c_str(), res.size());
18315+         return res.size();
18316+     }
18317+ 
1829518318    if (0 <= token && token < llama_n_vocab(model)) {
1829618319        switch (llama_vocab_get_type(model->vocab)) {
18297-         case LLAMA_VOCAB_TYPE_WPM:
18298-         case LLAMA_VOCAB_TYPE_SPM: {
18299-             // NOTE: we accept all unsupported token types,
18300-             // suppressing them like CONTROL tokens.
18301-             if (llama_is_normal_token(model->vocab, token)) {
18302-                 std::string result = model->vocab.id_to_token[token].text;
18303-                 llama_unescape_whitespace(result);
18304-                 if (length < (int) result.length()) {
18305-                     return -(int) result.length();
18306-                 }
18307-                 memcpy(buf, result.c_str(), result.length());
18308-                 return result.length();
18309-             } else if (
18310-                     (llama_is_user_defined_token(model->vocab, token)) ||
18311-                     (llama_is_control_token     (model->vocab, token) && special)) {
18312-                 std::string result = model->vocab.id_to_token[token].text;
18313-                 if (length < (int) result.length()) {
18314-                     return -(int) result.length();
18315-                 }
18316-                 memcpy(buf, result.c_str(), result.length());
18317-                 return result.length();
18318-             } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18319-                 if (length < 3) {
18320-                     return -3;
18321-                 }
18322-                 memcpy(buf, "\xe2\x96\x85", 3);
18323-                 return 3;
18324-             } else if (llama_is_byte_token(model->vocab, token)) {
18325-                 if (length < 1) {
18326-                     return -1;
18320+             case LLAMA_VOCAB_TYPE_WPM:
18321+             case LLAMA_VOCAB_TYPE_SPM: {
18322+                 // NOTE: we accept all unsupported token types,
18323+                 // suppressing them like CONTROL tokens.
18324+                 if (llama_is_normal_token(model->vocab, token)) {
18325+                     std::string result = model->vocab.id_to_token[token].text;
18326+                     llama_unescape_whitespace(result);
18327+                     if (length < (int) result.length()) {
18328+                         return -(int) result.length();
18329+                     }
18330+                     memcpy(buf, result.c_str(), result.length());
18331+                     return result.length();
18332+                 } else if (
18333+                         (llama_is_user_defined_token(model->vocab, token)) ||
18334+                         (llama_is_control_token     (model->vocab, token) && special)) {
18335+                     std::string result = model->vocab.id_to_token[token].text;
18336+                     if (length < (int) result.length()) {
18337+                         return -(int) result.length();
18338+                     }
18339+                     memcpy(buf, result.c_str(), result.length());
18340+                     return result.length();
18341+                 } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT
18342+                     if (length < 3) {
18343+                         return -3;
18344+                     }
18345+                     memcpy(buf, "\xe2\x96\x85", 3);
18346+                     return 3;
18347+                 } else if (llama_is_byte_token(model->vocab, token)) {
18348+                     if (length < 1) {
18349+                         return -1;
18350+                     }
18351+                     buf[0] = llama_token_to_byte(model->vocab, token);
18352+                     return 1;
1832718353                }
18328-                 buf[0] = llama_token_to_byte(model->vocab, token);
18329-                 return 1;
18354+                 break;
1833018355            }
18331-             break;
18332-         }
18333-         case LLAMA_VOCAB_TYPE_BPE: {
18334-             // NOTE: we accept all unsupported token types,
18335-             // suppressing them like CONTROL tokens.
18336-             if (llama_is_normal_token(model->vocab, token)) {
18337-                 std::string result = model->vocab.id_to_token[token].text;
18338-                 result = llama_decode_text(result);
18339-                 if (length < (int) result.length()) {
18340-                     return -(int) result.length();
18341-                 }
18342-                 memcpy(buf, result.c_str(), result.length());
18343-                 return result.length();
18344-             } else if (
18345-                     (llama_is_user_defined_token(model->vocab, token)) ||
18346-                     (llama_is_control_token     (model->vocab, token) && special)) {
18347-                 std::string result = model->vocab.id_to_token[token].text;
18348-                 if (length < (int) result.length()) {
18349-                     return -(int) result.length();
18356+             case LLAMA_VOCAB_TYPE_BPE: {
18357+                 // NOTE: we accept all unsupported token types,
18358+                 // suppressing them like CONTROL tokens.
18359+                 if (llama_is_normal_token(model->vocab, token)) {
18360+                     std::string result = model->vocab.id_to_token[token].text;
18361+                     result = llama_decode_text(result);
18362+                     if (length < (int) result.length()) {
18363+                         return -(int) result.length();
18364+                     }
18365+                     memcpy(buf, result.c_str(), result.length());
18366+                     return result.length();
18367+                 } else if (
18368+                         (llama_is_user_defined_token(model->vocab, token)) ||
18369+                         (llama_is_control_token     (model->vocab, token) && special)) {
18370+                     std::string result = model->vocab.id_to_token[token].text;
18371+                     if (length < (int) result.length()) {
18372+                         return -(int) result.length();
18373+                     }
18374+                     memcpy(buf, result.c_str(), result.length());
18375+                     return result.length();
1835018376                }
18351-                 memcpy(buf, result.c_str(), result.length());
18352-                 return result.length();
18377+                 break;
1835318378            }
18354-             break;
18355-         }
18356-         default:
18357-             GGML_ASSERT(false);
18379+             default:
18380+                 GGML_ASSERT(false);
1835818381        }
1835918382    }
1836018383    return 0;
0 commit comments