diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index aee6d2ff7655c..64b53c2912be0 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -1029,7 +1029,7 @@ if (onnxruntime_USE_QNN) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - $ + $ $/onnxruntime/capi/ ) if (EXISTS "${onnxruntime_QNN_HOME}/Qualcomm AI Hub Proprietary License.pdf") diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index cb5a28f82de66..2ed7923941643 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1289,31 +1289,34 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if(onnxruntime_USE_QNN) #qnn ctx generator - set(onnxruntime_qnn_ctx_gen_src_dir ${TEST_SRC_DIR}/qnn_ctx_gen) - set(onnxruntime_qnn_ctx_gen_src_patterns - "${onnxruntime_qnn_ctx_gen_src_dir}/*.cc" - "${onnxruntime_qnn_ctx_gen_src_dir}/*.h") + set(ep_weight_sharing_ctx_gen_src_dir ${TEST_SRC_DIR}/ep_weight_sharing_ctx_gen) + set(ep_weight_sharing_ctx_gen_src_patterns + "${ep_weight_sharing_ctx_gen_src_dir}/*.cc" + "${ep_weight_sharing_ctx_gen_src_dir}/*.h") - file(GLOB onnxruntime_qnn_ctx_gen_src CONFIGURE_DEPENDS - ${onnxruntime_qnn_ctx_gen_src_patterns} + file(GLOB ep_weight_sharing_ctx_gen_src CONFIGURE_DEPENDS + ${ep_weight_sharing_ctx_gen_src_patterns} ) - onnxruntime_add_executable(onnxruntime_qnn_ctx_gen ${onnxruntime_qnn_ctx_gen_src}) - target_include_directories(onnxruntime_qnn_ctx_gen PRIVATE ${onnx_test_runner_src_dir} ${ONNXRUNTIME_ROOT} - ${onnxruntime_graph_header} ${onnxruntime_exec_src_dir} - ${CMAKE_CURRENT_BINARY_DIR}) + onnxruntime_add_executable(ep_weight_sharing_ctx_gen ${ep_weight_sharing_ctx_gen_src}) + target_include_directories(ep_weight_sharing_ctx_gen PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) if (WIN32) - target_compile_options(onnxruntime_qnn_ctx_gen PRIVATE ${disabled_warnings}) + target_compile_options(ep_weight_sharing_ctx_gen PRIVATE ${disabled_warnings}) if (NOT DEFINED SYS_PATH_LIB) set(SYS_PATH_LIB shlwapi) endif() endif() - if(WIN32) - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE debug dbghelp advapi32) + if (onnxruntime_BUILD_SHARED_LIB) + set(ep_weight_sharing_ctx_gen_libs onnxruntime_common onnxruntime ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE ${ep_weight_sharing_ctx_gen_libs}) + if (WIN32) + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE debug dbghelp advapi32) + endif() + else() + target_link_libraries(ep_weight_sharing_ctx_gen PRIVATE onnxruntime_session ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE}) endif() - target_link_libraries(onnxruntime_qnn_ctx_gen PRIVATE onnx_test_runner_common onnxruntime_test_utils onnxruntime_common onnxruntime_graph onnxruntime_session onnxruntime_providers onnxruntime_framework onnxruntime_util onnxruntime_mlas onnxruntime_optimizer onnxruntime_flatbuffers onnx_test_data_proto ${onnxruntime_test_providers_libs} ${onnxruntime_EXTERNAL_LIBRARIES} ${GETOPT_LIB_WIDE} ${SYS_PATH_LIB} ${CMAKE_DL_LIBS}) - set_target_properties(onnxruntime_qnn_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") + set_target_properties(ep_weight_sharing_ctx_gen PROPERTIES FOLDER "ONNXRuntimeTest") endif() # shared lib diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 117a2cdabca2f..af1f9c04b2831 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -315,9 +315,12 @@ static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed // in case user need to merge/connect multiple EPContext nodes in one model static const char* const kOrtSessionOptionEpContextNodeNamePrefix = "ep.context_node_name_prefix"; -// Share EP related resources across EPs +// Share EP related resources across sessions static const char* const kOrtSessionOptionShareEpContexts = "ep.share_ep_contexts"; +// Stop to share EP related resources across sessions from then on +static const char* const kOrtSessionOptionStopShareEpContexts = "ep.stop_share_ep_contexts"; + // Use this config when dumping EP context model with an external initializers file // All initializers will be inside the external data file if specified, otherwise all in Onnx file static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFileName = diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index bcde69beceef7..26d792c008edc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -470,8 +470,10 @@ Status QnnBackendManager::InitializeProfiling() { QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to basic."; } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; } Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 99a6f51f6f712..1ad17d96e9322 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -195,6 +195,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio share_ep_contexts_ = config_options->GetConfigOrDefault(kOrtSessionOptionShareEpContexts, "0") == "1"; LOGS_DEFAULT(VERBOSE) << "User specified option - share EP contexts across sessions: " << share_ep_contexts_; + + stop_share_ep_contexts_ = + config_options->GetConfigOrDefault(kOrtSessionOptionStopShareEpContexts, "0") == "1"; + LOGS_DEFAULT(VERBOSE) << "User specified option - stop share EP contexts across sessions: " << stop_share_ep_contexts_; } static const std::string BACKEND_PATH = "backend_path"; @@ -384,17 +388,27 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } - qnn_backend_manager_ = qnn::QnnBackendManager::Create( - qnn::QnnBackendManagerConfig{backend_path, - profiling_level_etw, - profiling_level, - profiling_file_path, - context_priority, - qnn_saver_path, - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing}); + // For context binary generation with weight sharing enabled, use the QnnBackendManager from the shared context if it exits + // So that all graphs from later sessions will be compiled into the same QNN context + if (context_cache_enabled_ && share_ep_contexts_ && SharedContext::GetInstance().GetSharedQnnBackendManager()) { + qnn_backend_manager_ = SharedContext::GetInstance().GetSharedQnnBackendManager(); + // Clear the QnnBackendManager from singleton to stop the resource share + if (stop_share_ep_contexts_) { + SharedContext::GetInstance().ResetSharedQnnBackendManager(); + } + } else { + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); + } #if defined(_WIN32) if (onnxruntime::logging::EtwRegistrationManager::SupportsETW()) { @@ -1037,6 +1051,12 @@ Status QNNExecutionProvider::Compile(const std::vector& fused qnn_context_embed_mode_, max_spill_fill_buffer_size, logger)); + + if (share_ep_contexts_ && !stop_share_ep_contexts_ && + nullptr == SharedContext::GetInstance().GetSharedQnnBackendManager()) { + ORT_RETURN_IF_NOT(SharedContext::GetInstance().SetSharedQnnBackendManager(qnn_backend_manager_), + "Failed to set shared QnnBackendManager."); + } } return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 31c34855ca4c0..0f40e40c2fa36 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -90,6 +90,7 @@ class QNNExecutionProvider : public IExecutionProvider { uint32_t default_rpc_control_latency_ = 0; bool enable_HTP_FP16_precision_ = true; bool share_ep_contexts_ = false; + bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 81de357dbe677..277a484ad8528 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -61,13 +61,39 @@ class SharedContext { return graph_exist; } + bool SetSharedQnnBackendManager(std::shared_ptr& qnn_backend_manager) { + const std::lock_guard lock(mtx_); + + if (qnn_backend_manager_ != nullptr) { + if (qnn_backend_manager_ == qnn_backend_manager) { + return true; + } + return false; + } + qnn_backend_manager_ = qnn_backend_manager; + return true; + } + + std::shared_ptr GetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + return qnn_backend_manager_; + } + + void ResetSharedQnnBackendManager() { + const std::lock_guard lock(mtx_); + qnn_backend_manager_.reset(); + } + private: SharedContext() = default; ~SharedContext() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); + // Used for passing through QNN models (deserialized from context binary) across sessions std::vector> shared_qnn_models_; + // Used for compiling multiple models into same QNN context binary + std::shared_ptr qnn_backend_manager_; // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; diff --git a/onnxruntime/test/qnn_ctx_gen/README.md b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md similarity index 82% rename from onnxruntime/test/qnn_ctx_gen/README.md rename to onnxruntime/test/ep_weight_sharing_ctx_gen/README.md index 97ab89d79cbd2..be1a1fe039366 100644 --- a/onnxruntime/test/qnn_ctx_gen/README.md +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/README.md @@ -2,17 +2,19 @@ This tool provides the way to generate Onnx models that wraps QNN context binary warpt with weight sharing enabled. The options to use with the tool are listed below: -`onnxruntime_qnn_ctx_gen [options...] model_path,model_path` +`ep_weight_sharing_ctx_gen [options...] model_1_path,model_2_path` -./onnxruntime_qnn_ctx_gen -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" -C "ep.context_enable|1 ep.context_embed_mode|0" /mnt/c/model1.onnx,/mnt/c/model2.onnx +./ep_weight_sharing_ctx_gen -e qnn -v -i "soc_model|60 htp_graph_finalization_optimization_mode|3" /mnt/c/model1.onnx,/mnt/c/model2.onnx Options: - + + -e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider qnn, tensorrt, openvino, vitisai. Default is qnn. + -v: Show verbose information. -C: [session_config_entries]: Specify session configuration entries as key-value pairs: -C "| |" Refer to onnxruntime_session_options_config_keys.h for valid keys and values. - [Example] -C "ep.context_enable|1 ep.context_embed_mode|0" + [Example] -C "ep.context_enable|1 ep.context_embed_mode|0". These are set as default so can be ignored. -i: [provider_options]: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: [Usage]: -i '| |' diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc similarity index 68% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.cc rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index 24c343c7b9541..bf21d54ccde41 100644 --- a/onnxruntime/test/qnn_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) 2023 NVIDIA Corporation. // Licensed under the MIT License. #include "command_args_parser.h" @@ -29,28 +28,30 @@ namespace qnnctxgen { /*static*/ void CommandLineParser::ShowUsage() { printf( - "onnxruntime_qnn_ctx_gen [options...] model1_path,model2_path\n" - "Example: ./onnxruntime_qnn_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" + "ep_weight_sharing_ctx_gen [options...] model1_path,model2_path\n" + "Example: ./ep_weight_sharing_ctx_gen -i \"soc_model|60 htp_graph_finalization_optimization_mode|3\" -C \"ep.context_node_name_prefix|_part1\" ./model1.onnx,./model2.onnx\n" "Options:\n" + "\t-e [qnn|tensorrt|openvino|vitisai]: Specifies the compile based provider 'qnn','tensorrt','openvino', 'vitisai'. " + "Default:'qnn'.\n" "\t-v: Show verbose information.\n" "\t-C: Specify session configuration entries as key-value pairs: -C \"| |\" \n" "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t Force ep.context_enable to 1 and ep.context_embed_mode to 0. Change ep.context_file_path is not allowed." "\t [Example] -C \"ep.context_node_name_prefix|_part1\" \n" - "\t-i: Specify QNN EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" "\t [Usage]: -i '| |'\n" "\n" - "\t [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" - "\t [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" - "\t [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" - "\t [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" - "\t [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" - "\t [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" + "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. default to HTP backend\n" + "\t [QNN only] [vtcm_mb]: QNN VTCM size in MB. default to 0(not set).\n" + "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. eg: '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [enable_htp_fp16_precision]: Enable the HTP_FP16 precision so that the float32 model will be inferenced with fp16 precision. \n" "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" - "\t [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" - "\t [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" - "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" - "\t [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." + "\t [QNN only] [enable_htp_weight_sharing]: Allows common weights across graphs to be shared and stored in a single context binary. Defaults to '1' (enabled).\n" + "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" + "\t Defaults to '1' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary." "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -109,8 +110,22 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(TestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("o:u:i:C:vh"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("e:o:u:i:C:vh"))) != -1) { switch (ch) { + case 'e': + if (!CompareCString(optarg, ORT_TSTR("qnn"))) { + test_config.machine_config.provider_type_name = onnxruntime::kQnnExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("openvino"))) { + test_config.machine_config.provider_type_name = onnxruntime::kOpenVINOExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("tensorrt"))) { + test_config.machine_config.provider_type_name = onnxruntime::kTensorrtExecutionProvider; + } else if (!CompareCString(optarg, ORT_TSTR("vitisai"))) { + test_config.machine_config.provider_type_name = onnxruntime::kVitisAIExecutionProvider; + } else { + fprintf(stderr, "The execution provider is not included in this tool.\n"); + return false; + } + break; case 'v': test_config.run_config.f_verbose = true; break; @@ -162,7 +177,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer'])"); } - test_config.run_config.qnn_options[key] = value; + test_config.run_config.provider_options[key] = value; } break; } diff --git a/onnxruntime/test/qnn_ctx_gen/command_args_parser.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h similarity index 100% rename from onnxruntime/test/qnn_ctx_gen/command_args_parser.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.h diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc new file mode 100644 index 0000000000000..104cdbdfd5abc --- /dev/null +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/main.cc @@ -0,0 +1,247 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_configuration.h" +#include "command_args_parser.h" + +// onnxruntime dependencies +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" + +// onnx dependencies +#include "onnx/onnx_pb.h" +#include + +using namespace onnxruntime; +using ProviderOptions = std::unordered_map; + +// from the last context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models +// get the max spill fill buffer size +static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, + std::string& last_ctx_bin_file, + int64_t& max_size) { + max_size = 0; + + onnx::ModelProto model; + std::ifstream onnx_file_stream(last_onnx_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + + for (auto& node : model.graph().node()) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + for (auto& attr : node.attribute()) { + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + } + if (attr.name() == "ep_cache_context") { + last_ctx_bin_file = attr.s(); + } + } + if (is_main_context) { + return; + } + } + } + + onnx_file_stream.close(); +} + +// Update generated context cache Onnx model to make the main EPContext node point to +// the last QNN context binary file +// Remove not used QNN context binary file, only keep the last one which contains all graphs +static void UpdateEpContextModel(const std::vector>& ep_ctx_files, + const std::string& last_qnn_ctx_binary_file_name, + int64_t max_size) { + for (auto ep_ctx_file : ep_ctx_files) { + onnx::ModelProto model; + std::ifstream onnx_file_stream(ep_ctx_file, std::ios_base::binary); + model.ParseFromIstream(&onnx_file_stream); + onnx_file_stream.close(); + + for (auto& node : *(model.mutable_graph()->mutable_node())) { + if (node.op_type() == "EPContext") { + int64_t is_main_context = 0; + std::string old_qnn_ctx_binary_file_name; + int max_size_index = 0; + int ep_context_index = 0; + for (auto i = 0; i < node.attribute_size(); ++i) { + auto& attr = node.attribute()[i]; + if (attr.name() == "main_context") { + is_main_context = attr.i(); + } + if (attr.name() == "max_size") { + max_size = attr.i(); + max_size_index = i; + } + if (attr.name() == "ep_cache_context") { + old_qnn_ctx_binary_file_name = attr.s(); + ep_context_index = 0; + } + } + if (is_main_context) { + auto path_str = ToPathString(ep_ctx_file); + auto path = std::filesystem::path(path_str); + auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); + std::remove(file_path.string().c_str()); + + node.mutable_attribute(max_size_index)->set_i(max_size); + node.mutable_attribute(ep_context_index)->set_s(last_qnn_ctx_binary_file_name); + } + } + } + + // re-write the onnx ctx file + std::ofstream onnx_file_ostream(ep_ctx_file, std::ios_base::binary); + model.SerializeToOstream(&onnx_file_ostream); + onnx_file_ostream.close(); + } +} + +#ifdef _WIN32 +int real_main(int argc, wchar_t* argv[]) { +#else +int real_main(int argc, char* argv[]) { +#endif + qnnctxgen::TestConfig test_config; + if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { + qnnctxgen::CommandLineParser::ShowUsage(); + return -1; + } + + OrtLoggingLevel logging_level = test_config.run_config.f_verbose + ? ORT_LOGGING_LEVEL_VERBOSE + : ORT_LOGGING_LEVEL_ERROR; + Ort::Env env(logging_level, "ep_weight_sharing"); + + ORT_TRY { + Ort::SessionOptions so; + so.SetLogId("ep_weight_sharing_ctx_gen_session_logger"); + // Set default session option to dump EPContext model with non-embed mode + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + ProviderOptions provider_options; + + for (auto it : test_config.run_config.provider_options) { + provider_options[it.first] = it.second; + } + + for (auto it : test_config.run_config.session_config_entries) { + if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { + std::cerr << "Need to enable ep context cache." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { + std::cerr << "Only support non-embed model for weight sharing." << std::endl; + continue; + } + if (it.first == kOrtSessionOptionEpContextFilePath) { + std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; + continue; + } + so.AddConfigEntry(it.first.c_str(), it.second.c_str()); + } + + for (auto model_path : test_config.model_file_paths) { + std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; + } + + // Generate context cache model files with QNN context binary files + // The context binary file generated later includes all graphs from previous models + { + std::string provider_name_ = test_config.machine_config.provider_type_name; + if (provider_name_ == onnxruntime::kQnnExecutionProvider) { +#ifdef USE_QNN +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // set default QNN EP option to enable weight sharing if not set by user + const std::string enable_htp_weight_sharing = "enable_htp_weight_sharing"; + if (provider_options.find(enable_htp_weight_sharing) == provider_options.end()) { + provider_options[enable_htp_weight_sharing] = "1"; + } + so.AppendExecutionProvider("QNN", provider_options); +#else + ORT_THROW("QNN is not supported in this build\n"); +#endif + } else if (!provider_name_.empty()) { + ORT_THROW("This execution provider is not included in this tool.\n"); + } + + size_t total_file_count = test_config.model_file_paths.size(); + for (size_t i = 0; i < total_file_count; ++i) { + auto model_path = test_config.model_file_paths[i]; + std::cout << "Generating context cache model for: " << ToUTF8String(model_path) << std::endl; + if (i == total_file_count - 1) { + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + } + Ort::Session session(env, model_path.c_str(), so); + } + } + + std::cout << "Start to update the generated Onnx model." << std::endl; + std::vector> ep_ctx_files; + ep_ctx_files.reserve(test_config.model_file_paths.size()); + for (auto model_path : test_config.model_file_paths) { + auto pos = model_path.find_last_of(ORT_TSTR(".")); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + ORT_TSTR("_ctx.onnx"); + } else { + model_path = model_path + ORT_TSTR("_ctx.onnx"); + } + ep_ctx_files.push_back(model_path); + } + + // Get the last context binary file name + std::string last_qnn_ctx_binary_file_name; + int64_t max_size = 0; + GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); + std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; + if (last_qnn_ctx_binary_file_name.empty()) { + throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); + } + ep_ctx_files.pop_back(); + + // Update generated context cache Onnx model to make the main EPContext node point to + // the last QNN context binary file + // Remove not used QNN context binary file, only keep the last one only which contains all graphs + UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); + } + ORT_CATCH(const Ort::Exception& e) { + std::cerr << "Failed to generate context cache file: " << e.what(); + return -1; + } + + std::cout << "Generation done!"; + return 0; +} + +#ifdef _WIN32 +int wmain(int argc, wchar_t* argv[]) { +#else +int main(int argc, char* argv[]) { +#endif + int retval = -1; + ORT_TRY { + retval = real_main(argc, argv); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + fprintf(stderr, "%s\n", ex.what()); + retval = -1; + }); + } + + ::google::protobuf::ShutdownProtobufLibrary(); + + return retval; +} diff --git a/onnxruntime/test/qnn_ctx_gen/test_configuration.h b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h similarity index 75% rename from onnxruntime/test/qnn_ctx_gen/test_configuration.h rename to onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h index bf4c7061a3484..198d03211f561 100644 --- a/onnxruntime/test/qnn_ctx_gen/test_configuration.h +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/test_configuration.h @@ -14,15 +14,20 @@ namespace onnxruntime { namespace qnnctxgen { +struct MachineConfig { + std::string provider_type_name{onnxruntime::kQnnExecutionProvider}; +}; + struct RunConfig { bool f_verbose{false}; std::unordered_map session_config_entries; - std::unordered_map qnn_options; + std::unordered_map provider_options; }; struct TestConfig { std::vector> model_file_paths; RunConfig run_config; + MachineConfig machine_config; }; } // namespace qnnctxgen diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index e50dd7c214240..3dec74599abdf 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -43,6 +43,35 @@ static const std::string& GetNodeAttr(const Node& node, const std::string& attr_ return default_val; } +// from the context cache Onnx model, find the EPContext node with main_context=1, +// and get the QNN context binary file name +static void GetContextBinaryFileName(const std::string onnx_ctx_file, + std::string& last_ctx_bin_file, + const Logger& logger) { + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(onnx_ctx_file), ctx_model, nullptr, logger)); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); + if (1 == is_main_context) { + last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); + return; + } + } + } +} + +// Get context binary file name from Context model file and remove it with the context model file +void CleanUpCtxFile(std::string context_file_path) { + std::string qnn_ctx_binary_file_name; + GetContextBinaryFileName(context_file_path, qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); + + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name.c_str()), 0); + ASSERT_EQ(std::remove(context_file_path.c_str()), 0); +} + // Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | @@ -123,22 +152,22 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); int ep_context_node_count = 0; int non_ep_context_node_count = 0; std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); auto& ctx_graph = ctx_model->MainGraph(); for (auto& node : ctx_graph.Nodes()) { if (node.OpType() == "EPContext") { @@ -156,7 +185,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::SessionOptions so2; // context file path is required if it's non-embed mode and the model is loaded from memory - so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so2.AppendExecutionProvider("QNN", provider_options); std::string ctx_model_data; @@ -164,7 +193,7 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary @@ -237,7 +266,7 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file) { // clean up ASSERT_EQ(std::remove(model_with_ext.c_str()), 0); ASSERT_EQ(std::remove(model_ext_file_full_path.c_str()), 0); - ASSERT_EQ(std::remove(ep_context_model_file.c_str()), 0); + CleanUpCtxFile(ep_context_model_file); } // Set the external initializer size threshold to 1024 so FusedMatMul (which fallback on CPU) @@ -444,21 +473,21 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_model_file.c_str()); Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Generate context cache model from the ONNX models with 2 inputs. @@ -481,26 +510,26 @@ TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); auto inputs = model->MainGraph().GetInputs(); EXPECT_TRUE(inputs.size() == 2); EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { @@ -519,20 +548,20 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { auto& logging_manager = DefaultLoggingManager(); logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + const std::string context_model_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; Ort::SessionOptions so; so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_model_file.c_str()); so.AddConfigEntry(kOrtSessionOptionEpContextNodeNamePrefix, node_name_prefix.c_str()); so.AppendExecutionProvider("QNN", provider_options); Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + ASSERT_STATUS_OK(Model::Load(ToPathString(context_model_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); for (auto& node : model->MainGraph().Nodes()) { if (node.OpType() == "EPContext") { EXPECT_TRUE(node.Name().find(node_name_prefix) != std::string::npos); @@ -540,7 +569,7 @@ TEST_F(QnnHTPBackendTests, QnnContextGenerationNodeNamePrefix) { } // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -554,12 +583,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); const std::string op_type = "Atan"; @@ -577,11 +606,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model std::unordered_map session_option_pairs2; - session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, @@ -589,10 +618,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file, + context_model_file, session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Run QDQ model on HTP 3 times @@ -889,12 +918,12 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { provider_options["backend_path"] = "libQnnHtp.so"; #endif provider_options["offload_graph_io_quantization"] = "0"; - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::remove(context_binary_file.c_str()); + const std::string context_model_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_model_file.c_str()); std::unordered_map session_option_pairs; session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); @@ -913,11 +942,11 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { session_option_pairs); // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + EXPECT_TRUE(std::filesystem::exists(context_model_file.c_str())); // 2nd run directly loads and run from Qnn context cache model std::unordered_map session_option_pairs2; - session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_model_file); TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), provider_options, @@ -925,10 +954,10 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ExpectedEPNodeAssignment::All, QDQTolerance(), logging::Severity::kERROR, - context_binary_file, + context_model_file, session_option_pairs2); // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + CleanUpCtxFile(context_model_file); } // Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node @@ -1062,44 +1091,20 @@ static void CreateQdqModel(const std::string& model_file_name, const Logger& log static void DumpModelWithSharedCtx(const ProviderOptions& provider_options, const std::string& onnx_model_path1, const std::string& onnx_model_path2) { - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - auto qnn_ep = QnnExecutionProviderWithOptions(provider_options, &so); - std::shared_ptr qnn_ep_shared(std::move(qnn_ep)); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); - InferenceSessionWrapper session_object1{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object1.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object1.Load(ToPathString(onnx_model_path1))); - ASSERT_STATUS_OK(session_object1.Initialize()); + so.AppendExecutionProvider("QNN", provider_options); - InferenceSessionWrapper session_object2{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object2.RegisterExecutionProvider(qnn_ep_shared)); - ASSERT_STATUS_OK(session_object2.Load(ToPathString(onnx_model_path2))); - ASSERT_STATUS_OK(session_object2.Initialize()); -} + // Create 2 sessions to generate context binary models, the 1st session will share the QnnBackendManager + // to the 2nd session, so graphs from these 2 models are all included in the 2nd context binary + Ort::Session session1(*ort_env, ToPathString(onnx_model_path1).c_str(), so); -// from the last context ache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, thie context binary contains all graphs from all Onnx models -static void GetLastContextBinaryFileName(const std::string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - const Logger& logger) { - std::shared_ptr ctx_model; - ASSERT_STATUS_OK(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, logger)); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } // Update generated context cache Onnx model to make the main EPContext node point to @@ -1187,10 +1192,10 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions1) { DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); - // Get the last context binary file name + // Get the last context binary file name, the latest context binary file holds all graphs generated from all models std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1293,8 +1298,8 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { // Get the last context binary file name std::string last_qnn_ctx_binary_file_name; - GetLastContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, - DefaultLoggingManager().DefaultLogger()); + GetContextBinaryFileName(ctx_model_paths.back(), last_qnn_ctx_binary_file_name, + DefaultLoggingManager().DefaultLogger()); EXPECT_TRUE(!last_qnn_ctx_binary_file_name.empty()); // Update generated context cache Onnx model to make the main EPContext node point to @@ -1357,6 +1362,69 @@ TEST_F(QnnHTPBackendTests, QnnContextShareAcrossSessions2) { } std::remove(last_qnn_ctx_binary_file_name.c_str()); } + +// For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled +// Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context +TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["offload_graph_io_quantization"] = "0"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0"); + // enable ep.share_ep_contexts so that QNNEP share the QnnBackendManager across sessions + so.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session1(*ort_env, ToPathString(onnx_model_paths[0]).c_str(), so); + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + // Tell the EP stop share the QnnBackendManager from this session then on + so.AddConfigEntry(kOrtSessionOptionStopShareEpContexts, "1"); + Ort::Session session2(*ort_env, ToPathString(onnx_model_paths[1]).c_str(), so); + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + auto file_size_2 = std::filesystem::file_size(qnn_ctx_binary_file_name2); + EXPECT_TRUE(file_size_2 > file_size_1); + + // clean up + for (auto model_path : onnx_model_paths) { + ASSERT_EQ(std::remove(model_path.c_str()), 0); + } + for (auto ctx_model_path : ctx_model_paths) { + ASSERT_EQ(std::remove(ctx_model_path.c_str()), 0); + } + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name1.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_binary_file_name2.c_str()), 0); +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/qnn_ctx_gen/main.cc b/onnxruntime/test/qnn_ctx_gen/main.cc deleted file mode 100644 index bb5007b40b072..0000000000000 --- a/onnxruntime/test/qnn_ctx_gen/main.cc +++ /dev/null @@ -1,250 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// onnxruntime dependencies -#include "test_configuration.h" -#include -#include -#include -#include "command_args_parser.h" -#include - -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/inference_session.h" -#include "core/session/ort_env.h" -#include "core/providers/provider_factory_creators.h" -#include "core/common/logging/sinks/clog_sink.h" - -#include "core/graph/model.h" -#include "core/session/environment.h" -#include "core/common/logging/logging.h" - -using namespace onnxruntime; -const OrtApi* g_ort = NULL; -std::unique_ptr ort_env; - -static void CheckStatus(const Status& status) { - if (status.Code() != common::StatusCode::OK) { - std::string msg = status.ErrorMessage(); - throw Ort::Exception(std::move(msg), OrtErrorCode::ORT_FAIL); - } -} - -static int64_t GetNodeAttr(const Node& node, const std::string& attr_name, int64_t default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.i(); - } - - return default_val; -} - -static const std::string& GetNodeAttr(const Node& node, const std::string& attr_name, const std::string& default_val) { - const auto& attributes = node.GetAttributes(); - if (auto entry = attributes.find(attr_name); entry != attributes.end()) { - return entry->second.s(); - } - - return default_val; -} - -// from the last context cache Onnx model, find the EPContext node with main_context=1, -// and get the QNN context binary file name, this context binary contains all graphs from all Onnx models -// get the max spill fill buffer size -static void GetLastContextBinaryFileName(const std::basic_string last_onnx_ctx_file, - std::string& last_ctx_bin_file, - int64_t& max_size) { - max_size = 0; - std::shared_ptr ctx_model; - CheckStatus(Model::Load(ToPathString(last_onnx_ctx_file), ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - max_size = GetNodeAttr(node, "max_size", static_cast(0)); - if (1 == is_main_context) { - last_ctx_bin_file = GetNodeAttr(node, "ep_cache_context", ""); - return; - } - } - } -} - -// Update generated context cache Onnx model to make the main EPContext node point to -// the last QNN context binary file -// Remove not used QNN context binary file, only keep the last one which contains all graphs -static void UpdateEpContextModel(const std::vector>& ep_ctx_files, - const std::string& last_qnn_ctx_binary_file_name, - int64_t max_size) { - for (auto ep_ctx_file : ep_ctx_files) { - std::shared_ptr ctx_model; - auto path_str = ToPathString(ep_ctx_file); - CheckStatus(Model::Load(path_str, ctx_model, nullptr, - (*((OrtEnv*)*ort_env.get())->GetEnvironment().GetLoggingManager()).DefaultLogger())); - auto& ctx_graph = ctx_model->MainGraph(); - GraphViewer graph_viewer(ctx_graph); - auto path = std::filesystem::path(path_str); - - for (auto& node : ctx_graph.Nodes()) { - if (node.OpType() == "EPContext") { - int64_t is_main_context = GetNodeAttr(node, "main_context", static_cast(0)); - if (1 == is_main_context) { - std::string old_qnn_ctx_binary_file_name = GetNodeAttr(node, "ep_cache_context", ""); - auto file_path = path.replace_filename(old_qnn_ctx_binary_file_name); - std::remove(file_path.string().c_str()); - node.ClearAttribute("ep_cache_context"); - node.AddAttribute("ep_cache_context", last_qnn_ctx_binary_file_name); - node.ClearAttribute("max_size"); - node.AddAttribute("max_size", max_size); - } - } - } - std::remove(ToUTF8String(ep_ctx_file).c_str()); - CheckStatus(Model::Save(*ctx_model.get(), ToPathString(ep_ctx_file))); - } -} - -#ifdef _WIN32 -int real_main(int argc, wchar_t* argv[]) { -#else -int real_main(int argc, char* argv[]) { -#endif - g_ort = OrtGetApiBase()->GetApi(ORT_API_VERSION); - qnnctxgen::TestConfig test_config; - if (!qnnctxgen::CommandLineParser::ParseArguments(test_config, argc, argv)) { - qnnctxgen::CommandLineParser::ShowUsage(); - return -1; - } - - { - bool failed = false; - ORT_TRY { - OrtLoggingLevel logging_level = test_config.run_config.f_verbose - ? ORT_LOGGING_LEVEL_VERBOSE - : ORT_LOGGING_LEVEL_WARNING; - - ort_env = std::make_unique(logging_level, "Default"); - } - ORT_CATCH(const Ort::Exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "Error creating environment. Error-> %s \n", e.what()); - failed = true; - }); - } - - if (failed) - return -1; - } - - ORT_TRY { - SessionOptions so; - so.session_logid = "qnn_ctx_gen_session_logger"; - // Set default session option to dump QNN context model with non-embed mode - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1")); - CheckStatus(so.config_options.AddConfigEntry(kOrtSessionOptionEpContextEmbedMode, "0")); - RunOptions run_options; - run_options.run_tag = so.session_logid; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - // set default QNN EP option to enable weight sharing - provider_options["enable_htp_weight_sharing"] = "1"; - - for (auto it : test_config.run_config.qnn_options) { - provider_options[it.first] = it.second; - } - - for (auto it : test_config.run_config.session_config_entries) { - if (it.first == kOrtSessionOptionEpContextEnable && it.second != "1") { - std::cerr << "Need to enable ep context cache." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextEmbedMode && it.second != "0") { - std::cerr << "Only support non-embed model for weight sharing." << std::endl; - continue; - } - if (it.first == kOrtSessionOptionEpContextFilePath) { - std::cout << "Not support to specify the generated Onnx context cache file name." << std::endl; - continue; - } - CheckStatus(so.config_options.AddConfigEntry(it.first.c_str(), it.second.c_str())); - } - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Model file path: " << ToUTF8String(model_path) << std::endl; - } - - // Generate context cache model files with QNN context binary files - // The context binary file generated later includes all graphs from previous models - { - auto ep = QNNProviderFactoryCreator::Create(provider_options, &so)->CreateProvider(); - std::shared_ptr qnn_ep(std::move(ep)); - - for (auto model_path : test_config.model_file_paths) { - std::cout << "Generate context cache model for: " << ToUTF8String(model_path) << std::endl; - InferenceSession session_object{so, ((OrtEnv*)*ort_env.get())->GetEnvironment()}; - CheckStatus(session_object.RegisterExecutionProvider(qnn_ep)); - CheckStatus(session_object.Load(ToPathString(model_path))); - CheckStatus(session_object.Initialize()); - } - } - - std::cout << "Start to update the generated Onnx model." << std::endl; - std::vector> ep_ctx_files; - ep_ctx_files.reserve(test_config.model_file_paths.size()); - for (auto model_path : test_config.model_file_paths) { - ep_ctx_files.push_back(model_path + ORT_TSTR("_ctx.onnx")); - } - - // Get the last context binary file name - std::string last_qnn_ctx_binary_file_name; - int64_t max_size = 0; - GetLastContextBinaryFileName(ep_ctx_files.back(), last_qnn_ctx_binary_file_name, max_size); - std::cout << "The last context binary file: " << last_qnn_ctx_binary_file_name << std::endl; - if (last_qnn_ctx_binary_file_name.empty()) { - throw Ort::Exception("Can't find QNN context binary file from the Onnx model.", OrtErrorCode::ORT_FAIL); - } - ep_ctx_files.pop_back(); - - // Update generated context cache Onnx model to make the main EPContext node point to - // the last QNN context binary file - // Remove not used QNN context binary file, only keep the last one which contains all graphs - UpdateEpContextModel(ep_ctx_files, last_qnn_ctx_binary_file_name, max_size); - } - ORT_CATCH(const Ort::Exception& e) { - fprintf(stderr, "Failed to generate context cache file: %s \n", e.what()); - - ort_env.reset(); - return -1; - } - - ort_env.reset(); - - return 0; -} - -#ifdef _WIN32 -int wmain(int argc, wchar_t* argv[]) { -#else -int main(int argc, char* argv[]) { -#endif - int retval = -1; - ORT_TRY { - retval = real_main(argc, argv); - } - ORT_CATCH(const std::exception& ex) { - ORT_HANDLE_EXCEPTION([&]() { - fprintf(stderr, "%s\n", ex.what()); - retval = -1; - }); - } - - ::google::protobuf::ShutdownProtobufLibrary(); - - return retval; -} diff --git a/setup.py b/setup.py index ced2f28e38778..53e533050b245 100644 --- a/setup.py +++ b/setup.py @@ -356,7 +356,7 @@ def finalize_options(self): "libQnnSaver.so", "libQnnSystem.so", "libHtpPrepare.so", - "onnxruntime_qnn_ctx_gen", + "ep_weight_sharing_ctx_gen", ] dl_libs.extend(qnn_deps) if nightly_build: