Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
cmake_minimum_required(VERSION 3.18)
project(tokenizers_cpp C CXX)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

include(FetchContent)

# update to contain more rust flags
set(TOKENIZERS_CPP_RUST_FLAGS "")
set(TOKENIZERS_CPP_CARGO_TARGET "")
Expand Down Expand Up @@ -71,6 +77,13 @@ endif ()
get_filename_component(TOKENIZERS_CPP_ROOT ${CMAKE_CURRENT_LIST_FILE} DIRECTORY)
set(TOKENIZERS_CPP_CARGO_SOURCE_PATH ${TOKENIZERS_CPP_ROOT}/rust)

FetchContent_Declare(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder the differences between fetch_content and setting it as a 3rdparty in the repo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There shouldn't be much difference, whether it's 3rd or FetchContent, both pull repositories online for building. Personally, I prefer this approach.

msgpack
GIT_REPOSITORY https://github.com/msgpack/msgpack-c
GIT_TAG cpp-6.1.0
)
option(MSGPACK_USE_BOOST "" OFF)
FetchContent_MakeAvailable(msgpack)

if(MSVC)
set(TOKENIZERS_RUST_LIB "${TOKENIZERS_CPP_CARGO_BINARY_DIR}/tokenizers_c.lib")
Expand Down Expand Up @@ -98,10 +111,12 @@ set(
TOKENIZER_CPP_SRCS
src/sentencepiece_tokenizer.cc
src/huggingface_tokenizer.cc
src/rwkv_world_tokenizer.cc
)
add_library(tokenizer_cpp_objs OBJECT ${TOKENIZER_CPP_SRCS})
target_include_directories(tokenizer_cpp_objs PRIVATE sentencepiece/src)
target_include_directories(tokenizer_cpp_objs PUBLIC ${TOKENIZERS_CPP_INCLUDE})
target_link_libraries(tokenizer_cpp_objs PRIVATE msgpack-cxx)

# sentencepiece config
option(SPM_ENABLE_SHARED "override sentence piece config" OFF)
Expand Down
4 changes: 4 additions & 0 deletions example/build_and_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ fi
if [ ! -f "tokenizer.json" ]; then
wget https://huggingface.co/togethercomputer/RedPajama-INCITE-Chat-3B-v1/resolve/main/tokenizer.json
fi
if [ ! -f "tokenizer_model" ]; then
wget https://github.com/BBuf/rwkv-world-tokenizer/releases/download/v1.0.0/tokenizer_model.zip
unzip tokenizer_model.zip
fi
cd ..

# run
Expand Down
17 changes: 17 additions & 0 deletions example/example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,24 @@ void HuggingFaceTokenizerExample() {
std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl;
}

// RWKV world tokenizer
// - dist/tokenizer_model
void RWKVWorldTokenizerExample() {
auto tok = Tokenizer::FromBlobRwkvWorld("dist/tokenizer_model");
std::string prompt = "What is the capital of Canada?";
// call Encode to turn prompt into token ids
std::vector<int> ids = tok->Encode(prompt);
// call Decode to turn ids into string
std::string decoded_prompt = tok->Decode(ids);

// print encoded result
std::cout << "RWKV World tokenizer: " << std::endl;
PrintEncodeResult(ids);
std::cout << "decode=\"" << decoded_prompt << "\"" << std::endl;
}

int main(int argc, char* argv[]) {
SentencePieceTokenizerExample();
HuggingFaceTokenizerExample();
RWKVWorldTokenizerExample();
}
24 changes: 24 additions & 0 deletions include/rwkv_world_tokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*!
* Copyright (c) 2023 by Contributors daquexian
* \file rwkv_world_tokenizer.h
* \brief Implementation of llm chat.
*/

#include <unordered_map>
#include <string>
#include <vector>

namespace tokenizers {
class RWKVWorldToolTokenizer {
public:
RWKVWorldToolTokenizer(const std::string &path);
std::vector<int> encode(std::string_view str) const;
std::string decode(const std::vector<int> &ids) const;
std::string decode(int id) const;

private:
std::unordered_map<std::string, int> _word2idx;
std::unordered_map<int, std::string> _idx2word;
};
} // namespace tokenizers

7 changes: 7 additions & 0 deletions include/tokenizers_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ class Tokenizer {
* \return The created tokenizer.
*/
static std::unique_ptr<Tokenizer> FromBlobSentencePiece(const std::string& model_blob);
/*!
* \brief Create RWKVWorldTokenizer.
*
* \param model_blob The blob that contains vocabs.
* \return The created tokenizer.
*/
static std::unique_ptr<Tokenizer> FromBlobRwkvWorld(const std::string& model_blob);
};

} // namespace tokenizers
Expand Down
101 changes: 101 additions & 0 deletions src/rwkv_world_tokenizer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*!
* Copyright (c) 2023 by Contributors
* \file rwkv_world_tokenizer.cpp
* \brief Implementation of llm chat.
*/
#include <tokenizers_cpp.h>
#include "rwkv_world_tokenizer.h"

#include <iostream>
#include <fstream>
#include <string_view>
#include <msgpack.hpp>

namespace tokenizers {

RWKVWorldToolTokenizer::RWKVWorldToolTokenizer(const std::string &path) {
std::ifstream infile;
infile.open(path, std::ios::binary | std::ios::in);
infile.seekg(0, std::ios::end);
int64_t length = infile.tellg();
infile.seekg(0, std::ios::beg);
char *data = new char[length];
infile.read(data, length);
infile.close();

auto unpacker = msgpack::unpack(data, length);
auto obj = unpacker.get();
_idx2word = obj.as<std::unordered_map<int, std::string>>();
for (auto &pair : _idx2word) {
_word2idx[pair.second] = pair.first;
}
}

std::vector<int> RWKVWorldToolTokenizer::encode(std::string_view str) const {
std::vector<int> ids;
int str_idx = 0;
int word_len = 1;
int id = 0;
while (str_idx < str.size()) {
if (str_idx + word_len > str.size()) {
ids.push_back(id);
break;
}
auto substr = str.substr(str_idx, word_len);
auto it = _word2idx.find(std::string(substr));
if (it == _word2idx.end()) {
ids.push_back(id);
str_idx += (word_len - 1);
word_len = 1;
} else {
id = it->second;
word_len++;
}
}
return ids;
}

std::string RWKVWorldToolTokenizer::decode(int id) const {
auto it = _idx2word.find(id);
if (it == _idx2word.end()) {
return "<unk>";
} else {
return it->second;
}
}

std::string RWKVWorldToolTokenizer::decode(const std::vector<int> &ids) const {
std::string str;
for (auto id : ids) {
str += decode(id);
}
return str;
}

RWKVWorldToolTokenizer createRWKVWorldToolTokenizer(const std::string &path) {
return RWKVWorldToolTokenizer(path);
}

class RWKVWorldTokenizer : public Tokenizer {
public:
explicit RWKVWorldTokenizer(const std::string& model_blob) : rwkv_world_tokenizer_(model_blob) {
}

std::vector<int32_t> Encode(const std::string& text) final {
return rwkv_world_tokenizer_.encode(text);
}

std::string Decode(const std::vector<int32_t>& ids) final {
return rwkv_world_tokenizer_.decode(ids);
}

private:
// the tokenizer
RWKVWorldToolTokenizer rwkv_world_tokenizer_;
};

std::unique_ptr<Tokenizer> Tokenizer::FromBlobRwkvWorld(const std::string& model_blob) {
return std::make_unique<RWKVWorldTokenizer>(model_blob);
}

} // namespace tokenizers