diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index b61397311aa..04183efc4d0 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -754,7 +754,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector & output) { // normalize and split by whitespace - std::vector words = preprocess(text); + std::vector words = preprocess(text, vocab.get_normalizer_lowercase()); // bos token prepended already // find the longest tokens that form the words @@ -799,7 +799,7 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector preprocess(const std::string & text) { + static std::vector preprocess(const std::string & text, bool lowercase) { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); @@ -818,7 +818,7 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + const std::string s = unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -2159,6 +2159,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "whitespace") { pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; + normalizer_lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -2339,9 +2340,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } - ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); - ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); - ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + ml.get_key(LLM_KV_TOKENIZER_ADD_PREFIX, add_space_prefix, false); + ml.get_key(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, remove_extra_whitespaces, false); } const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); @@ -2511,6 +2511,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } } + // Lowercase normalizer flag (consulted by WPM / whitespace BPE) + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_lowercase, false); + // auto-detect special tokens by text // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_... // for now, we apply this workaround to find the tokens based on their text