From 85d1e5b18b295bb8b6e25fbf72387bccd4ddb287 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Wed, 21 May 2025 13:19:58 -0700 Subject: [PATCH] Use dependency injection for runner (#10326) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10326 X-link: https://github.com/pytorch-labs/tokenizers/pull/53 Pass in runner components, move most of the instantiation logic from `load()` to a new static API `create()`. This adds testability to runner components. Next step would be moving most of the logic out into `extension/llm/runner/` so that it can be used on non-llama models. Currently the logic for getting tokenizer instance should not assume llama, which I can modify in next diff. Reviewed By: kirklandsign, iseeyuan Differential Revision: D73165546 --- .../LLaMARunner/Exported/LLaMARunner.mm | 2 +- examples/models/llama/main.cpp | 11 +- examples/models/llama/runner/runner.cpp | 225 ++++++------ examples/models/llama/runner/runner.h | 39 ++- .../models/llama/runner/test/CMakeLists.txt | 28 ++ examples/models/llama/runner/test/TARGETS | 14 + examples/models/llama/runner/test/targets.bzl | 25 ++ .../models/llama/runner/test/test_runner.cpp | 323 ++++++++++++++++++ extension/android/jni/jni_layer_llama.cpp | 17 +- .../apple/Benchmark/Tests/LLaMA/LLaMATests.mm | 6 +- extension/llm/runner/text_prefiller.h | 17 + extension/llm/runner/text_token_generator.h | 17 + extension/llm/tokenizers | 2 +- 13 files changed, 598 insertions(+), 128 deletions(-) create mode 100644 examples/models/llama/runner/test/CMakeLists.txt create mode 100644 examples/models/llama/runner/test/TARGETS create mode 100644 examples/models/llama/runner/test/targets.bzl create mode 100644 examples/models/llama/runner/test/test_runner.cpp diff --git a/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm b/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm index 3618d05ec6c..c2f01bf17b1 100644 --- a/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm +++ b/examples/demo-apps/apple_ios/LLaMA/LLaMARunner/LLaMARunner/Exported/LLaMARunner.mm @@ -31,7 +31,7 @@ - (instancetype)initWithModelPath:(NSString*)modelPath self = [super init]; if (self) { [ExecuTorchLog.sharedLog addSink:self]; - _runner = std::make_unique( + _runner = example::Runner::create( modelPath.UTF8String, tokenizerPath.UTF8String); } return self; diff --git a/examples/models/llama/main.cpp b/examples/models/llama/main.cpp index d75b152be1f..1c1b6f62dc1 100644 --- a/examples/models/llama/main.cpp +++ b/examples/models/llama/main.cpp @@ -4,6 +4,7 @@ * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ #include @@ -80,18 +81,16 @@ int32_t main(int32_t argc, char** argv) { } #endif // create llama runner - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - example::Runner runner(model_path, tokenizer_path, data_path); + std::unique_ptr runner = + example::Runner::create(model_path, tokenizer_path, data_path); if (warmup) { - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - runner.warmup(prompt, /*max_new_tokens=*/seq_len); + runner->warmup(prompt, /*max_new_tokens=*/seq_len); } // generate executorch::extension::llm::GenerationConfig config{ .seq_len = seq_len, .temperature = temperature}; - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - runner.generate(prompt, config); + runner->generate(prompt, config); return 0; } diff --git a/examples/models/llama/runner/runner.cpp b/examples/models/llama/runner/runner.cpp index ef3681b74bc..119eedc704e 100644 --- a/examples/models/llama/runner/runner.cpp +++ b/examples/models/llama/runner/runner.cpp @@ -4,6 +4,7 @@ * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated */ // A simple llama2 runner that includes preprocessing and post processing logic. @@ -11,9 +12,6 @@ #include -#include -#include - #include #include @@ -62,125 +60,162 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer( } } // namespace -Runner::Runner( +std::unique_ptr Runner::create( const std::string& model_path, const std::string& tokenizer_path, - std::optional data_path) - // NOTE: we observed ~2x loading performance increase on iPhone 15 - // and a ~5% improvement on Galaxy S22 by switching to - // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors. - : tokenizer_path_(tokenizer_path), - metadata_({ - {kEnableDynamicShape, false}, - {kMaxSeqLen, 128}, - {kMaxContextLen, 128}, - {kUseKVCache, true}, - {kUseSDPAWithKVCache, false}, - }) { - if (data_path.has_value()) { - module_ = std::make_unique( - model_path, data_path.value(), Module::LoadMode::File); - } else { - module_ = std::make_unique(model_path, Module::LoadMode::File); - } + std::optional data_path, + float temperature) { ET_LOG( Info, "Creating LLaMa runner: model_path=%s, tokenizer_path=%s", model_path.c_str(), tokenizer_path.c_str()); -} -[[deprecated( - "This constructor is deprecated. Use the constructor without temperature parameter instead.")]] -Runner::Runner( - const std::string& model_path, - const std::string& tokenizer_path, - const float temperature, - std::optional data_path) - : Runner(model_path, tokenizer_path, std::move(data_path)) { - temperature_ = temperature; -} - -bool Runner::is_loaded() const { - return module_->is_loaded() && tokenizer_ && text_decoder_runner_ && - text_prefiller_ && text_token_generator_; -} - -Error Runner::load() { - if (is_loaded()) { - return Error::Ok; + // Create the Module + std::unique_ptr module; + if (data_path.has_value()) { + module = std::make_unique( + model_path, data_path.value(), Module::LoadMode::File); + } else { + module = std::make_unique(model_path, Module::LoadMode::File); } - ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward")); - // Load tokenizer. - tokenizer_ = load_tokenizer(tokenizer_path_); - if (tokenizer_ == nullptr) { + // Initialize metadata with default values + std::unordered_map metadata({ + {kEnableDynamicShape, false}, + {kMaxSeqLen, 128}, + {kMaxContextLen, 128}, + {kUseKVCache, true}, + {kUseSDPAWithKVCache, false}, + }); + + // Create and load tokenizer + std::unique_ptr<::tokenizers::Tokenizer> tokenizer = + load_tokenizer(tokenizer_path); + + // Fallback to BPE tokenizer if tiktoken fails + if (tokenizer == nullptr) { ET_LOG( Info, - "Failed to load %s as a Tiktoken artifact, trying BPE tokenizer", - tokenizer_path_.c_str()); - tokenizer_.reset(); - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>(); - auto err = tokenizer_->load(tokenizer_path_); - ET_CHECK_TK_OK_OR_RETURN_ERROR( - err, - "Failed to load %s as a llama2.c tokenizer artifact", - tokenizer_path_.c_str()); - return ::executorch::runtime::Error::InvalidArgument; + "Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types", + tokenizer_path.c_str()); + return nullptr; } ET_LOG(Info, "Reading metadata from model"); - metadata_[kBosId] = tokenizer_->bos_tok(); + // Set tokenizer-related metadata + metadata[kBosId] = tokenizer->bos_tok(); auto eos_ids = std::make_unique>( - std::unordered_set{tokenizer_->eos_tok()}); - metadata_[kVocabSize] = tokenizer_->vocab_size(); - - const auto method_names = - ET_UNWRAP(module_->method_names(), "Failed reading method names"); + std::unordered_set{tokenizer->eos_tok()}); + metadata[kVocabSize] = tokenizer->vocab_size(); + + // Read metadata from the model + auto method_names_result = module->method_names(); + if (method_names_result.error() != Error::Ok) { + ET_LOG(Error, "Failed reading method names"); + return nullptr; + } + const auto method_names = method_names_result.get(); - for (auto& pair : metadata_) { + for (auto& pair : metadata) { const auto& method_name = pair.first; auto& value = pair.second; if (method_names.count(method_name)) { - value = ET_UNWRAP(module_->get(method_name)) - .toScalar() - .to(); + auto get_result = module->get(method_name); + value = get_result.get().toScalar().to(); } else { ET_LOG( Info, - "Methond %s not found, using the default value %" PRId64, + "Method %s not found, using the default value %" PRId64, method_name.c_str(), value); } ET_LOG(Info, "Metadata: %s = %" PRId64, method_name.c_str(), value); } + + // Get EOS IDs if available if (method_names.count(kEosIds)) { eos_ids->clear(); - for (const auto& eos_id : ET_UNWRAP(module_->execute(kEosIds))) { + auto execute_result = module->execute(kEosIds); + if (execute_result.error() != Error::Ok) { + ET_LOG(Error, "Failed to execute %s", kEosIds); + return nullptr; + } + for (const auto& eos_id : execute_result.get()) { auto value = eos_id.toScalar().to(); eos_ids->emplace(value); ET_LOG(Info, "eos_id = %" PRId64, value); } } - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - text_decoder_runner_ = std::make_unique( - module_.get(), metadata_.at(kUseKVCache)); - text_prefiller_ = std::make_unique( - text_decoder_runner_.get(), - metadata_.at(kUseKVCache), - metadata_.at(kEnableDynamicShape), - metadata_.at(kMaxSeqLen)); - - text_token_generator_ = std::make_unique( - tokenizer_.get(), - text_decoder_runner_.get(), - metadata_.at(kUseKVCache), + + // Create text_decoder_runner. Use a shared_ptr so that it can be shared with + // TextPrefiller and TextTokenGenerator + auto text_decoder_runner = std::make_unique( + module.get(), metadata.at(kUseKVCache)); + + // Create text_prefiller + auto text_prefiller = std::make_unique( + text_decoder_runner.get(), + metadata.at(kUseKVCache), + metadata.at(kEnableDynamicShape), + metadata.at(kMaxSeqLen)); + + // Create text_token_generator with stats + auto stats = std::make_unique(); + auto text_token_generator = std::make_unique( + tokenizer.get(), + text_decoder_runner.get(), + metadata.at(kUseKVCache), std::move(eos_ids), - &stats_); + stats.get()); + + // Create and return the Runner instance + return std::make_unique( + std::move(metadata), + std::move(tokenizer), + std::move(module), + std::move(text_decoder_runner), + std::move(text_prefiller), + std::move(text_token_generator), + std::move(stats), + temperature); +} + +Runner::Runner( + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::Module> module, + std::unique_ptr<::executorch::extension::llm::TextDecoderRunner> + text_decoder_runner, + std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller, + std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> + text_token_generator, + std::unique_ptr<::executorch::extension::llm::Stats> stats, + float temperature) + : tokenizer_(std::move(tokenizer)), + metadata_(std::move(metadata)), + module_(std::move(module)), + text_decoder_runner_(std::move(text_decoder_runner)), + text_prefiller_(std::move(text_prefiller)), + text_token_generator_(std::move(text_token_generator)), + stats_(std::move(stats)), + temperature_(temperature) { + // Note: This constructor assumes that text_prefiller and text_token_generator + // already have references to the Module and TextDecoderRunner they need +} + +bool Runner::is_loaded() const { + return text_prefiller_->is_loaded() && text_token_generator_->is_loaded(); +} +Error Runner::load() { + if (is_loaded()) { + return Error::Ok; + } + ET_CHECK_OK_OR_RETURN_ERROR(text_prefiller_->load()); + ET_CHECK_OK_OR_RETURN_ERROR(text_token_generator_->load()); return Error::Ok; } @@ -201,9 +236,9 @@ Error Runner::generate( // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { - stats_.model_load_start_ms = llm::time_in_ms(); + stats_->model_load_start_ms = llm::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); - stats_.model_load_end_ms = llm::time_in_ms(); + stats_->model_load_end_ms = llm::time_in_ms(); } if (config.warming) { @@ -229,7 +264,7 @@ Error Runner::generate( // First token time only measures the time it takes to encode the prompt and // return a response token. - stats_.inference_start_ms = llm::time_in_ms(); + stats_->inference_start_ms = llm::time_in_ms(); shouldStop_ = false; ::tokenizers::Result> encode_res = tokenizer_->encode( @@ -270,8 +305,8 @@ Error Runner::generate( auto prefill_res = text_prefiller_->prefill(prompt_tokens, pos); ET_CHECK_OK_OR_RETURN_ERROR(prefill_res.error()); uint64_t cur_token = prefill_res.get(); - stats_.first_token_ms = llm::time_in_ms(); - stats_.prompt_eval_end_ms = llm::time_in_ms(); + stats_->first_token_ms = llm::time_in_ms(); + stats_->prompt_eval_end_ms = llm::time_in_ms(); // print the first token from prefill. No prev_token so use cur_token for it. wrapped_callback( @@ -292,7 +327,7 @@ Error Runner::generate( temperature_ == -1.0f ? config.temperature : temperature_, wrapped_callback)); - stats_.inference_end_ms = llm::time_in_ms(); + stats_->inference_end_ms = llm::time_in_ms(); if (!config.warming) { printf("\n"); } @@ -305,17 +340,17 @@ Error Runner::generate( RUNNER_ET_LOG(config.warming, "Max new tokens %i reached!", max_new_tokens); } - stats_.num_prompt_tokens = num_prompt_tokens; - stats_.num_generated_tokens = num_generated_tokens; + stats_->num_prompt_tokens = num_prompt_tokens; + stats_->num_generated_tokens = num_generated_tokens; if (config.warming) { ET_LOG(Info, "Warmup run finished!"); } else { // Do not print report during warmup - ::executorch::llm::print_report(stats_); + ::executorch::llm::print_report(*stats_); } if (stats_callback) { - stats_callback(stats_); + stats_callback(*stats_); } return Error::Ok; @@ -329,8 +364,8 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) { // Call generate with the warmup config Error err = generate(prompt, config); - // Reset stats after warmup - stats_.reset(); + // Reset stats after warmup, not resetting the std::unique_ptr! + stats_->reset(); return err; } diff --git a/examples/models/llama/runner/runner.h b/examples/models/llama/runner/runner.h index 97ffe4b98b7..e4e91db37d5 100644 --- a/examples/models/llama/runner/runner.h +++ b/examples/models/llama/runner/runner.h @@ -30,18 +30,26 @@ namespace example { class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { public: - explicit Runner( + // Static factory method to create a Runner instance + static std::unique_ptr create( const std::string& model_path, const std::string& tokenizer_path, - std::optional data_path = std::nullopt); + std::optional data_path = std::nullopt, + float temperature = -1.0f); - [[deprecated( - "This constructor is deprecated. Use the constructor without temperature parameter instead.")]] + // Constructor with dependency injection explicit Runner( - const std::string& model_path, - const std::string& tokenizer_path, - const float temperature, - std::optional data_path = std::nullopt); + std::unordered_map metadata, + std::unique_ptr<::tokenizers::Tokenizer> tokenizer, + std::unique_ptr<::executorch::extension::Module> module, + std::unique_ptr<::executorch::extension::llm::TextDecoderRunner> + text_decoder_runner, + std::unique_ptr<::executorch::extension::llm::TextPrefiller> + text_prefiller, + std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> + text_token_generator, + std::unique_ptr<::executorch::extension::llm::Stats> stats, + float temperature = -1.0f); bool is_loaded() const override; ::executorch::runtime::Error load() override; @@ -59,19 +67,22 @@ class ET_EXPERIMENTAL Runner : public executorch::extension::llm::IRunner { private: bool shouldStop_{false}; - // model - std::unique_ptr<::executorch::extension::Module> module_; - std::string tokenizer_path_; + // Components std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; std::unordered_map metadata_; + std::unique_ptr<::executorch::extension::Module> + module_; // Manage module's lifecycle, make sure it outlives + // text_decoder_runner_. std::unique_ptr<::executorch::extension::llm::TextDecoderRunner> - text_decoder_runner_; + text_decoder_runner_; // Manage text_decoder_runner_'s lifecycle, make + // sure it outlives text_prefiller_ & + // text_token_generator_. std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller_; std::unique_ptr<::executorch::extension::llm::TextTokenGenerator> text_token_generator_; - // stats - ::executorch::extension::llm::Stats stats_; + // Stats + std::unique_ptr<::executorch::extension::llm::Stats> stats_; // temperature. // Deprecated, we should rely on the temperature in GenerationConfig instead. diff --git a/examples/models/llama/runner/test/CMakeLists.txt b/examples/models/llama/runner/test/CMakeLists.txt new file mode 100644 index 00000000000..aa754b96da6 --- /dev/null +++ b/examples/models/llama/runner/test/CMakeLists.txt @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +cmake_minimum_required(VERSION 3.19) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) + +include(${EXECUTORCH_ROOT}/tools/cmake/Test.cmake) + +set(_test_srcs test_runner.cpp) + +et_cxx_test( + test_runner + SOURCES + ${_test_srcs} + EXTRA_LIBS + executorch +) diff --git a/examples/models/llama/runner/test/TARGETS b/examples/models/llama/runner/test/TARGETS new file mode 100644 index 00000000000..97de7abe9b1 --- /dev/null +++ b/examples/models/llama/runner/test/TARGETS @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/examples/models/llama/runner/test/targets.bzl b/examples/models/llama/runner/test/targets.bzl new file mode 100644 index 00000000000..3b02360da08 --- /dev/null +++ b/examples/models/llama/runner/test/targets.bzl @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + runtime.cxx_test( + name = "test_runner", + srcs = ["test_runner.cpp"], + deps = [ + "//executorch/examples/models/llama/runner:runner", + "//executorch/extension/llm/runner:irunner", + "//executorch/extension/llm/runner:stats", + "//executorch/extension/llm/runner:text_token_generator", + "//executorch/extension/llm/runner:text_decoder_runner", + "//executorch/extension/llm/runner:text_prefiller", + "//executorch/extension/module:module", + "//executorch/runtime/core:core", + "//executorch/runtime/platform:platform", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + ], + ) diff --git a/examples/models/llama/runner/test/test_runner.cpp b/examples/models/llama/runner/test/test_runner.cpp new file mode 100644 index 00000000000..f158ca8515d --- /dev/null +++ b/examples/models/llama/runner/test/test_runner.cpp @@ -0,0 +1,323 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated + */ + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using namespace example; +using executorch::extension::llm::GenerationConfig; +using executorch::extension::llm::Stats; +using executorch::extension::llm::TextDecoderRunner; +using executorch::extension::llm::TextPrefiller; +using executorch::extension::llm::TextTokenGenerator; +using executorch::runtime::Error; +using executorch::runtime::Result; +using executorch::runtime::testing::TensorFactory; +// Mock classes for dependencies +class MockTokenizer : public ::tokenizers::Tokenizer { + public: + MOCK_METHOD(::tokenizers::Error, load, (const std::string&), ()); + MOCK_METHOD(bool, is_loaded, (), (const)); + MOCK_METHOD( + ::tokenizers::Result>, + encode, + (const std::string&, int8_t, int8_t), + (const)); + MOCK_METHOD( + ::tokenizers::Result, + decode, + (uint64_t, uint64_t), + (const)); + MOCK_METHOD(uint64_t, bos_tok, (), (const)); + MOCK_METHOD(uint64_t, eos_tok, (), (const)); + MOCK_METHOD(uint64_t, vocab_size, (), (const)); +}; + +class MockModule : public ::executorch::extension::Module { + public: + MockModule() : Module("") {} + MOCK_METHOD( + Error, + load, + (const executorch::runtime::Program::Verification), + (override)); + MOCK_METHOD(bool, is_loaded, (), (const, override)); + MOCK_METHOD( + Result>, + execute, + (const std::string&, const std::vector&), + (override)); +}; + +class MockTextDecoderRunner : public TextDecoderRunner { + public: + MockTextDecoderRunner() : TextDecoderRunner(nullptr, false) {} + MOCK_METHOD( + Result, + step, + (executorch::extension::TensorPtr&, executorch::extension::TensorPtr&), + ()); + MOCK_METHOD(bool, is_method_loaded, (), ()); + MOCK_METHOD(Result, prefill, (std::vector&, int64_t), ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); +}; + +class MockTextPrefiller : public TextPrefiller { + public: + explicit MockTextPrefiller(TextDecoderRunner* text_decoder_runner) + : TextPrefiller(text_decoder_runner, false, false, 0) {} + MOCK_METHOD( + Result, + prefill, + (std::vector&, int64_t&), + ()); + MOCK_METHOD(::executorch::runtime::Error, load, (), ()); + MOCK_METHOD(bool, is_loaded, (), ()); +}; + +// Callback counter class for tests +class CallbackCounter { + public: + CallbackCounter() : count_(0) {} + + void callback(const std::string& token) { + (void)token; + count_++; + } + + int getCount() const { + return count_; + } + + private: + int count_; +}; + +// Test fixture for Runner tests - minimal setup +class RunnerTest : public Test { + protected: + // Helper functions to create and set up mock objects + std::unique_ptr createMockTokenizer() { + auto tokenizer = std::make_unique(); + + // Set up default behavior for the tokenizer + ON_CALL(*tokenizer, is_loaded).WillByDefault(Return(true)); + ON_CALL(*tokenizer, encode) + .WillByDefault([](const std::string&, int8_t, int8_t) { + return ::tokenizers::Result>( + std::vector{1, 2, 3}); + }); + + ON_CALL(*tokenizer, decode).WillByDefault([](uint64_t, uint64_t) { + return ::tokenizers::Result("token"); + }); + + ON_CALL(*tokenizer, bos_tok()).WillByDefault(Return(1)); + ON_CALL(*tokenizer, eos_tok()).WillByDefault(Return(2)); + ON_CALL(*tokenizer, vocab_size()).WillByDefault(Return(100)); + + return tokenizer; + } + + std::unique_ptr createMockTextDecoderRunner() { + auto text_decoder_runner = std::make_unique(); + ON_CALL(*text_decoder_runner, step) + .WillByDefault([&](executorch::extension::TensorPtr&, + executorch::extension::TensorPtr&) { + return Result(tensor); + }); + ON_CALL(*text_decoder_runner, is_method_loaded()) + .WillByDefault(Return(true)); + return text_decoder_runner; + } + + std::unique_ptr createMockTextPrefiller( + TextDecoderRunner* text_decoder_runner) { + auto text_prefiller = + std::make_unique(text_decoder_runner); + ON_CALL(*text_prefiller, is_loaded()).WillByDefault(Return(true)); + // Set up default behavior for the text prefiller + ON_CALL(*text_prefiller, prefill) + .WillByDefault([](const std::vector&, int64_t) { + return Result(4); + }); + + return text_prefiller; + } + + std::unique_ptr createTextTokenGenerator( + ::tokenizers::Tokenizer* tokenizer, + TextDecoderRunner* text_decoder_runner, + Stats* stats) { + auto eos_ids = std::make_unique>( + std::unordered_set{100}); + return std::make_unique( + tokenizer, + text_decoder_runner, + true, // use_kv_cache + std::move(eos_ids), + stats); + } + + std::unordered_map createDefaultMetadata() { + return { + {"enable_dynamic_shape", false}, + {"get_max_seq_len", 128}, + {"get_max_context_len", 128}, + {"use_kv_cache", true}, + }; + } + + protected: + Stats stats_; + std::vector return_logits_ = {0.1f, 0.2f, 0.3f, 0.4f}; + TensorFactory tf; + executorch::aten::Tensor tensor = tf.make({1, 4}, return_logits_); +}; + +// Test that generate() calls the token callback exactly max_new_tokens times +TEST_F(RunnerTest, GenerateCallsCallbackExactlyMaxNewTokensTimes) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + EXPECT_CALL(*tokenizer, encode(_, _, _)) + .WillOnce(Return(::tokenizers::Result>( + std::vector{1, 2, 3}))); + + // Set up expectations for the text prefiller + EXPECT_CALL(*text_prefiller, prefill(_, _)) + .WillOnce(Return(Result(4))); + + // Set up expectations for load methods + EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + // Create a real TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::make_unique(), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator), + std::move(stats)); + + // Load + runner.load(); + + // Set up the generation config with a specific max_new_tokens value + GenerationConfig config; + config.max_new_tokens = 10; + config.echo = false; + + // Create a callback counter + CallbackCounter counter; + + // Call generate with our callback + Error err = runner.generate( + "test prompt", config, [&counter](const std::string& token) { + counter.callback(token); + }); + + // Verify the callback was called exactly max_new_tokens times + // The first token is generated by prefill, and the rest by the token + // generator + EXPECT_EQ(counter.getCount(), config.max_new_tokens); + EXPECT_EQ(err, Error::Ok); +} + +// Test that warmup() calls generate with the warming flag set +TEST_F(RunnerTest, WarmupCallsGenerateWithWarmingFlag) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + // Set up expectations for the tokenizer encode method + EXPECT_CALL(*tokenizer, encode(_, _, _)) + .WillOnce(Return(::tokenizers::Result>( + std::vector{1, 2, 3}))); + + // Set up expectations for the text prefiller + EXPECT_CALL(*text_prefiller, prefill(_, _)) + .WillOnce(Return(Result(4))); + + EXPECT_CALL(*text_prefiller, is_loaded()).WillRepeatedly(Return(true)); + + std::unique_ptr stats = + std::make_unique(); + // Create a TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::move(tokenizer), + std::make_unique(), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator), + std::move(stats)); + + // Load + runner.load(); + + // Call warmup + Error err = runner.warmup("test prompt", 5); + + // Verify the result + EXPECT_EQ(err, Error::Ok); +} + +// Test that is_loaded() returns true when components are initialized +TEST_F(RunnerTest, IsLoadedReturnsTrueWhenComponentsInitialized) { + // Create mock instances using helper functions + auto tokenizer = createMockTokenizer(); + auto text_decoder_runner = createMockTextDecoderRunner(); + auto text_prefiller = createMockTextPrefiller(text_decoder_runner.get()); + + std::unique_ptr stats = + std::make_unique(); + // Create a real TextTokenGenerator + auto text_token_generator = createTextTokenGenerator( + tokenizer.get(), text_decoder_runner.get(), stats.get()); + + // Create a Runner with our mocked components + Runner runner( + createDefaultMetadata(), + std::unique_ptr<::tokenizers::Tokenizer>(tokenizer.release()), + std::make_unique(), + std::move(text_decoder_runner), + std::unique_ptr<::executorch::extension::llm::TextPrefiller>( + text_prefiller.release()), + std::move(text_token_generator), + std::move(stats)); + + // Load + runner.load(); + + // Verify is_loaded returns true + EXPECT_TRUE(runner.is_loaded()); +} diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0e6731dfcd5..03e26f089db 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -165,16 +165,13 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { tokenizer_path->toStdString().c_str(), temperature); } else if (model_type_category == MODEL_TYPE_CATEGORY_LLM) { - if (data_path != nullptr) { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - data_path->toStdString().c_str()); - } else { - runner_ = std::make_unique( - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str()); - } + std::optional data_path_str = data_path + ? std::optional{data_path->toStdString()} + : std::nullopt; + runner_ = example::Runner::create( + model_path->toStdString(), + tokenizer_path->toStdString(), + data_path_str); #if defined(EXECUTORCH_BUILD_MEDIATEK) } else if (model_type_category == MODEL_TYPE_MEDIATEK_LLAMA) { runner_ = std::make_unique( diff --git a/extension/benchmark/apple/Benchmark/Tests/LLaMA/LLaMATests.mm b/extension/benchmark/apple/Benchmark/Tests/LLaMA/LLaMATests.mm index fe5e816d6a7..e53a457939c 100644 --- a/extension/benchmark/apple/Benchmark/Tests/LLaMA/LLaMATests.mm +++ b/extension/benchmark/apple/Benchmark/Tests/LLaMA/LLaMATests.mm @@ -74,8 +74,12 @@ @implementation LLaMATests NSString *tokenizerPath = resources[@"tokenizer"]; return @{ @"generate" : ^(XCTestCase *testCase){ - auto __block runner = std::make_unique( + auto __block runner = example::Runner::create( modelPath.UTF8String, tokenizerPath.UTF8String); + if (!runner) { + XCTFail("Failed to create runner"); + return; + } const auto status = runner->load(); if (status != Error::Ok) { XCTFail("Load failed with error %i", status); diff --git a/extension/llm/runner/text_prefiller.h b/extension/llm/runner/text_prefiller.h index 28632ad856a..49b2c867167 100644 --- a/extension/llm/runner/text_prefiller.h +++ b/extension/llm/runner/text_prefiller.h @@ -49,6 +49,23 @@ class ET_EXPERIMENTAL TextPrefiller { std::vector& prompt_tokens, int64_t& start_pos); + /** + * Load the necessary resources for the TextPrefiller. + * This method should be called before using the prefill methods. + */ + ::executorch::runtime::Error load() { + return text_decoder_runner_->load(); + } + + /** + * Check if the TextPrefiller has been successfully loaded. + * @return True if the resources are loaded, false otherwise. + */ + bool inline is_loaded() const { + // Implementation to check if resources are loaded + return text_decoder_runner_->is_method_loaded(); + } + private: /** * Note: TextPrefiller does not own the TextDecoderRunner instance. diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 38873e25fc1..839ad195c7e 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -137,6 +137,23 @@ class ET_EXPERIMENTAL TextTokenGenerator { should_stop_ = true; } + /** + * Load the necessary resources for TextTokenGenerator. + * This method should be called before using the generate() method. + */ + ::executorch::runtime::Error load() { + return text_decoder_runner_->load(); + } + + /** + * Check if the TextTokenGenerator has been successfully loaded. + * @return True if the resources are loaded, false otherwise. + */ + bool inline is_loaded() const { + // Implementation to check if resources are loaded + return tokenizer_->is_loaded() && text_decoder_runner_->is_method_loaded(); + } + private: /** * Note: TextTokenGenerator does not own the tokenizer_ and diff --git a/extension/llm/tokenizers b/extension/llm/tokenizers index 9ceef562d5c..57eb76d71d6 160000 --- a/extension/llm/tokenizers +++ b/extension/llm/tokenizers @@ -1 +1 @@ -Subproject commit 9ceef562d5c941eb6aea5476c768d0419962bc0c +Subproject commit 57eb76d71d6dde5396520c7d35142eb868994e06