diff --git a/.gitmodules b/.gitmodules index 44137b27a71..42deca0a6bb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -62,3 +62,9 @@ [submodule "examples/third-party/LLaVA"] path = examples/third-party/LLaVA url = https://github.com/haotian-liu/LLaVA.git +[submodule "examples/models/llama2/third-party/re2"] + path = examples/models/llama2/third-party/re2 + url = https://github.com/google/re2.git +[submodule "examples/models/llama2/third-party/abseil-cpp"] + path = examples/models/llama2/third-party/abseil-cpp + url = https://github.com/abseil/abseil-cpp.git diff --git a/examples/models/llama2/CMakeLists.txt b/examples/models/llama2/CMakeLists.txt index ad6a2c78f9d..cf8bc96bf23 100644 --- a/examples/models/llama2/CMakeLists.txt +++ b/examples/models/llama2/CMakeLists.txt @@ -21,6 +21,8 @@ project(llama_runner) # Duplicating options as root CMakeLists.txt option(EXECUTORCH_BUILD_OPTIMIZED "Build the optimized kernels" OFF) +option(EXECUTORCH_BUILD_RE2 "Build RE2" OFF) + include(CMakeDependentOption) # # pthreadpool: build pthreadpool library. Disable on unsupported platforms @@ -86,8 +88,19 @@ endif() # llama_runner library add_subdirectory(runner) - -set(link_libraries) +if(EXECUTORCH_BUILD_RE2) + # find RE2 for tokenizer + set(ABSL_ENABLE_INSTALL ON) + set(_pic_flag + ${CMAKE_POSITION_INDEPENDENT_CODE}) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/abseil-cpp) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third-party/re2) + set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) + target_link_libraries(llama_runner PUBLIC re2::re2) +endif() + +set(link_libraries gflags) set(_srcs main.cpp) if(EXECUTORCH_BUILD_OPTIMIZED) @@ -162,7 +175,7 @@ if(CMAKE_BUILD_TYPE EQUAL "RELEASE") endif() target_include_directories(llama_main PUBLIC ${_common_include_directories}) -target_link_libraries(llama_main PUBLIC gflags llama_runner ${link_libraries}) +target_link_libraries(llama_main PUBLIC llama_runner ${link_libraries}) target_compile_options(llama_main PUBLIC ${_common_compile_options}) if(APPLE) diff --git a/examples/models/llama2/main.cpp b/examples/models/llama2/main.cpp index 10a355a6037..a8cb048f88b 100644 --- a/examples/models/llama2/main.cpp +++ b/examples/models/llama2/main.cpp @@ -39,6 +39,11 @@ DEFINE_int32( -1, "Number of CPU threads for inference. Defaults to -1, which implies we'll use a heuristic to derive the # of performant cores for a specific device."); +DEFINE_bool( + use_tiktoken, + false, + "Use Tiktoken tokenizer instead of the default BPE tokenizer."); + int32_t main(int32_t argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); @@ -57,6 +62,8 @@ int32_t main(int32_t argc, char** argv) { int32_t cpu_threads = FLAGS_cpu_threads; + bool use_tiktoken = FLAGS_use_tiktoken; + #if defined(ET_USE_THREADPOOL) uint32_t num_performant_cores = cpu_threads == -1 ? torch::executorch::cpuinfo::get_num_performant_cores() @@ -69,7 +76,8 @@ int32_t main(int32_t argc, char** argv) { } #endif // create llama runner - ::torch::executor::Runner runner(model_path, tokenizer_path, temperature); + ::torch::executor::Runner runner( + model_path, tokenizer_path, temperature, use_tiktoken); // generate runner.generate(prompt, seq_len); diff --git a/examples/models/llama2/runner/generation.py b/examples/models/llama2/runner/generation.py index ed6fd7445b8..56a15005ef1 100644 --- a/examples/models/llama2/runner/generation.py +++ b/examples/models/llama2/runner/generation.py @@ -71,7 +71,7 @@ def generate( # noqa: C901 self, prompt_tokens: List[List[int]], max_gen_len: int, - temperature: float = 0.6, + temperature: float = 0.8, top_p: float = 0.9, logprobs: bool = False, echo: bool = False, diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 61a5ea66bdc..c6889a150dd 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include @@ -37,8 +38,10 @@ std::string statsToJsonString(const Runner::Stats& stats); Runner::Runner( const std::string& model_path, const std::string& tokenizer_path, - const float temperature) - : module_(std::make_unique( + const float temperature, + bool use_tiktoken) + : use_tiktoken_(use_tiktoken), + module_(std::make_unique( model_path, Module::MlockConfig::UseMlockIgnoreErrors)), tokenizer_path_(tokenizer_path), @@ -77,7 +80,11 @@ Error Runner::load() { append_eos_ = getMetadataHelper("append_eos_to_prompt", false); // Load tokenizer - tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); + if (use_tiktoken_) { + tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); + } else { + tokenizer_ = std::make_unique(vocab_size_, bos_id_, eos_id_); + } tokenizer_->load(tokenizer_path_); if (tokenizer_->bos_tok() != bos_id_) { ET_LOG( diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 4e200d5e6ca..f15cdb636d0 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -29,7 +29,8 @@ class Runner { explicit Runner( const std::string& model_path, const std::string& tokenizer_path, - const float temperature = 0.8f); + const float temperature = 0.8f, + bool use_tiktoken = false); struct Stats { // Scaling factor for timestamps - in this case, we use ms. @@ -85,6 +86,7 @@ class Runner { int32_t n_bos_; int32_t n_eos_; int32_t max_seq_len_; + bool use_tiktoken_; bool use_kv_cache_; bool use_sdpa_with_kv_cache_; bool append_eos_; diff --git a/examples/models/llama2/third-party/abseil-cpp b/examples/models/llama2/third-party/abseil-cpp new file mode 160000 index 00000000000..85419307149 --- /dev/null +++ b/examples/models/llama2/third-party/abseil-cpp @@ -0,0 +1 @@ +Subproject commit 854193071498f330b71083d7e06a7cd18e02a4cc diff --git a/examples/models/llama2/third-party/re2 b/examples/models/llama2/third-party/re2 new file mode 160000 index 00000000000..ac82d4f628a --- /dev/null +++ b/examples/models/llama2/third-party/re2 @@ -0,0 +1 @@ +Subproject commit ac82d4f628a2045d89964ae11c48403d3b091af1 diff --git a/examples/models/llama2/tokenizer/base64.h b/examples/models/llama2/tokenizer/base64.h new file mode 100644 index 00000000000..9fb1b5129b3 --- /dev/null +++ b/examples/models/llama2/tokenizer/base64.h @@ -0,0 +1,180 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +// @lint-ignore-every LICENSELINT +/************************************************************************** + Copyright (c) 2023 sewenew + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + *************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +namespace torch { +namespace executor { +namespace base64 { + +std::string decode(const std::string_view& input); + +namespace detail { + +constexpr uint32_t DECODE_TABLE[] = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, + 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, + 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, + 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, + 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, + 49, 50, 51, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255}; + +inline void validate(uint32_t v) { + ET_CHECK_MSG(v != 255, "invalid char"); +} + +inline void decode(const std::string_view& input, std::string& output) { + ET_CHECK_MSG( + input.size() == 4, "input length must be 4, got %zu", input.size()); + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + validate(v); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + c = input[2]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + c = input[3]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 16) & 0xFF)); + output.push_back(static_cast((val >> 8) & 0xFF)); + output.push_back(static_cast(val & 0xFF)); +} + +inline void decode_1_padding( + const std::string_view& input, + std::string& output) { + ET_CHECK_MSG( + input.size() == 3, "input length must be 3, got %zu", input.size()); + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + validate(v); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + c = input[2]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 10) & 0xFF)); + output.push_back(static_cast((val >> 2) & 0xFF)); +} + +inline void decode_2_padding( + const std::string_view& input, + std::string& output) { + assert(input.size() == 2); + + uint32_t val = 0; + + uint8_t c = input[0]; + auto v = DECODE_TABLE[c]; + validate(v); + val = v; + + c = input[1]; + v = DECODE_TABLE[c]; + validate(v); + val = (val << 6) | v; + + output.push_back(static_cast((val >> 4) & 0xFF)); +} + +} // namespace detail + +inline std::string decode(const std::string_view& input) { + ET_CHECK_MSG(!input.empty(), "empty input"); + + // Faster than `input.size() % 4`. + ET_CHECK_MSG( + (input.size() & 3) == 0 && input.size() >= 4, + "input length must be larger than 4 and is multiple of 4, got %zu", + input.size()); + + std::string output; + output.reserve(input.size() / 4 * 3); + auto idx = 0U; + for (; idx < input.size() - 4; idx += 4) { + detail::decode(input.substr(idx, 4), output); + } + + // Last 4 bytes. Might contain paddings. + if (input[idx + 3] == '=') { + if (input[idx + 2] == '=') { + // Tow paddings. + detail::decode_2_padding(input.substr(idx, 2), output); + } else { + // One padding. + detail::decode_1_padding(input.substr(idx, 3), output); + } + } else { + // No padding. + detail::decode(input.substr(idx, 4), output); + } + + return output; +} + +} // namespace base64 + +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/tokenizer/targets.bzl b/examples/models/llama2/tokenizer/targets.bzl index b25693558ae..2ac2b483991 100644 --- a/examples/models/llama2/tokenizer/targets.bzl +++ b/examples/models/llama2/tokenizer/targets.bzl @@ -5,10 +5,13 @@ def define_common_targets(): name = "tokenizer", srcs = [ "bpe_tokenizer.cpp", + "tiktoken.cpp", ], exported_headers = [ "tokenizer.h", "bpe_tokenizer.h", + "tiktoken.h", + "base64.h", ], exported_deps = [ "//executorch/runtime/core/exec_aten:lib", @@ -17,6 +20,9 @@ def define_common_targets(): visibility = [ "@EXECUTORCH_CLIENTS", ], + exported_external_deps = [ + "re2", + ], ) runtime.python_library( diff --git a/examples/models/llama2/tokenizer/test/targets.bzl b/examples/models/llama2/tokenizer/test/targets.bzl index 7ed15b81b9e..3642ceca66f 100644 --- a/examples/models/llama2/tokenizer/test/targets.bzl +++ b/examples/models/llama2/tokenizer/test/targets.bzl @@ -20,6 +20,22 @@ def define_common_targets(): }, ) + runtime.cxx_test( + name = "test_tiktoken", + srcs = [ + "test_tiktoken.cpp", + ], + deps = [ + "//executorch/examples/models/llama2/tokenizer:tokenizer", + ], + env = { + "RESOURCES_PATH": "$(location :resources_fb_only)/resources", + }, + external_deps = [ + "re2", + ], + ) + runtime.filegroup( name = "resources", srcs = native.glob([ @@ -27,6 +43,13 @@ def define_common_targets(): ]), ) + runtime.filegroup( + name = "resources_fb_only", + srcs = native.glob([ + "resources/fb/**", + ]), + ) + runtime.python_test( name = "test_tokenizer_py", srcs = [ diff --git a/examples/models/llama2/tokenizer/test/test_tiktoken.cpp b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp new file mode 100644 index 00000000000..2f08e2a1aa7 --- /dev/null +++ b/examples/models/llama2/tokenizer/test/test_tiktoken.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +using namespace ::testing; + +namespace torch { +namespace executor { + +class TiktokenExtensionTest : public Test { + public: + void SetUp() override { + torch::executor::runtime_init(); + tokenizer_ = std::make_unique(128256, 128000, 128001); + modelPath_ = + std::getenv("RESOURCES_PATH") + std::string("/fb/tokenizer.model"); + } + + std::unique_ptr tokenizer_; + std::string modelPath_; +}; + +TEST_F(TiktokenExtensionTest, EncodeWithoutLoadFails) { + Result> res = tokenizer_->encode("hello world", 0, 0); + EXPECT_EQ(res.error(), Error::NotSupported); +} + +TEST_F(TiktokenExtensionTest, DecodeWithoutLoadFails) { + auto result = tokenizer_->decode(0, 0); + EXPECT_EQ(result.error(), Error::NotSupported); +} + +TEST_F(TiktokenExtensionTest, TokenizerVocabSizeIsExpected) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // test.bin has vocab size 0 but the tokenizer respects the vocab size being + // passed in and add placeholder tokens. + EXPECT_EQ(tokenizer_->vocab_size(), 128256); + EXPECT_EQ(tokenizer_->bos_tok(), 128000); + EXPECT_EQ(tokenizer_->eos_tok(), 128001); +} + +TEST_F(TiktokenExtensionTest, TokenizerEncodeCorrectly) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // test.bin has vocab size 0 but the tokenizer respects the vocab size being + // passed in and add placeholder tokens. + Result> out = tokenizer_->encode("hello world", 1, 0); + EXPECT_EQ(out.error(), Error::Ok); + EXPECT_EQ(out.get().size(), 3); + EXPECT_EQ(out.get()[0], 128000); + EXPECT_EQ(out.get()[1], 15339); + EXPECT_EQ(out.get()[2], 1917); +} + +TEST_F(TiktokenExtensionTest, TokenizerDecodeCorrectly) { + Error res = tokenizer_->load(modelPath_.c_str()); + EXPECT_EQ(res, Error::Ok); + // test.bin has vocab size 0 but the tokenizer respects the vocab size being + // passed in and add placeholder tokens. + std::vector expected = {"<|begin_of_text|>", "hello", " world"}; + std::vector tokens = {128000, 15339, 1917}; + for (size_t i = 0; i < tokens.size(); i++) { + Result out = tokenizer_->decode(0, tokens[i]); + EXPECT_EQ(out.error(), Error::Ok); + EXPECT_EQ(out.get(), expected[i]); + } +} + +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/tokenizer/tiktoken.cpp b/examples/models/llama2/tokenizer/tiktoken.cpp new file mode 100644 index 00000000000..849a2ff1e8d --- /dev/null +++ b/examples/models/llama2/tokenizer/tiktoken.cpp @@ -0,0 +1,391 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Adopted from https://github.com/sewenew/tokenizer + +// @lint-ignore-every LICENSELINT +/************************************************************************** + Copyright (c) 2023 sewenew + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + *************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +// ------------------------------Util start------------------------------------ + +static uint64_t _max_size() { + return std::numeric_limits::max(); +} + +static Re2UPtr _create_regex(const std::string& pattern) { + assert(!pattern.empty()); + + return std::make_unique("(" + pattern + ")"); +} + +static Re2UPtr _build_special_token_regex(const Encoder& special_encoder) { + std::string special_pattern; + for (const auto& ele : special_encoder) { + if (!special_pattern.empty()) { + special_pattern += "|"; + } + special_pattern += re2::RE2::QuoteMeta(ele.first); + } + + if (special_pattern.empty()) { + return nullptr; + } + + return _create_regex(special_pattern); +} + +static std::pair _parse(const std::string& line) { + auto pos = line.find(" "); + ET_CHECK_MSG( + pos != std::string::npos, "invalid encoder line: %s", line.c_str()); + + auto token = base64::decode({line.data(), pos}); + uint64_t rank = 0; + try { + rank = std::stoul(line.substr(pos + 1)); + } catch (const std::exception&) { + ET_CHECK_MSG(false, "invalid encoder rank: %s", line.c_str()); + } + + return {std::move(token), rank}; +} + +static Encoder _load_encoder(const std::string& path) { + std::ifstream file(path); + ET_CHECK_MSG(file, "failed to open encoder file: %s", path.c_str()); + + Encoder encoder; + std::string line; + while (std::getline(file, line)) { + auto [token, rank] = _parse(line); + + ET_CHECK_MSG( + encoder.emplace(std::move(token), rank).second, + "duplicate item: %s", + line.c_str()); + } + + return encoder; +} + +static Decoder _build_decoder(const Encoder& encoder) { + Decoder decoder; + for (const auto& [k, v] : encoder) { + decoder.emplace(v, k); + } + + ET_CHECK_MSG(encoder.size() == decoder.size(), "duplicate items in encoder"); + + return decoder; +} + +static std::vector _byte_pair_merge( + const std::string& piece, + const std::unordered_map& ranks, + std::function func) { + // This is a vector of (start, rank). + // The rank is of the byte pair starting at position start. + // The rank of the last item in the vector is not a valid value. + std::vector> parts; + parts.reserve(piece.size() + 1); + for (auto idx = 0U; idx < piece.size() + 1; ++idx) { + parts.emplace_back(idx, _max_size()); + } + + auto get_rank = [&piece, &ranks]( + const std::vector>& parts, + uint64_t start_idx, + uint64_t skip) -> std::optional { + if (start_idx + skip + 2 < parts.size()) { + auto s = parts[start_idx].first; + auto e = parts[start_idx + skip + 2].first; + auto key = piece.substr(s, e - s); + auto iter = ranks.find(key); + if (iter != ranks.end()) { + return iter->second; + } + } + return std::nullopt; + }; + + // We look up the ranks once in the beginning and iteratively update + // them during each merge, which reduces the number of rank lookups. + for (auto i = 0U; i < parts.size() - 2; ++i) { + auto rank = get_rank(parts, i, 0); + if (rank) { + // usize::MAX is a sentinel value and cannot be a valid rank + ET_CHECK_MSG(*rank != _max_size(), "rank is too large"); + parts[i].second = *rank; + } + } + + // If you have n parts and m merges, this does O(mn) work. + // We could do something with a heap and do O(m log n) work. + // It is important to consider that n is often small (<100), and as such + // the cache-locality benefits outweigh the algorithmic complexity downsides + // of the `parts` vector data structure above. + + // Note that we hash bytes, not token pairs. As long as we train BPE the way + // we currently do, this is equivalent. An easy way to break this would be + // to decouple merge priority from token index or to prevent specific token + // merges. + while (true) { + if (parts.size() == 1) { + break; + } + + // usize::MAX is a sentinel rank value allowing us to + // take the min more quickly + auto min_rank = std::make_pair(_max_size(), 0); + for (auto i = 0U; i < parts.size() - 1; ++i) { + auto rank = parts[i].second; + if (rank < min_rank.first) { + min_rank.first = rank; + min_rank.second = i; + } + } + + if (min_rank.first != _max_size()) { + auto i = min_rank.second; + + // NOTE: We are about to remove parts[i + 1]. We do not do it + // yet because there are cache-locality benefits to updating + // parts[i] and parts[i-1] before removing, which could thrash + // the cache. Thus, we update the rank calculation by skipping over + // parts[i + 1], by invoking `get_rank!` with `skip = 1`. + auto rank = get_rank(parts, i, 1); + if (rank) { + parts[i].second = *rank; + } else { + parts[i].second = _max_size(); + } + if (i > 0) { + rank = get_rank(parts, i - 1, 1); + if (rank) { + parts[i - 1].second = *rank; + } else { + parts[i - 1].second = _max_size(); + } + } + + parts.erase(parts.begin() + (i + 1)); + } else { + break; + } + } + std::vector out; + out.reserve(parts.size() - 1); + for (auto i = 0U; i < parts.size() - 1; ++i) { + auto s = parts[i].first; + auto e = parts[i + 1].first; + out.push_back(func(s, e)); + } + return out; +} + +static std::vector _byte_pair_encode( + const std::string& piece, + const Encoder& encoder) { + if (piece.size() == 1) { + auto iter = encoder.find(piece); + if (iter != encoder.end()) { + return std::vector({iter->second}); + } else { + // TODO: is it possible? + return {}; + } + } + + return _byte_pair_merge( + piece, encoder, [&piece, &encoder](uint64_t start, uint64_t stop) { + std::string key = piece.substr(start, stop - start); + auto iter = encoder.find(key); + if (iter != encoder.end()) { + return iter->second; + } else { + // TODO: what if key does not exist? Should we return `unknown`? + // assert(false); // ?? + return uint64_t(0); + } + }); +} +// ------------------------------Util end------------------------------------ +// -------------------------private method start------------------------------- + +template +std::pair, re2::StringPiece> +Tiktoken::_split_with_allowed_special_token( + re2::StringPiece& input, + const T& allowed_special) { + if (!_special_token_regex) { + return std::make_pair(std::nullopt, input); + } + + auto start = input.begin(); + std::string special; + while (true) { + if (!re2::RE2::FindAndConsume(&input, *_special_token_regex, &special)) { + // No special token. + break; + } + + if (allowed_special.count(special) == 1) { + // Found an allowed special token, split the text with it. + return std::make_pair( + special, + re2::StringPiece(start, input.begin() - start - special.size())); + } // else try to find the next special token + } + + return std::make_pair(std::nullopt, input); +} + +void Tiktoken::_encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len) { + std::string piece; + assert(_regex); + while (re2::RE2::FindAndConsume(&input, *_regex, &piece)) { + auto iter = _encoder.find(piece); + if (iter != _encoder.end()) { + last_piece_token_len = 1; + ret.push_back(iter->second); + continue; + } + auto tokens = _byte_pair_encode(piece, _encoder); + last_piece_token_len = tokens.size(); + ret.insert(ret.end(), tokens.begin(), tokens.end()); + } +} + +template +std::pair, uint64_t> Tiktoken::_encode_with_special_token( + const std::string& text, + const T& allowed_special) { + std::vector tokens; + uint64_t last_piece_token_len = 0; + re2::StringPiece input(text); + while (true) { + auto [special, sub_input] = + _split_with_allowed_special_token(input, allowed_special); + + _encode(sub_input, tokens, last_piece_token_len); + + if (special) { + uint64_t token = 0; + try { + token = _special_token_encoder.at(*special); + } catch (const std::out_of_range&) { + // Should never go here, since special pattern includes all special + // chars. + ET_CHECK_MSG(false, "unknown special token: %s", special->c_str()); + } + + tokens.push_back(token); + last_piece_token_len = 0; + } else { + break; + } + } + + // last_piece_token_len is how many tokens came from the last regex split. + // This is used for determining unstable tokens, since you can't merge + // across (stable) regex splits + return std::make_pair(tokens, last_piece_token_len); +} + +// -------------------------private method end------------------------------- +// -------------------------public method start------------------------------- + +Error Tiktoken::load(const std::string& path) { + _encoder = _load_encoder(path); + _special_token_encoder = _get_special_tokens(_encoder.size()); + + _decoder = _build_decoder(_encoder); + _special_token_decoder = _build_decoder(_special_token_encoder); + + _regex = _create_regex(_pattern); + + _special_token_regex = _build_special_token_regex(_special_token_encoder); + + initialized_ = true; + return Error::Ok; +} + +Result> +Tiktoken::encode(const std::string& text, int8_t bos, int8_t eos) { + if (!initialized_) { + return Error::NotSupported; + } + auto res = _encode_with_special_token(text, _special_token_encoder).first; + for (auto i = 0; i < bos; ++i) { + res.insert(res.begin(), bos_tok_); + } + for (auto i = 0; i < eos; ++i) { + res.push_back(eos_tok_); + } + return Result(res); +} + +Result Tiktoken::decode(uint64_t prev, uint64_t cur) { + (void)prev; + if (!initialized_) { + return Error::NotSupported; + } + std::string ret; + + std::string token_bytes; + auto iter = _decoder.find(cur); + if (iter != _decoder.end()) { + token_bytes = iter->second; + } else { + iter = _special_token_decoder.find(cur); + if (iter != _special_token_decoder.end()) { + token_bytes = iter->second; + } else { + ET_CHECK_MSG(false, "unknown token: %" PRIu64, cur); + } + } + ret += token_bytes; + + return ret; +} +// -------------------------public method end------------------------------- + +} // namespace executor +} // namespace torch diff --git a/examples/models/llama2/tokenizer/tiktoken.h b/examples/models/llama2/tokenizer/tiktoken.h new file mode 100644 index 00000000000..e00efdf99e3 --- /dev/null +++ b/examples/models/llama2/tokenizer/tiktoken.h @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace torch { +namespace executor { + +using Encoder = std::unordered_map; +using Decoder = std::unordered_map; +using Re2UPtr = std::unique_ptr; + +class Tiktoken : public Tokenizer { + public: + explicit Tiktoken(int32_t vocab_size, uint64_t bos_tok, uint64_t eos_tok) + : Tokenizer(vocab_size, bos_tok, eos_tok){}; + ~Tiktoken(){}; + + Error load(const std::string& tokenizer_path); + + Result> + encode(const std::string& input, int8_t bos, int8_t eos); + + Result decode(uint64_t prev_token, uint64_t token); + + private: + static inline const Encoder _get_special_tokens(ssize_t num_base_tokens) { + Encoder special_tokens; + special_tokens.emplace("<|begin_of_text|>", num_base_tokens++); + special_tokens.emplace("<|end_of_text|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_0|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_1|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_2|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_3|>", num_base_tokens++); + special_tokens.emplace("<|start_header_id|>", num_base_tokens++); + special_tokens.emplace("<|end_header_id|>", num_base_tokens++); + special_tokens.emplace("<|reserved_special_token_4|>", num_base_tokens++); + special_tokens.emplace("<|eot_id|>", num_base_tokens++); + for (auto i = 5; i < 251; ++i) { + special_tokens.emplace( + "<|reserved_special_token_" + std::to_string(i) + "|>", + num_base_tokens++); + } + return special_tokens; + } + + template + std::pair, re2::StringPiece> + _split_with_allowed_special_token( + re2::StringPiece& input, + const T& allowed_special); + + void _encode( + re2::StringPiece& input, + std::vector& ret, + uint64_t& last_piece_token_len); + + template + std::pair, uint64_t> _encode_with_special_token( + const std::string& text, + const T& allowed_special); + + // Removed negative lookahead \s+(?!\S) since it's not supported by RE2. + const std::string _pattern = + R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+)"; + Encoder _encoder; + Encoder _special_token_encoder; + Decoder _decoder; + Decoder _special_token_decoder; + + Re2UPtr _regex; + Re2UPtr _special_token_regex; +}; +} // namespace executor +} // namespace torch diff --git a/shim/xplat/executorch/build/env_interface.bzl b/shim/xplat/executorch/build/env_interface.bzl index 5035521dbbd..6d82cc4e4b5 100644 --- a/shim/xplat/executorch/build/env_interface.bzl +++ b/shim/xplat/executorch/build/env_interface.bzl @@ -42,6 +42,7 @@ _EXTERNAL_DEPS = { "libtorch_python": "//third-party:libtorch_python", "prettytable": "//third-party:prettytable", "pybind11": "//third-party:pybind11", + "re2": [], # TODO(larryliu0820): Add support # Core C++ PyTorch functionality like Tensor and ScalarType. "torch-core-cpp": "//third-party:libtorch", "torchgen": "//third-party:torchgen",