diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 87a8f15..faf2d75 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -25,4 +25,3 @@ target_include_directories(example PRIVATE ${TOKENZIER_CPP_PATH}/include) # You can link tokenizers_cpp, it will automatically link tokenizers_c # and sentencepiece libary target_link_libraries(example PRIVATE tokenizers_cpp) - diff --git a/example/build_and_run.sh b/example/build_and_run.sh index a1cd1f8..b2bae14 100755 --- a/example/build_and_run.sh +++ b/example/build_and_run.sh @@ -11,7 +11,7 @@ cd .. mkdir -p dist cd dist if [ ! -f "tokenizer.model" ]; then - wget https://huggingface.co/decapoda-research/llama-7b-hf/resolve/main/tokenizer.model + wget https://huggingface.co/lmsys/vicuna-7b-v1.5/resolve/main/tokenizer.model fi if [ ! -f "tokenizer.json" ]; then wget https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1/resolve/main/tokenizer.json diff --git a/example/example.cc b/example/example.cc index 530bc68..d8e90d2 100644 --- a/example/example.cc +++ b/example/example.cc @@ -1,5 +1,7 @@ #include +#include +#include #include #include #include @@ -30,60 +32,92 @@ void PrintEncodeResult(const std::vector& ids) { std::cout << "]" << std::endl; } +void TestTokenizer(std::unique_ptr tok, bool print_vocab = false, + bool check_id_back = true) { + // Check #1. Encode and Decode + std::string prompt = "What is the capital of Canada?"; + std::vector ids = tok->Encode(prompt); + std::string decoded_prompt = tok->Decode(ids); + PrintEncodeResult(ids); + std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl; + assert(decoded_prompt == prompt); + + // Check #2. IdToToken and TokenToId + std::vector ids_to_test = {0, 1, 2, 3, 32, 33, 34, 130, 131, 1000}; + for (auto id : ids_to_test) { + auto token = tok->IdToToken(id); + auto id_new = tok->TokenToId(token); + std::cout << "id=" << id << ", token=\"" << token << "\", id_new=" << id_new << std::endl; + if (check_id_back) { + assert(id == id_new); + } + } + + // Check #3. GetVocabSize + auto vocab_size = tok->GetVocabSize(); + std::cout << "vocab_size=" << vocab_size << std::endl; + + std::cout << std::endl; +} + // Sentencepiece tokenizer // - dist/tokenizer.model void SentencePieceTokenizerExample() { + std::cout << "Tokenizer: SentencePiece" << std::endl; + + auto start = std::chrono::high_resolution_clock::now(); + // Read blob from file. auto blob = LoadBytesFromFile("dist/tokenizer.model"); // Note: all the current factory APIs takes in-memory blob as input. // This gives some flexibility on how these blobs can be read. auto tok = Tokenizer::FromBlobSentencePiece(blob); - std::string prompt = "What is the capital of Canada?"; - // call Encode to turn prompt into token ids - std::vector ids = tok->Encode(prompt); - // call Decode to turn ids into string - std::string decoded_prompt = tok->Decode(ids); - // print encoded result - std::cout << "SetencePiece tokenizer: " << std::endl; - PrintEncodeResult(ids); - std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl; + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + + std::cout << "Load time: " << duration << " ms" << std::endl; + + TestTokenizer(std::move(tok), false, true); } // HF tokenizer // - dist/tokenizer.json void HuggingFaceTokenizerExample() { + std::cout << "Tokenizer: Huggingface" << std::endl; + + auto start = std::chrono::high_resolution_clock::now(); + // Read blob from file. auto blob = LoadBytesFromFile("dist/tokenizer.json"); // Note: all the current factory APIs takes in-memory blob as input. // This gives some flexibility on how these blobs can be read. auto tok = Tokenizer::FromBlobJSON(blob); - std::string prompt = "What is the capital of Canada?"; - // call Encode to turn prompt into token ids - std::vector ids = tok->Encode(prompt); - // call Decode to turn ids into string - std::string decoded_prompt = tok->Decode(ids); - // print encoded result - std::cout << "HF tokenizer: " << std::endl; - PrintEncodeResult(ids); - std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl; + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + + std::cout << "Load time: " << duration << " ms" << std::endl; + + TestTokenizer(std::move(tok), false, true); } // RWKV world tokenizer // - dist/tokenizer_model void RWKVWorldTokenizerExample() { + std::cout << "Tokenizer: RWKVWorld" << std::endl; + + auto start = std::chrono::high_resolution_clock::now(); + auto tok = Tokenizer::FromBlobRWKVWorld("dist/tokenizer_model"); - std::string prompt = "What is the capital of Canada?"; - // call Encode to turn prompt into token ids - std::vector ids = tok->Encode(prompt); - // call Decode to turn ids into string - std::string decoded_prompt = tok->Decode(ids); - // print encoded result - std::cout << "RWKV World tokenizer: " << std::endl; - PrintEncodeResult(ids); - std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl; + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start).count(); + + std::cout << "Load time: " << duration << " ms" << std::endl; + + // We cannot check id back for RWKVWorldTokenizer yet. + TestTokenizer(std::move(tok), false, false); } int main(int argc, char* argv[]) { diff --git a/include/rwkv_world_tokenizer.h b/include/rwkv_world_tokenizer.h deleted file mode 100644 index a2f06ee..0000000 --- a/include/rwkv_world_tokenizer.h +++ /dev/null @@ -1,50 +0,0 @@ -/*! - * Copyright (c) 2023 by Contributors daquexian - * \file rwkv_world_tokenizer.h - * \brief Implementation of llm chat. - */ - -#include -#include -#include -#include -#include -#include - -#define STRINGIFY(...) STRINGIFY_(__VA_ARGS__) -#define STRINGIFY_(...) #__VA_ARGS__ -#define RV_CHECK(...) \ - for (bool _rv_check_status = (__VA_ARGS__); !_rv_check_status;) \ - throw FRException() << ("Check \"" STRINGIFY(__VA_ARGS__) "\" failed at " + \ - std::to_string(__LINE__) + \ - " in " __FILE__ "\n > Error msg: ") -struct FRException : public std::runtime_error { - FRException() : std::runtime_error("") {} - const char *what() const noexcept override { return msg.c_str(); } - template FRException &operator<<(const T &s) { - std::stringstream ss; - ss << s; - msg += ss.str(); - return *this; - } - std::string msg; -}; - -namespace tokenizers { -struct TrieTree; - -class RWKVWorldToolTokenizer{ -public: - RWKVWorldToolTokenizer(const std::string &path); - std::vector encode(const std::string &str) const; - std::string decode(const std::vector &ids) const; - std::string decode(int id) const; - -private: - std::unordered_map _word2idx; - std::unordered_map _idx2word; - std::unique_ptr _tree; -}; - -} // namespace tokenizers - diff --git a/include/tokenizers_c.h b/include/tokenizers_c.h index e1b77ab..6563f53 100644 --- a/include/tokenizers_c.h +++ b/include/tokenizers_c.h @@ -32,6 +32,13 @@ void tokenizers_get_decode_str(TokenizerHandle handle, const char** data, size_t void tokenizers_get_encode_ids(TokenizerHandle handle, const uint32_t** id_data, size_t* len); +void tokenizers_get_vocab_size(TokenizerHandle handle, size_t* size); + +void tokenizers_id_to_token(TokenizerHandle handle, uint32_t id, const char** data, size_t* len); + +// tokenizers_token_to_id stores -1 to *id if the token is not in the vocab +void tokenizers_token_to_id(TokenizerHandle handle, const char* token, size_t len, int32_t* id); + void tokenizers_free(TokenizerHandle handle); #ifdef __cplusplus diff --git a/include/tokenizers_cpp.h b/include/tokenizers_cpp.h index 6480a70..7de6721 100644 --- a/include/tokenizers_cpp.h +++ b/include/tokenizers_cpp.h @@ -36,6 +36,22 @@ class Tokenizer { */ virtual std::string Decode(const std::vector& ids) = 0; + /*! + * \brief Returns the vocabulary size. Special tokens are considered. + */ + virtual size_t GetVocabSize() = 0; + + /*! + * \brief Convert the given id to its corresponding token if it exists. If not, return an + * empty string. + */ + virtual std::string IdToToken(int32_t token_id) = 0; + + /*! + * \brief Convert the given token to its corresponding id if it exists. If not, return -1. + */ + virtual int32_t TokenToId(const std::string& token) = 0; + //--------------------------------------------------- // Factory functions from byte-blobs // These factory function takes in in-memory blobs diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 9ca91f2..10206a0 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -9,6 +9,7 @@ pub struct TokenizerWrapper { tokenizer: Tokenizer, encode_ids: Vec, decode_str: String, + id_to_token_result: String, } pub type Vocab = HashMap; @@ -20,6 +21,7 @@ impl TokenizerWrapper { tokenizer: Tokenizer::from_str(json).unwrap().into(), encode_ids: Vec::new(), decode_str: String::new(), + id_to_token_result: String::new(), } } @@ -77,6 +79,7 @@ impl TokenizerWrapper { tokenizer: tokenizer, encode_ids: Vec::new(), decode_str: String::new(), + id_to_token_result: String::new(), } } @@ -182,3 +185,46 @@ extern "C" fn tokenizers_free(wrapper: *mut TokenizerWrapper) { drop(Box::from_raw(wrapper)); } } + +#[no_mangle] +extern "C" fn tokenizers_get_vocab_size(handle: *mut TokenizerWrapper, size: *mut usize) { + unsafe { + *size = (*handle).tokenizer.get_vocab_size(true); + } +} + +#[no_mangle] +extern "C" fn tokenizers_id_to_token( + handle: *mut TokenizerWrapper, + id: u32, + out_cstr: *mut *mut u8, + out_len: *mut usize, +) { + unsafe { + let str = (*handle).tokenizer.id_to_token(id); + (*handle).id_to_token_result = match str { + Some(s) => s, + None => String::from(""), + }; + + *out_cstr = (*handle).id_to_token_result.as_mut_ptr(); + *out_len = (*handle).id_to_token_result.len(); + } +} + +#[no_mangle] +extern "C" fn tokenizers_token_to_id( + handle: *mut TokenizerWrapper, + token: *const u8, + len: usize, + out_id: *mut i32, +) { + unsafe { + let token: &str = std::str::from_utf8(std::slice::from_raw_parts(token, len)).unwrap(); + let id = (*handle).tokenizer.token_to_id(token); + *out_id = match id { + Some(id) => id as i32, + None => -1, + }; + } +} diff --git a/src/huggingface_tokenizer.cc b/src/huggingface_tokenizer.cc index 82cb441..1f96fb3 100644 --- a/src/huggingface_tokenizer.cc +++ b/src/huggingface_tokenizer.cc @@ -7,6 +7,8 @@ #include #include +#include + namespace tokenizers { /*! * \brief A simple c++ header of tokenizer via C API. @@ -31,7 +33,9 @@ class HFTokenizer : public Tokenizer { const uint32_t* data; size_t len; tokenizers_get_encode_ids(handle_, &data, &len); - return std::vector(data, data + len); + const int32_t* data_i32 = reinterpret_cast(data); + auto res = std::vector(data_i32, data_i32 + len); + return res; } // use i32 to be consistent with sentencepiece @@ -45,6 +49,26 @@ class HFTokenizer : public Tokenizer { return std::string(data, len); } + size_t GetVocabSize() final { + size_t size; + tokenizers_get_vocab_size(handle_, &size); + assert(size > 0); + return size; + } + + std::string IdToToken(int32_t id) final { + const char* data; + size_t len; + tokenizers_id_to_token(handle_, static_cast(id), &data, &len); + return std::string(data, len); + } + + int32_t TokenToId(const std::string& token) final { + int32_t id; + tokenizers_token_to_id(handle_, token.data(), token.length(), &id); + return id; + } + private: // internal handle TokenizerHandle handle_{nullptr}; diff --git a/src/rwkv_world_tokenizer.cc b/src/rwkv_world_tokenizer.cc index 10af82c..dab70a7 100644 --- a/src/rwkv_world_tokenizer.cc +++ b/src/rwkv_world_tokenizer.cc @@ -3,12 +3,11 @@ * \file rwkv_world_tokenizer.cpp * \brief Implementation of llm chat. */ -#include #include "rwkv_world_tokenizer.h" -#include +#include + #include -#include #include namespace tokenizers { @@ -17,7 +16,7 @@ struct TrieTree { std::unordered_map> children; std::string word; std::optional token_id; - + TrieTree(const std::unordered_map& word2id) { for (auto& pair : word2id) { add_word(pair.first, pair.second); @@ -45,11 +44,9 @@ struct TrieTree { return {prefix, token_id}; } - private: + private: TrieTree() = default; - void add_word(const std::string& word, int token_id) { - return _add_word(word, token_id, 0); - } + void add_word(const std::string& word, int token_id) { return _add_word(word, token_id, 0); } void _add_word(const std::string& word, int token_id, int idx) { if (idx == word.size()) { this->word = word; @@ -64,78 +61,83 @@ struct TrieTree { } }; -RWKVWorldToolTokenizer::RWKVWorldToolTokenizer(const std::string &path) { - std::ifstream infile; - infile.open(path, std::ios::binary | std::ios::in); - infile.seekg(0, std::ios::end); - int64_t length = infile.tellg(); - infile.seekg(0, std::ios::beg); - char *data = new char[length]; - infile.read(data, length); - infile.close(); - - auto unpacker = msgpack::unpack(data, length); - auto obj = unpacker.get(); - _idx2word = obj.as>(); - for (auto &pair : _idx2word) { - _word2idx[pair.second] = pair.first; +class RWKVWorldTokenizer : public Tokenizer { + public: + explicit RWKVWorldTokenizer(const std::string& path) { + std::ifstream infile; + infile.open(path, std::ios::binary | std::ios::in); + infile.seekg(0, std::ios::end); + int64_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + char* data = new char[length]; + infile.read(data, length); + infile.close(); + + auto unpacker = msgpack::unpack(data, length); + auto obj = unpacker.get(); + delete[] data; + _idx2word = obj.as>(); + for (auto& pair : _idx2word) { + _word2idx[pair.second] = pair.first; + } + _tree = std::make_unique(_word2idx); } - _tree = std::make_unique(_word2idx); -} -std::vector RWKVWorldToolTokenizer::encode(const std::string &str) const { - std::vector ids; - int str_idx = 0; - - while (str_idx < str.size()) { - auto [prefix, token_id] = _tree->find_longest_prefix(str.substr(str_idx)); - ids.push_back(token_id); - str_idx += prefix.size(); - } - return ids; -} + std::vector Encode(const std::string& str) final { + std::vector ids; + int str_idx = 0; -std::string RWKVWorldToolTokenizer::decode(int id) const { - auto it = _idx2word.find(id); - if (it == _idx2word.end()) { - return ""; - } else { - return it->second; + while (str_idx < str.size()) { + auto [prefix, token_id] = _tree->find_longest_prefix(str.substr(str_idx)); + ids.push_back(token_id); + str_idx += prefix.size(); + } + return ids; } -} -std::string RWKVWorldToolTokenizer::decode(const std::vector &ids) const { - std::string str; - for (auto id : ids) { - str += decode(id); + std::string Decode(const std::vector& ids) final { + std::string str; + for (auto id : ids) { + str += IdToToken(id); + } + return str; } - return str; -} - -RWKVWorldToolTokenizer createRWKVWorldToolTokenizer(const std::string &path) { - return RWKVWorldToolTokenizer(path); -} -class RWKVWorldTokenizer : public Tokenizer { - public: - explicit RWKVWorldTokenizer(const std::string& model_blob) : rwkv_world_tokenizer_(model_blob) { + size_t GetVocabSize() final { + auto size = _idx2word.size(); + RV_CHECK(size > 0); + return size; } - std::vector Encode(const std::string& text) final { - return rwkv_world_tokenizer_.encode(text); + virtual std::string IdToToken(int32_t token_id) final { + RV_CHECK(_idx2word.size() > 0); + auto it = _idx2word.find(token_id); + if (it == _idx2word.end()) { + return ""; + } else { + return it->second; + } } - std::string Decode(const std::vector& ids) final { - return rwkv_world_tokenizer_.decode(ids); + int32_t TokenToId(const std::string& token) final { + RV_CHECK(_word2idx.size() > 0); + auto it = _word2idx.find(token); + if (it == _word2idx.end()) { + return -1; + } else { + return it->second; + } } private: // the tokenizer - RWKVWorldToolTokenizer rwkv_world_tokenizer_; + std::unordered_map _word2idx; + std::unordered_map _idx2word; + std::unique_ptr _tree; }; std::unique_ptr Tokenizer::FromBlobRWKVWorld(const std::string& model_blob) { return std::make_unique(model_blob); } -} // namespace tokenizers +} // namespace tokenizers diff --git a/src/rwkv_world_tokenizer.h b/src/rwkv_world_tokenizer.h new file mode 100644 index 0000000..46bfbed --- /dev/null +++ b/src/rwkv_world_tokenizer.h @@ -0,0 +1,33 @@ +/*! + * Copyright (c) 2023 by Contributors daquexian + * \file rwkv_world_tokenizer.h + * \brief Implementation of llm chat. + */ +#ifndef RWKV_WORLD_TOKENIZER_H_ +#define RWKV_WORLD_TOKENIZER_H_ + +#include +#include +#include +#include + +#define STRINGIFY(...) STRINGIFY_(__VA_ARGS__) +#define STRINGIFY_(...) #__VA_ARGS__ +#define RV_CHECK(...) \ + for (bool _rv_check_status = (__VA_ARGS__); !_rv_check_status;) \ + throw FRException() << ("Check \"" STRINGIFY(__VA_ARGS__) "\" failed at " + \ + std::to_string(__LINE__) + " in " __FILE__ "\n > Error msg: ") +struct FRException : public std::runtime_error { + FRException() : std::runtime_error("") {} + const char* what() const noexcept override { return msg.c_str(); } + template + FRException& operator<<(const T& s) { + std::stringstream ss; + ss << s; + msg += ss.str(); + return *this; + } + std::string msg; +}; + +#endif // RWKV_WORLD_TOKENIZER_H_ diff --git a/src/sentencepiece_tokenizer.cc b/src/sentencepiece_tokenizer.cc index d6dfb9f..ed188df 100644 --- a/src/sentencepiece_tokenizer.cc +++ b/src/sentencepiece_tokenizer.cc @@ -3,9 +3,10 @@ * \file sentencepiece_tokenizer.cc * \brief Sentencepice tokenizer */ +#include #include -#include "sentencepiece_processor.h" +#include namespace tokenizers { @@ -27,6 +28,16 @@ class SentencePieceTokenizer : public Tokenizer { return text; } + size_t GetVocabSize() final { + auto size = sentence_piece_.GetPieceSize(); + assert(size > 0); + return size; + } + + std::string IdToToken(int32_t id) final { return sentence_piece_.IdToPiece(id); } + + int32_t TokenToId(const std::string& token) final { return sentence_piece_.PieceToId(token); } + private: // the tokenizer sentencepiece::SentencePieceProcessor sentence_piece_;