diff --git a/examples/c/CMakeLists.txt b/examples/c/CMakeLists.txt index 226e88d994..ffa2e2d4ad 100644 --- a/examples/c/CMakeLists.txt +++ b/examples/c/CMakeLists.txt @@ -24,6 +24,7 @@ FetchContent_MakeAvailable(CLI11) option(USE_CXX "Invoke the C++ example" ON) option(MODEL_CHAT "Build the Model Chat example" OFF) option(MODEL_QA "Build the Model Q&A example" OFF) +option(MODEL_COMPILE "Build the Model Compile example" OFF) option(MODEL_MM "Build the Model Multimodal example" OFF) option(WHISPER "Build the Whisper example" OFF) @@ -113,6 +114,13 @@ if(MODEL_QA) target_link_libraries(model_qa PRIVATE CLI11::CLI11) endif() +if(MODEL_COMPILE) + add_executable(model_compile ${EXAMPLES_SOURCE_DIR}/model_compile.cpp ${EXAMPLES_SOURCE_DIR}/common.cpp) + prepare_executable(model_compile) + target_link_libraries(model_compile PRIVATE nlohmann_json::nlohmann_json) + target_link_libraries(model_compile PRIVATE CLI11::CLI11) +endif() + if(MODEL_MM) add_executable(model_mm ${EXAMPLES_SOURCE_DIR}/model_mm.cpp ${EXAMPLES_SOURCE_DIR}/common.cpp) prepare_executable(model_mm) diff --git a/examples/c/src/common.cpp b/examples/c/src/common.cpp index 15159e779a..62fc487d81 100644 --- a/examples/c/src/common.cpp +++ b/examples/c/src/common.cpp @@ -236,18 +236,17 @@ void RegisterEP(const std::string& ep, const std::string& ep_path) { return; // No library path specified, skip registration } - std::cout << "Registering execution provider: " << ep_path << std::endl; - auto env = Ort::Env(); + // Must register on GenAI's OrtEnv (via OgaRegisterExecutionProviderLibrary) so + // GetEpDevices() in ValidateCompiledModel sees the plugin; Ort::Env() is a different env. if (ep.compare("cuda") == 0) { - env.RegisterExecutionProviderLibrary("CUDAExecutionProvider", std::filesystem::path(ep_path).c_str()); + OgaRegisterExecutionProviderLibrary("CUDAExecutionProvider", ep_path.c_str()); } else if (ep.compare("NvTensorRtRtx") == 0) { - env.RegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", std::filesystem::path(ep_path).c_str()); + OgaRegisterExecutionProviderLibrary("NvTensorRTRTXExecutionProvider", ep_path.c_str()); } else { std::cout << "Warning: EP registration not supported for " << ep << std::endl; std::cout << "Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries." << std::endl; return; } - std::cout << "Registered " << ep << " successfully!" << std::endl; } diff --git a/examples/c/src/model_compile.cpp b/examples/c/src/model_compile.cpp new file mode 100644 index 0000000000..714c7d8109 --- /dev/null +++ b/examples/c/src/model_compile.cpp @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Model Compile example: runs the same model under different EP and compile configurations +// (CPU, CPU+overlay, NvTensorRtRtx no-compile / 4 options / all options). Use -v for verbose, +// -d for ORT verbose logging (ORTGENAI_ORT_VERBOSE_LOGGING=1). + +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" + +namespace fs = std::filesystem; + +// Enable ONNX Runtime verbose logging. Must be set before any Oga/ORT API use. +// Alternatively set env ORTGENAI_ORT_VERBOSE_LOGGING=1 before launching. +static void SetOrtVerboseLogging() { +#ifdef _WIN32 + _putenv("ORTGENAI_ORT_VERBOSE_LOGGING=1"); +#else + setenv("ORTGENAI_ORT_VERBOSE_LOGGING", "1", 1); +#endif +} + +static const char* kCpuEp = "cpu"; +static const char* kNvTensorRtRtxEp = "NvTensorRtRtx"; + +static const char* kDefaultPrompt = "Tell me about AI and ML"; + +static double RunOneGeneration(OgaModel& model, OgaTokenizer& tokenizer, bool verbose) { + auto stream = OgaTokenizerStream::Create(tokenizer); + auto sequences = OgaSequences::Create(); + tokenizer.Encode(kDefaultPrompt, *sequences); + + auto params = OgaGeneratorParams::Create(model); + params->SetSearchOption("max_length", 128); + params->SetSearchOption("batch_size", 1); + + auto generator = OgaGenerator::Create(model, *params); + generator->AppendTokenSequences(*sequences); + + if (verbose) std::cout << "Prompt: " << kDefaultPrompt << std::endl; + std::cout << "Output: " << std::flush; + auto t0 = Clock::now(); + while (!generator->IsDone()) { + generator->GenerateNextToken(); + std::cout << stream->Decode(generator->GetNextTokens()[0]) << std::flush; + } + std::cout << std::endl; + return std::chrono::duration(Clock::now() - t0).count(); +} + +static void PrintTimings(const char* label, double load_time_sec, double inference_time_sec) { + const auto default_precision = std::cout.precision(); + std::cout << " " << label << ": " + << std::fixed << std::setprecision(3) + << "model load " << load_time_sec << "s, " + << "inference " << inference_time_sec << "s" + << std::setprecision(default_precision) << std::endl; +} + +// 1) Run model with CPU execution provider only (no compile overlay). +void RunWithCpu(const std::string& model_path, const std::string& ep_path, bool verbose) { + (void)ep_path; + if (verbose) std::cout << "[RunWithCpu] Creating config (CPU, no compile overlay)..." << std::endl; + std::unordered_map ep_options; + GeneratorParamsArgs search_options; + auto config = GetConfig(model_path, kCpuEp, ep_options, search_options); + if (verbose) std::cout << "[RunWithCpu] Creating model..." << std::endl; + auto load_t0 = Clock::now(); + auto model = OgaModel::Create(*config); + double load_time = std::chrono::duration(Clock::now() - load_t0).count(); + if (verbose) std::cout << "[RunWithCpu] Creating tokenizer..." << std::endl; + auto tokenizer = OgaTokenizer::Create(*model); + double inference_time = RunOneGeneration(*model, *tokenizer, verbose); + PrintTimings("RunWithCpu (CPU, no overlay)", load_time, inference_time); +} + +// 2) Run model with CPU execution provider and compile config passed via config_overlay. +void RunWithCpuAndCompileOverlay(const std::string& model_path, const std::string& ep_path, bool verbose) { + (void)ep_path; + if (verbose) std::cout << "[RunWithCpuAndCompileOverlay] Creating config (CPU + compile overlay)..." << std::endl; + std::unordered_map ep_options; + GeneratorParamsArgs search_options; + auto config = GetConfig(model_path, kCpuEp, ep_options, search_options); + config->Overlay(R"({ + "model": { + "decoder": { + "compile_options": { + "enable_ep_context": true, + "ep_context_embed_mode": false, + "force_compile_if_needed": true, + "graph_optimization_level": 99 + } + } + } + })"); + if (verbose) std::cout << "[RunWithCpuAndCompileOverlay] Creating model..." << std::endl; + auto load_t0 = Clock::now(); + auto model = OgaModel::Create(*config); + double load_time = std::chrono::duration(Clock::now() - load_t0).count(); + if (verbose) std::cout << "[RunWithCpuAndCompileOverlay] Creating tokenizer..." << std::endl; + auto tokenizer = OgaTokenizer::Create(*model); + double inference_time = RunOneGeneration(*model, *tokenizer, verbose); + PrintTimings("RunWithCpuAndCompileOverlay (CPU + overlay)", load_time, inference_time); +} + +// 3) Run model with NvTensorRtRtx EP without compile options. +void RunWithNvTensorRtRtxNoCompile(const std::string& model_path, const std::string& ep_path, bool verbose) { + if (ep_path.empty() && verbose) { + std::cout << "Warning: --ep_path not set; NvTensorRTRTX may not be available (only CPU)." << std::endl; + } + if (verbose) std::cout << "[RunWithNvTensorRtRtxNoCompile] Creating config (NvTensorRtRtx, no compile)..." << std::endl; + std::unordered_map ep_options; + GeneratorParamsArgs search_options; + auto config = GetConfig(model_path, kNvTensorRtRtxEp, ep_options, search_options); + if (verbose) std::cout << "[RunWithNvTensorRtRtxNoCompile] Creating model..." << std::endl; + auto load_t0 = Clock::now(); + auto model = OgaModel::Create(*config); + double load_time = std::chrono::duration(Clock::now() - load_t0).count(); + if (verbose) std::cout << "[RunWithNvTensorRtRtxNoCompile] Creating tokenizer..." << std::endl; + auto tokenizer = OgaTokenizer::Create(*model); + double inference_time = RunOneGeneration(*model, *tokenizer, verbose); + PrintTimings("RunWithNvTensorRtRtxNoCompile (NvTensorRtRtx, no compile)", load_time, inference_time); +} + +// 4) Run model with NvTensorRtRtx EP and minimum compile options. +void RunWithNvTensorRtRtxMinimumCompileOptions(const std::string& model_path, const std::string& ep_path, bool verbose) { + if (ep_path.empty() && verbose) { + std::cout << "Warning: --ep_path not set; NvTensorRTRTX may not be available (only CPU)." << std::endl; + } + if (verbose) std::cout << "[RunWithNvTensorRtRtxMinimumCompileOptions] Creating config (NvTensorRtRtx + minimum compile options)..." << std::endl; + std::unordered_map ep_options; + GeneratorParamsArgs search_options; + auto config = GetConfig(model_path, kNvTensorRtRtxEp, ep_options, search_options); + // ep_context_embed_mode must be false for larger models(>2GB) or compilation will error + config->Overlay(R"({ + "model": { + "decoder": { + "compile_options": { + "enable_ep_context": true, + "ep_context_embed_mode": false + } + } + } + })"); + if (verbose) std::cout << "[RunWithNvTensorRtRtxMinimumCompileOptions] Creating model..." << std::endl; + auto load_t0 = Clock::now(); + auto model = OgaModel::Create(*config); + double load_time = std::chrono::duration(Clock::now() - load_t0).count(); + if (verbose) std::cout << "[RunWithNvTensorRtRtxMinimumCompileOptions] Creating tokenizer..." << std::endl; + auto tokenizer = OgaTokenizer::Create(*model); + double inference_time = RunOneGeneration(*model, *tokenizer, verbose); + PrintTimings("RunWithNvTensorRtRtxMinimumCompileOptions (minimum options)", load_time, inference_time); +} + +// 5) Run model with NvTensorRtRtx EP and all compile options. +void RunWithNvTensorRtRtxCompileAllOptions(const std::string& model_path, const std::string& ep_path, bool verbose) { + if (ep_path.empty() && verbose) { + std::cout << "Warning: --ep_path not set; NvTensorRTRTX may not be available (only CPU)." << std::endl; + } + if (verbose) std::cout << "[RunWithNvTensorRtRtxCompileAllOptions] Creating config (NvTensorRtRtx + all compile options)..." << std::endl; + std::unordered_map ep_options; + GeneratorParamsArgs search_options; + auto config = GetConfig(model_path, kNvTensorRtRtxEp, ep_options, search_options); + // Single config: ep_context_file_path is full path (relative to model dir) including filename, e.g. "contexts/model_ctx.onnx" + config->Overlay(R"({ + "model": { + "decoder": { + "compile_options": { + "enable_ep_context": true, + "graph_optimization_level": 99, + "ep_context_file_path": "contexts/ep_context_output/model_ctx.onnx", + "ep_context_embed_mode": false, + "force_compile_if_needed": true + } + } + } + })"); + if (verbose) std::cout << "[RunWithNvTensorRtRtxCompileAllOptions] Creating model..." << std::endl; + auto load_t0 = Clock::now(); + auto model = OgaModel::Create(*config); + double load_time = std::chrono::duration(Clock::now() - load_t0).count(); + if (verbose) std::cout << "[RunWithNvTensorRtRtxCompileAllOptions] Creating tokenizer..." << std::endl; + auto tokenizer = OgaTokenizer::Create(*model); + double inference_time = RunOneGeneration(*model, *tokenizer, verbose); + PrintTimings("RunWithNvTensorRtRtxCompileAllOptions (all options)", load_time, inference_time); +} + +int main(int argc, char** argv) { + GeneratorParamsArgs generator_params_args; + GuidanceArgs guidance_args; + std::string model_path, ep = "follow_config", ep_path, system_prompt, user_prompt; + bool verbose = false, debug = false, interactive = false, rewind = true; + std::vector image_paths, audio_paths; + + if (!ParseArgs(argc, argv, generator_params_args, guidance_args, model_path, ep, ep_path, system_prompt, user_prompt, verbose, debug, interactive, rewind, image_paths, audio_paths)) { + return -1; + } + + if (ep.compare(kNvTensorRtRtxEp) == 0 && ep_path.empty()) { +#if defined(_WIN32) + ep_path = (fs::current_path() / "onnxruntime_providers_nv_tensorrt_rtx.dll").string(); +#else + ep_path = (fs::current_path() / "libonnxruntime_providers_nv_tensorrt_rtx.so").string(); +#endif + } + + if (debug) { + SetOrtVerboseLogging(); + SetLogger(); + } + + if (!ep_path.empty()) { + RegisterEP(kNvTensorRtRtxEp, ep_path); + } + + OgaHandle handle; + + if (verbose) { + std::cout << "Model path: " << model_path << std::endl; + std::cout << "EP path: " << (ep_path.empty() ? "(none)" : ep_path) << std::endl; + } + std::cout << "Timings (model load, inference):" << std::endl; + + try { + // RunWithCpu(model_path, ep_path, verbose); + // RunWithCpuAndCompileOverlay(model_path, ep_path, verbose); + //First run the no-compile case + RunWithNvTensorRtRtxNoCompile(model_path, ep_path, verbose); + //Then run for first time compile case, Model load time will be load time at no compile + compile time + RunWithNvTensorRtRtxMinimumCompileOptions(model_path, ep_path, verbose); + //Then run for second time compile case, Model load time must be very less as it is already compiled + RunWithNvTensorRtRtxMinimumCompileOptions(model_path, ep_path, verbose); + //Then run for all compile options,With different ep_context_file_path, ep_context_embed_mode, force_compile_if_needed, graph_optimization_level + RunWithNvTensorRtRtxCompileAllOptions(model_path, ep_path, verbose); + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << std::endl; + return -1; + } + + return 0; +} diff --git a/src/config.cpp b/src/config.cpp index 94417c6782..6d0f5c81c1 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -227,6 +227,39 @@ struct RunOptions_Element : JSON::Element { Config::RunOptions& v_; }; +struct CompileOptions_Element : JSON::Element { + explicit CompileOptions_Element(Config::CompileOptions& v) : v_{v} {} + + void OnValue(std::string_view name, JSON::Value value) override { + if (name == "enable_ep_context") { + v_.enable_ep_context = JSON::Get(value); + } else if (name == "force_compile_if_needed") { + v_.force_compile_if_needed = JSON::Get(value); + } else if (name == "graph_optimization_level") { + auto level = static_cast(JSON::Get(value)); + if (level < ORT_DISABLE_ALL || level > ORT_ENABLE_ALL) { + throw std::runtime_error("Invalid graph_optimization_level value: " + std::to_string(level)); + } + v_.graph_optimization_level = static_cast(level); + } else if (name == "ep_context_file_path") { + v_.ep_context_file_path = JSON::Get(value); + } else if (name == "ep_context_embed_mode") { + v_.ep_context_embed_mode = JSON::Get(value); + } else if (name == "flags") { + v_.flags = static_cast(JSON::Get(value)); + } else if (name == "external_initializers_file_path") { + v_.external_initializers_file_path = JSON::Get(value); + } else if (name == "external_initializers_size_threshold") { + v_.external_initializers_size_threshold = static_cast(JSON::Get(value)); + } else { + throw JSON::unknown_value_error{}; + } + } + + private: + Config::CompileOptions& v_; +}; + struct EncoderInputs_Element : JSON::Element { explicit EncoderInputs_Element(Config::Model::Encoder::Inputs& v) : v_{v} {} @@ -412,6 +445,11 @@ struct PipelineModel_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "output_names_forwarder") { return output_names_forwarder_; } @@ -431,6 +469,7 @@ struct PipelineModel_Element : JSON::Element { Config::Model::Decoder::PipelineModel& v_; std::unique_ptr session_options_; std::unique_ptr run_options_; + std::unique_ptr compile_options_; StringArray_Element inputs_{v_.inputs}; StringArray_Element outputs_{v_.outputs}; StringStringMap_Element output_names_forwarder_{v_.output_names_forwarder}; @@ -530,6 +569,11 @@ struct Encoder_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "inputs") { return inputs_; } @@ -543,6 +587,7 @@ struct Encoder_Element : JSON::Element { Config::Model::Encoder& v_; std::unique_ptr session_options_; std::unique_ptr run_options_; + std::unique_ptr compile_options_; EncoderInputs_Element inputs_{v_.inputs}; EncoderOutputs_Element outputs_{v_.outputs}; }; @@ -577,6 +622,11 @@ struct Decoder_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "inputs") { return inputs_; } @@ -606,6 +656,7 @@ struct Decoder_Element : JSON::Element { Config::Model::Decoder& v_; SessionOptions_Element session_options_{v_.session_options}; std::unique_ptr run_options_; + std::unique_ptr compile_options_; DecoderInputs_Element inputs_{v_.inputs}; DecoderOutputs_Element outputs_{v_.outputs}; Pipeline_Element pipeline_{v_.pipeline}; @@ -750,6 +801,11 @@ struct Vision_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "inputs") { return inputs_; } @@ -775,6 +831,7 @@ struct Vision_Element : JSON::Element { Config::Model::Vision& v_; std::unique_ptr session_options_; std::unique_ptr run_options_; + std::unique_ptr compile_options_; VisionInputs_Element inputs_{v_.inputs}; VisionOutputs_Element outputs_{v_.outputs}; VisionPipeline_Element pipeline_element_{v_.pipeline}; @@ -843,6 +900,11 @@ struct Speech_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "inputs") { return inputs_; } @@ -856,6 +918,7 @@ struct Speech_Element : JSON::Element { Config::Model::Speech& v_; std::unique_ptr session_options_; std::unique_ptr run_options_; + std::unique_ptr compile_options_; SpeechInputs_Element inputs_{v_.inputs}; SpeechOutputs_Element outputs_{v_.outputs}; }; @@ -916,6 +979,11 @@ struct Embedding_Element : JSON::Element { run_options_ = std::make_unique(*v_.run_options); return *run_options_; } + if (name == "compile_options") { + v_.compile_options = Config::CompileOptions{}; + compile_options_ = std::make_unique(*v_.compile_options); + return *compile_options_; + } if (name == "inputs") { return inputs_; } @@ -929,6 +997,7 @@ struct Embedding_Element : JSON::Element { Config::Model::Embedding& v_; std::unique_ptr session_options_; std::unique_ptr run_options_; + std::unique_ptr compile_options_; EmbeddingInputs_Element inputs_{v_.inputs}; EmbeddingOutputs_Element outputs_{v_.outputs}; }; diff --git a/src/config.h b/src/config.h index 51358b8192..bb46310061 100644 --- a/src/config.h +++ b/src/config.h @@ -99,6 +99,19 @@ struct Config { using RunOptions = std::vector; // Entries go into OrtRunOptions::AddConfigEntry + struct CompileOptions { + std::optional enable_ep_context; // Whether to enable model compilation + std::optional force_compile_if_needed; // If true, treat PREFER_RECOMPILATION as invalid and recompile; if false, accept OPTIMAL or PREFER_RECOMPILATION as valid + std::optional graph_optimization_level; + std::optional ep_context_file_path; // Full path (relative to config path) for compiled EP context model, e.g. "contexts/model_ctx.onnx" (default: "contexts/{model_name}_{ep_name}_ctx.onnx") + std::optional ep_context_embed_mode; + std::optional flags; + std::optional external_initializers_file_path; + std::optional external_initializers_size_threshold; + // Note: Function pointers for write_func and get_initializer_location_func + // cannot be configured via JSON and must be set programmatically + }; + struct Model { std::string type; @@ -120,6 +133,7 @@ struct Config { std::string filename; std::optional session_options; std::optional run_options; + std::optional compile_options; int hidden_size{}; int num_attention_heads{}; @@ -146,6 +160,7 @@ struct Config { std::string filename; std::optional session_options; std::optional run_options; + std::optional compile_options; struct Inputs { std::string input_ids{Defaults::InputIdsName}; @@ -162,6 +177,7 @@ struct Config { std::string filename; std::optional session_options; std::optional run_options; + std::optional compile_options; // Qwen2.5-VL specific vision config values int spatial_merge_size{2}; @@ -198,6 +214,7 @@ struct Config { std::string filename; std::optional session_options; std::optional run_options; + std::optional compile_options; std::string config_filename{"audio_processor_config.json"}; std::optional adapter_filename{}; @@ -218,6 +235,7 @@ struct Config { std::string filename; SessionOptions session_options; std::optional run_options; + std::optional compile_options; int hidden_size{}; // Not currently used, potentially useful for embeddings in the future int num_attention_heads{}; // Not currently used, potentially useful if num_key_value_heads isn't set @@ -270,6 +288,7 @@ struct Config { std::string filename; std::optional session_options; std::optional run_options; + std::optional compile_options; std::string model_id; std::vector inputs; diff --git a/src/filesystem.h b/src/filesystem.h index f7c2c2262e..c85b1a8cf7 100644 --- a/src/filesystem.h +++ b/src/filesystem.h @@ -22,6 +22,11 @@ #include #include +#ifndef _WIN32 +#include +#include +#endif + namespace fs { class path { @@ -71,7 +76,7 @@ class path { return join(path); } - path operator/(const path& path) { + path operator/(const path& path) const { return join(path.path_); } @@ -131,6 +136,14 @@ class path { return path(path_.substr(0, pos)); } + std::string filename() const { + size_t pos = path_.find_last_of("/\\"); + if (pos == std::string::npos) { + return path_; // No separator found, entire path is the filename + } + return path_.substr(pos + 1); + } + private: std::string path_; @@ -178,4 +191,42 @@ inline bool exists(const path& p) { return p.exists(); } +inline bool create_directories(const path& p) { +#ifdef _WIN32 + // On Windows, create directory recursively using CreateDirectoryW + if (p.exists()) { + return true; // Already exists + } + + // First create parent directory if needed + path parent = p.parent_path(); + if (!parent.string().empty() && !parent.exists()) { + if (!create_directories(parent)) { + return false; + } + } + + // Create the directory + return CreateDirectoryW(p.c_str(), nullptr) != 0 || GetLastError() == ERROR_ALREADY_EXISTS; +#else + // On Unix-like systems, use mkdir with recursive creation + if (p.exists()) { + return true; // Already exists + } + + // First create parent directory if needed + path parent = p.parent_path(); + if (!parent.string().empty() && !parent.exists()) { + if (!create_directories(parent)) { + return false; + } + } + + // Create the directory with 0755 permissions + errno = 0; + int result = mkdir(p.c_str(), 0755); + return result == 0 || errno == EEXIST; +#endif +} + } // namespace fs diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index b7a571d586..7708c421d4 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -4,7 +4,8 @@ namespace Generators { DecoderOnly_Model::DecoderOnly_Model(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { - session_decoder_ = CreateSession(ort_env, config_->model.decoder.filename, session_options_.get()); + std::string decoder_model_path = CompileModel(ort_env, config_->model.decoder.filename, session_options_.get(), true, config_->model.decoder.compile_options); + session_decoder_ = CreateSession(ort_env, decoder_model_path, session_options_.get()); session_info_.Add(*session_decoder_); } diff --git a/src/models/decoder_only_pipeline.cpp b/src/models/decoder_only_pipeline.cpp index 497bd1b295..ffb86de2fc 100644 --- a/src/models/decoder_only_pipeline.cpp +++ b/src/models/decoder_only_pipeline.cpp @@ -12,7 +12,13 @@ namespace Generators { DecoderOnlyPipelineModel::DecoderOnlyPipelineModel(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)}, ort_env_{ort_env} { for (const auto& model : config_->model.decoder.pipeline) { - sessions_.emplace_back(CreateSession(ort_env, model.filename, GetSessionOptions(model.model_id))); + // Get the compiled model path if it was compiled, otherwise use full path from config + filename + std::string model_path = GetPipelineCompiledModelPath(model.model_id); + if (model_path.empty()) { + // Use full path to original model if not compiled + model_path = (config_->config_path / fs::path(model.filename)).string(); + } + sessions_.emplace_back(CreateSession(ort_env, model_path, GetSessionOptions(model.model_id))); } for (auto& session : sessions_) { diff --git a/src/models/gpt.cpp b/src/models/gpt.cpp index d48cfd2a83..ab4c1d064a 100644 --- a/src/models/gpt.cpp +++ b/src/models/gpt.cpp @@ -5,7 +5,8 @@ namespace Generators { Gpt_Model::Gpt_Model(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { - session_decoder_ = CreateSession(ort_env, config_->model.decoder.filename, session_options_.get()); + std::string decoder_model_path = CompileModel(ort_env, config_->model.decoder.filename, session_options_.get(), true, config_->model.decoder.compile_options); + session_decoder_ = CreateSession(ort_env, decoder_model_path, session_options_.get()); session_info_.Add(*session_decoder_); } diff --git a/src/models/marian.cpp b/src/models/marian.cpp index 1da12617e9..9d9c0ed1fe 100644 --- a/src/models/marian.cpp +++ b/src/models/marian.cpp @@ -8,10 +8,13 @@ namespace Generators { MarianModel::MarianModel(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { encoder_session_options_ = OrtSessionOptions::Create(); - CreateSessionOptionsFromConfig(config_->model.encoder.session_options.has_value() ? config_->model.encoder.session_options.value() : config_->model.decoder.session_options, *encoder_session_options_, true, false); + CreateSessionOptionsFromConfig(config_->model.encoder.session_options.has_value() ? config_->model.encoder.session_options.value() : config_->model.decoder.session_options, *encoder_session_options_, false, false); - session_encoder_ = CreateSession(ort_env, config_->model.encoder.filename, encoder_session_options_.get()); - session_decoder_ = CreateSession(ort_env, config_->model.decoder.filename, session_options_.get()); + std::string encoder_model_path = CompileModel(ort_env, config_->model.encoder.filename, encoder_session_options_.get(), false, config_->model.encoder.compile_options); + session_encoder_ = CreateSession(ort_env, encoder_model_path, encoder_session_options_.get()); + + std::string decoder_model_path = CompileModel(ort_env, config_->model.decoder.filename, session_options_.get(), true, config_->model.decoder.compile_options); + session_decoder_ = CreateSession(ort_env, decoder_model_path, session_options_.get()); session_info_.Add(*session_decoder_); session_info_.Add(*session_encoder_); diff --git a/src/models/model.cpp b/src/models/model.cpp index a79bdb55bf..ae207b8ae1 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -3,6 +3,7 @@ // // Modifications Copyright(C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved. #include +#include #include #include #include @@ -1014,7 +1015,7 @@ std::vector SessionInfo::GetOutputSymbolicShape(const std::string& } Model::Model(std::unique_ptr config) : config_{std::move(config)} { - CreateSessionOptions(); + CreateSessionOptions(); EnsureDeviceOrtInit(*p_device_, *config_, arena_cfg_); // Only CUDA, TRT-RTX, RyzenAI and DML does every input on the device @@ -1213,7 +1214,272 @@ OrtSessionOptions* Model::GetSessionOptions(const std::string& model_id) const { return session_options_.get(); } -std::unique_ptr Model::CreateSession(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options) { +std::string Model::GetPipelineCompiledModelPath(const std::string& model_id) const { + auto it = pipeline_compiled_model_paths_.find(model_id); + if (it != pipeline_compiled_model_paths_.end()) { + return it->second; + } + return ""; // Return empty string if not found +} + +std::unique_ptr Model::CreateModelCompilationOptions(OrtEnv& ort_env, OrtSessionOptions* session_options) { + // Create model compilation options from the provided session options + if (!session_options) { + return nullptr; + } + + OrtModelCompilationOptions* p; + Ort::ThrowOnError(Ort::GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(&ort_env, session_options, &p)); + return std::unique_ptr(p); +} + +bool Model::ValidateCompiledModel(OrtEnv& ort_env, const fs::path& compiled_model_path, bool force_compile_if_needed) { + const std::string ep_name = EPContextSupportedProviders(p_device_->GetType()); + // EP context applicability already checked in CompileModel; ep_name is non-empty here. + + Ort::Allocator& alloc = Ort::Allocator::GetWithDefaultOptions(); + char* compat_info = nullptr; + OrtStatus* st = Ort::api->GetCompatibilityInfoFromModel( + compiled_model_path.c_str(), ep_name.c_str(), &alloc, &compat_info); + if (st != nullptr) { + Ort::api->ReleaseStatus(st); + return false; // Error reading model (e.g. invalid file) -> recompile + } + // Context valid only if compatibility info is present for this EP + if (compat_info == nullptr) { + return false; + } + std::string compat_str(compat_info); + Ort::api->AllocatorFree(&alloc, compat_info); + + const OrtEpDevice* const* devices = nullptr; + size_t num_devices = 0; + st = Ort::api->GetEpDevices(&ort_env, &devices, &num_devices); + if (st != nullptr) { + Ort::api->ReleaseStatus(st); + return false; // Cannot enumerate devices -> recompile to be safe + } + + const OrtEpDevice* ep_device = nullptr; + for (size_t i = 0; i < num_devices; ++i) { + const char* device_ep = Ort::api->EpDevice_EpName(devices[i]); + if (device_ep && std::string(device_ep) == ep_name) { + ep_device = devices[i]; + break; + } + } + if (ep_device == nullptr) { + return false; // No matching EP device -> all other cases return false + } + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + OrtStatus* result = Ort::api->GetModelCompatibilityForEpDevices(&ep_device, 1, compat_str.c_str(), &status); + if (result != nullptr) { + Ort::api->ReleaseStatus(result); + return false; // API error -> recompile + } + // EPContext is valid and optimal, no need to recompile + if (status == OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL) { + return true; + } + if (status == OrtCompiledModelCompatibility_EP_UNSUPPORTED) { + return false; // Context not compatible with this EP -> recompile + } + if (status == OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION) { + if (force_compile_if_needed) { + Log("info", "Found existing EP Context for " + ep_name + " but performance is sub-optimal; Force compile is enabled, going to recompile EPContext."); + return false; + } + Log("warning", "Found existing EP Context for " + ep_name + " but its performance is sub-optimal in this EP, recommended to recompile EPContext."); + return true; + } + + return false; // NOT_APPLICABLE or unknown +} + +bool Model::CheckCompiledModelExists(OrtEnv& ort_env, + const std::string& model_filename, + const Config::CompileOptions& compile_options_config, + fs::path& out_compiled_model_path) { + if (compile_options_config.ep_context_file_path.has_value() && !compile_options_config.ep_context_file_path.value().empty()) { + // Single path: full path (relative to config path) including filename, e.g. "contexts/model_ctx.onnx" + out_compiled_model_path = config_->config_path / compile_options_config.ep_context_file_path.value(); + } else { + // Default: "contexts/{model_name}_{ep_name}_ctx.onnx" + std::string model_name = model_filename; + size_t ext_pos = model_name.find_last_of('.'); + if (ext_pos != std::string::npos) { + model_name = model_name.substr(0, ext_pos); + } + std::string ep_name = to_string(p_device_->GetType()); + std::transform(ep_name.begin(), ep_name.end(), ep_name.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::string output_filename = model_name + "_" + ep_name + "_ctx.onnx"; + out_compiled_model_path = config_->config_path / "contexts" / output_filename; + } + + // Check if the compiled model file exists + if (!fs::exists(out_compiled_model_path)) { + return false; // File doesn't exist, need to compile + } + + // Validate the compiled model (EP compatibility) + const bool force_compile = compile_options_config.force_compile_if_needed.value_or(false); + return ValidateCompiledModel(ort_env, out_compiled_model_path, force_compile); +} + +std::string Model::CompileModel(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options, + bool is_primary_session_option, const std::optional& compile_options) { + + // Check if compilation is enabled for the specified model + if (!compile_options.has_value()) { + // No compile options provided, return full path to original model + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + const auto& comp_opts = compile_options.value(); + if (!comp_opts.enable_ep_context.has_value() || !comp_opts.enable_ep_context.value()) { + // Compilation not enabled, return full path to original model + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + // EP Context is only applicable for certain EPs (e.g. NvTensorRTRTX); if not applicable, use original model + if (EPContextSupportedProviders(p_device_->GetType()).empty()) { + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + // Helper lambda to configure and compile a model + auto compile_model_helper = [this, &ort_env](OrtModelCompilationOptions* compilation_options, + const std::string& model_filename, + const std::optional& config_compilation_options) -> std::string { + if (!compilation_options) { + // Return full path to original model + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + // Check if compilation is enabled for this specific model + if (!config_compilation_options.has_value()) { + // No compile options, return full path to original model + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + const auto& comp_opts = config_compilation_options.value(); + if (!comp_opts.enable_ep_context.has_value() || !comp_opts.enable_ep_context.value()) { + // Compilation not enabled, return full path to original model + fs::path full_path = config_->config_path / model_filename; + return full_path.string(); + } + + // Check if compiled model already exists and is valid + fs::path compiled_model_path; + if (CheckCompiledModelExists(ort_env, model_filename, comp_opts, compiled_model_path)) { + // Compiled model exists and is valid, return compiled path + return compiled_model_path.string(); + } + + // Set input model (from buffer if available, otherwise from file) + if (auto model_data_it = config_->model_data_spans_.find(model_filename); + model_data_it != config_->model_data_spans_.end()) { + // Compile from buffer + if (model_data_it->second.empty()) { + throw std::runtime_error("Failed to load model data from memory for " + model_filename); + } + compilation_options->SetInputModelFromBuffer( + model_data_it->second.data(), + model_data_it->second.size()); + } else { + // Compile from file + fs::path input_path = config_->config_path / model_filename; + compilation_options->SetInputModelPath(input_path.c_str()); + } + + // Set output model path - use the path from CheckCompiledModelExists + // Ensure the output directory exists + fs::path output_dir = compiled_model_path.parent_path(); + if (!fs::exists(output_dir)) { + if (!fs::create_directories(output_dir)) { + throw std::runtime_error("Failed to create output directory: " + output_dir.string()); + } + } + + compilation_options->SetOutputModelPath(compiled_model_path.c_str()); + + // Apply configuration options from config + // Set graph optimization level + if (comp_opts.graph_optimization_level.has_value()) { + compilation_options->SetGraphOptimizationLevel(comp_opts.graph_optimization_level.value()); + } + + // Set EP context embed mode + if (comp_opts.ep_context_embed_mode.has_value()) { + compilation_options->SetEpContextEmbedMode(comp_opts.ep_context_embed_mode.value()); + } + + // Set flags + if (comp_opts.flags.has_value()) { + compilation_options->SetFlags(comp_opts.flags.value()); + } + + // Set external initializers file + if (comp_opts.external_initializers_file_path.has_value() && + comp_opts.external_initializers_size_threshold.has_value()) { + fs::path external_init_path = config_->config_path / comp_opts.external_initializers_file_path.value(); + compilation_options->SetOutputModelExternalInitializersFile( + external_init_path.c_str(), + comp_opts.external_initializers_size_threshold.value()); + } + + // Compile the model + Ort::CompileModel(ort_env, *compilation_options); + + // Return the compiled model path + return compiled_model_path.string(); + }; + + // Compile the specified model with the provided compile_options + auto compilation_options = CreateModelCompilationOptions(ort_env, session_options); + std::string main_model_path = compile_model_helper(compilation_options.get(), model_filename, compile_options); + + // Additionally, compile all pipeline models that have compile_options (if primary session option) + // Use explicit pipeline session_options when present, otherwise fallback to main session_options_ + // (consistent with GetSessionOptions() at runtime). + if (is_primary_session_option) { + for (auto& pipeline_model : config_->model.decoder.pipeline) { + if (!pipeline_model.compile_options.has_value()) { + continue; + } + OrtSessionOptions* opts_to_use = nullptr; + auto session_options_it = pipeline_session_options_.find(pipeline_model.model_id); + if (session_options_it != pipeline_session_options_.end()) { + opts_to_use = session_options_it->second.get(); + } else { + opts_to_use = session_options; + } + auto pipeline_compile_options = CreateModelCompilationOptions(ort_env, opts_to_use); + if (pipeline_compile_options) { + std::string pipeline_model_path = compile_model_helper(pipeline_compile_options.get(), + pipeline_model.filename, + pipeline_model.compile_options); + pipeline_compiled_model_paths_[pipeline_model.model_id] = pipeline_model_path; + } + } + } + + // Return the main model path (original or compiled) + return main_model_path; +} + +std::unique_ptr Model::CreateSession(OrtEnv& ort_env, const std::string& model_path, OrtSessionOptions* session_options) { + + // Extract just the filename from the path for model_data_spans lookup + fs::path path_obj(model_path); + std::string model_filename = path_obj.filename(); + if (auto model_data_it = config_->model_data_spans_.find(model_filename); model_data_it != config_->model_data_spans_.end()) { // If model data was provided, load the model from memory @@ -1238,7 +1504,7 @@ std::unique_ptr Model::CreateSession(OrtEnv& ort_env, const std::str } // Otherwise, load the model from the file system - return OrtSession::Create(ort_env, (config_->config_path / fs::path(model_filename)).c_str(), session_options); + return OrtSession::Create(ort_env, fs::path(model_path).c_str(), session_options); } std::shared_ptr Model::CreateTokenizer() const { diff --git a/src/models/model.h b/src/models/model.h index 906e7e2db2..229b9013a7 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -156,10 +156,84 @@ struct Model : std::enable_shared_from_this, LeakChecked, External OrtSessionOptions* GetSessionOptions(const std::string& model_id) const; + /** \brief Gets the compiled model path for a pipeline model + * + * \param model_id The pipeline model ID + * \return The compiled model path if available, empty string otherwise + */ + std::string GetPipelineCompiledModelPath(const std::string& model_id) const; + std::unique_ptr CreateSession(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options); bool IsPruned() const; + /** \brief Returns the ORT execution provider name for the given device type if it supports EP context; empty string otherwise. + * If EP Context is enabled for any provider, please add the provider name here. + */ + static std::string EPContextSupportedProviders(DeviceType device_type) { + switch (device_type) { + case DeviceType::NvTensorRtRtx: + return "NvTensorRTRTXExecutionProvider"; + default: + return ""; + } + } + + /** \brief Compiles the specified model and optionally all pipeline models + * + * Creates compilation options from session options and compiles the models. + * Automatically configures compilation based on config settings: + * - Input: Uses model data from buffer (if available via AddModelData), otherwise from file path + * - Output: Creates "contexts" folder and saves as "{model_name}_{ep_name}_ctx.onnx", or as configured + * - Reads compilation options from config.model.*.compile_options: + * * enable_ep_context - Controls whether model compilation is performed (default: not set, no compilation) + * * graph_optimization_level + * * ep_context_file_path - Full path (relative to config path) for compiled EP context model, e.g. "contexts/model_ctx.onnx" + * * ep_context_embed_mode - How EP context is stored (embedded vs external files) + * * flags + * * external_initializers_file_path and external_initializers_size_threshold + * + * Function pointers (write_func, get_initializer_location_func) must be set programmatically. + * + * Throws an exception on error. + * + * \param ort_env The OrtEnv object + * \param model_filename The model filename to compile + * \param session_options The session options to create compilation options from + * \param is_primary_session_option If true, also compiles all pipeline models + * \param compile_options The compile options from config for the specified model + * \return The model path to use for creating session (original if not compiled, compiled path if compiled) + */ + std::string CompileModel(OrtEnv& ort_env, const std::string& model_filename, OrtSessionOptions* session_options, + bool is_primary_session_option, const std::optional& compile_options = std::nullopt); + + private: + /** \brief Checks if a compiled model exists and is valid + * + * \param ort_env OrtEnv (used for EP device / compatibility validation) + * \param model_filename The original model filename + * \param compile_options_config The compile options from config (output path, force_compile_if_needed, etc.) + * \param out_compiled_model_path Output parameter for the compiled model path (default or from config) + * \return true if compiled model exists and is valid, false otherwise + */ + bool CheckCompiledModelExists(OrtEnv& ort_env, + const std::string& model_filename, + const Config::CompileOptions& compile_options_config, + fs::path& out_compiled_model_path); + + /** \brief Validates a compiled model using EP compatibility APIs. + * Context is valid only if: (1) compatibility info is present for this EP, and + * (2) GetModelCompatibilityForEpDevices returns OPTIMAL or (PREFER_RECOMPILATION when force_compile_if_needed is false). + * All other cases return false. + * + * \param ort_env OrtEnv (for GetEpDevices) + * \param compiled_model_path Path to the compiled model file + * \param force_compile_if_needed If true, PREFER_RECOMPILATION is treated as invalid (recompile); if false, it is accepted as valid with a warning + * \return true if the compiled model is valid for the current EP (or validation not applicable) + */ + bool ValidateCompiledModel(OrtEnv& ort_env, const fs::path& compiled_model_path, bool force_compile_if_needed); + + public: std::unique_ptr config_; std::unique_ptr session_options_; std::unique_ptr arena_cfg_; @@ -174,6 +248,7 @@ struct Model : std::enable_shared_from_this, LeakChecked, External protected: void CreateSessionOptions(); + std::unique_ptr CreateModelCompilationOptions(OrtEnv& ort_env, OrtSessionOptions* session_options); void CreateSessionOptionsFromConfig(const Config::SessionOptions& config_session_options, OrtSessionOptions& session_options, @@ -181,6 +256,7 @@ struct Model : std::enable_shared_from_this, LeakChecked, External bool disable_graph_capture); std::map> pipeline_session_options_; + std::map pipeline_compiled_model_paths_; // Maps pipeline model_id to compiled model path }; } // namespace Generators diff --git a/src/models/multi_modal.cpp b/src/models/multi_modal.cpp index 384de54234..7426ae5b6e 100644 --- a/src/models/multi_modal.cpp +++ b/src/models/multi_modal.cpp @@ -65,21 +65,26 @@ MultiModalLanguageModel::MultiModalLanguageModel(std::unique_ptr config, // The non-decoder models don't support graph capture because of control flow nodes, so disable graph capture for them if (vision) { vision_session_options_ = OrtSessionOptions::Create(); - CreateSessionOptionsFromConfig(config_->model.vision.session_options.has_value() ? config_->model.vision.session_options.value() : config_->model.decoder.session_options, *vision_session_options_, true, true); - vision_session_ = CreateSession(ort_env, config_->model.vision.filename, vision_session_options_.get()); + CreateSessionOptionsFromConfig(config_->model.vision.session_options.has_value() ? config_->model.vision.session_options.value() : config_->model.decoder.session_options, *vision_session_options_, false, true); + std::string vision_model_path = CompileModel(ort_env, config_->model.vision.filename, vision_session_options_.get(), false, config_->model.vision.compile_options); + vision_session_ = CreateSession(ort_env, vision_model_path, vision_session_options_.get()); } if (speech) { speech_session_options_ = OrtSessionOptions::Create(); - CreateSessionOptionsFromConfig(config_->model.speech.session_options.has_value() ? config_->model.speech.session_options.value() : config_->model.decoder.session_options, *speech_session_options_, true, true); - speech_session_ = CreateSession(ort_env, config_->model.speech.filename, speech_session_options_.get()); + CreateSessionOptionsFromConfig(config_->model.speech.session_options.has_value() ? config_->model.speech.session_options.value() : config_->model.decoder.session_options, *speech_session_options_, false, true); + std::string speech_model_path = CompileModel(ort_env, config_->model.speech.filename, speech_session_options_.get(), false, config_->model.speech.compile_options); + speech_session_ = CreateSession(ort_env, speech_model_path, speech_session_options_.get()); } embedding_session_options_ = OrtSessionOptions::Create(); - CreateSessionOptionsFromConfig(config_->model.embedding.session_options.has_value() ? config_->model.embedding.session_options.value() : config_->model.decoder.session_options, *embedding_session_options_, true, true); + CreateSessionOptionsFromConfig(config_->model.embedding.session_options.has_value() ? config_->model.embedding.session_options.value() : config_->model.decoder.session_options, *embedding_session_options_, false, true); - embedding_session_ = CreateSession(ort_env, config_->model.embedding.filename, embedding_session_options_.get()); - decoder_session_ = CreateSession(ort_env, config_->model.decoder.filename, session_options_.get()); + std::string embedding_model_path = CompileModel(ort_env, config_->model.embedding.filename, embedding_session_options_.get(), false, config_->model.embedding.compile_options); + embedding_session_ = CreateSession(ort_env, embedding_model_path, embedding_session_options_.get()); + + std::string decoder_model_path = CompileModel(ort_env, config_->model.decoder.filename, session_options_.get(), true, config_->model.decoder.compile_options); + decoder_session_ = CreateSession(ort_env, decoder_model_path, session_options_.get()); session_info_.Add(*decoder_session_); session_info_.Add(*embedding_session_); diff --git a/src/models/onnxruntime_api.h b/src/models/onnxruntime_api.h index a9aa493131..7946761fb4 100644 --- a/src/models/onnxruntime_api.h +++ b/src/models/onnxruntime_api.h @@ -137,6 +137,15 @@ inline const OrtModelEditorApi& GetModelEditorApi() { return *model_editor_api; } +/// Returns a reference to the ORT C Compile API. Used if compiling a model at runtime. +inline const OrtCompileApi& GetCompileApi() { + auto* compile_api = api->GetCompileApi(); + if (compile_api == nullptr) { + throw std::runtime_error("Compile API is not available in this build"); + } + return *compile_api; +} + #if defined(__linux__) || defined(MACOS_USE_DLOPEN) inline std::string GetCurrentModuleDir() { Dl_info dl_info; @@ -443,6 +452,16 @@ void RegisterExecutionProviderLibrary(OrtEnv* env, const char* registration_name void UnregisterExecutionProviderLibrary(OrtEnv* env, const char* registration_name); +/** \brief Compiles an input model to generate a model with EPContext nodes + * + * Wraps OrtCompileApi::CompileModel + * Throws an exception on error + * + * \param env OrtEnv object + * \param model_compilation_options Compilation options for the model + */ +void CompileModel(OrtEnv& env, const OrtModelCompilationOptions& model_compilation_options); + } // namespace Ort /** \brief The Status that holds ownership of OrtStatus received from C API @@ -649,6 +668,94 @@ struct OrtSessionOptions { Ort::Abstract make_abstract; }; +/** \brief Options object used for model compilation + * + * Wraps ::OrtModelCompilationOptions object and methods from the Compile API + */ +struct OrtModelCompilationOptions { + /** \brief Creates OrtModelCompilationOptions from an OrtEnv and OrtSessionOptions + * + * Wraps OrtCompileApi::CreateModelCompilationOptionsFromSessionOptions + */ + static std::unique_ptr Create(OrtEnv& env, const OrtSessionOptions& session_options); + + /** \brief Sets the input ONNX model file path to compile + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetInputModelPath + */ + OrtModelCompilationOptions& SetInputModelPath(const ORTCHAR_T* input_model_path); + + /** \brief Sets the input ONNX model from a memory buffer + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetInputModelFromBuffer + */ + OrtModelCompilationOptions& SetInputModelFromBuffer(const void* input_model_data, size_t input_model_data_size); + + /** \brief Sets the output compiled model file path + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetOutputModelPath + */ + OrtModelCompilationOptions& SetOutputModelPath(const ORTCHAR_T* output_model_path); + + /** \brief Sets external initializers file path and size threshold for the compiled model + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile + */ + OrtModelCompilationOptions& SetOutputModelExternalInitializersFile(const ORTCHAR_T* external_initializers_file_path, + size_t external_initializers_size_threshold); + + /** \brief Sets the output compiled model to be written to a buffer + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetOutputModelBuffer + */ + OrtModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, + void** output_model_buffer_ptr, + size_t* output_model_buffer_size_ptr); + + /** \brief Sets a custom write function for the output compiled model + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetOutputModelWriteFunc + */ + OrtModelCompilationOptions& SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state); + + /** \brief Sets a function to determine initializer locations in the output model + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc + */ + OrtModelCompilationOptions& SetOutputModelGetInitializerLocationFunc(OrtGetInitializerLocationFunc get_initializer_location_func, + void* state); + + /** \brief Enables or disables embedding of EPContext binary data in the compiled model + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetEpContextEmbedMode + */ + OrtModelCompilationOptions& SetEpContextEmbedMode(bool embed_ep_context_in_model); + + /** \brief Sets EP context binary information (output directory and model name) + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetEpContextBinaryInformation + */ + OrtModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name); + + /** \brief Sets flags for model compilation + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetFlags + */ + OrtModelCompilationOptions& SetFlags(uint32_t flags); + + /** \brief Sets the graph optimization level for model compilation + * + * Wraps OrtCompileApi::ModelCompilationOptions_SetGraphOptimizationLevel + */ + OrtModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + + static void operator delete(void* p) { + if (p) Ort::GetCompileApi().ReleaseModelCompilationOptions(reinterpret_cast(p)); + } + Ort::Abstract make_abstract; +}; + /** \brief Wrapper around ::OrtModelMetadata * */ diff --git a/src/models/onnxruntime_inline.h b/src/models/onnxruntime_inline.h index e717d2a9bf..6490a6b4dd 100644 --- a/src/models/onnxruntime_inline.h +++ b/src/models/onnxruntime_inline.h @@ -193,6 +193,10 @@ inline void UnregisterExecutionProviderLibrary(OrtEnv* env, const char* registra ThrowOnError(Ort::api->UnregisterExecutionProviderLibrary(env, registration_name)); } +inline void CompileModel(OrtEnv& env, const OrtModelCompilationOptions& model_compilation_options) { + ThrowOnError(Ort::GetCompileApi().CompileModel(&env, &model_compilation_options)); +} + } // namespace Ort inline std::unique_ptr OrtStatus::Create(OrtErrorCode code, const std::string& what) { @@ -759,6 +763,79 @@ inline OrtSessionOptions& OrtSessionOptions::AppendExecutionProvider_V2(OrtEnv& return *this; } +/// OrtModelCompilationOptions +inline std::unique_ptr OrtModelCompilationOptions::Create(OrtEnv& env, const OrtSessionOptions& session_options) { + OrtModelCompilationOptions* p; + Ort::ThrowOnError(Ort::GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(&env, &session_options, &p)); + return std::unique_ptr(p); +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetInputModelPath(const ORTCHAR_T* input_model_path) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetInputModelPath(this, input_model_path)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetInputModelFromBuffer(const void* input_model_data, + size_t input_model_data_size) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetInputModelFromBuffer(this, input_model_data, input_model_data_size)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetOutputModelPath(const ORTCHAR_T* output_model_path) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetOutputModelPath(this, output_model_path)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetOutputModelExternalInitializersFile( + const ORTCHAR_T* external_initializers_file_path, + size_t external_initializers_size_threshold) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile( + this, external_initializers_file_path, external_initializers_size_threshold)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetOutputModelBuffer(OrtAllocator* allocator, + void** output_model_buffer_ptr, + size_t* output_model_buffer_size_ptr) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetOutputModelBuffer( + this, allocator, output_model_buffer_ptr, output_model_buffer_size_ptr)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetOutputModelWriteFunc(OrtWriteBufferFunc write_func, void* state) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetOutputModelWriteFunc(this, write_func, state)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetOutputModelGetInitializerLocationFunc( + OrtGetInitializerLocationFunc get_initializer_location_func, + void* state) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetOutputModelGetInitializerLocationFunc( + this, get_initializer_location_func, state)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetEpContextEmbedMode(this, embed_ep_context_in_model)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, + const ORTCHAR_T* model_name) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation(this, output_directory, model_name)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetFlags(uint32_t flags) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetFlags(this, flags)); + return *this; +} + +inline OrtModelCompilationOptions& OrtModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + Ort::ThrowOnError(Ort::GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this, graph_optimization_level)); + return *this; +} + /// Session inline std::unique_ptr OrtSession::Create(OrtEnv& env, const ORTCHAR_T* model_path, const OrtSessionOptions* options) { OrtSession* p; diff --git a/src/models/whisper.cpp b/src/models/whisper.cpp index 903b227c14..9929ae58b0 100644 --- a/src/models/whisper.cpp +++ b/src/models/whisper.cpp @@ -8,10 +8,13 @@ namespace Generators { WhisperModel::WhisperModel(std::unique_ptr config, OrtEnv& ort_env) : Model{std::move(config)} { encoder_session_options_ = OrtSessionOptions::Create(); - CreateSessionOptionsFromConfig(config_->model.encoder.session_options.has_value() ? config_->model.encoder.session_options.value() : config_->model.decoder.session_options, *encoder_session_options_, true, false); + CreateSessionOptionsFromConfig(config_->model.encoder.session_options.has_value() ? config_->model.encoder.session_options.value() : config_->model.decoder.session_options, *encoder_session_options_, false, false); - session_encoder_ = CreateSession(ort_env, config_->model.encoder.filename, encoder_session_options_.get()); - session_decoder_ = CreateSession(ort_env, config_->model.decoder.filename, session_options_.get()); + std::string encoder_model_path = CompileModel(ort_env, config_->model.encoder.filename, encoder_session_options_.get(), false, config_->model.encoder.compile_options); + session_encoder_ = CreateSession(ort_env, encoder_model_path, encoder_session_options_.get()); + + std::string decoder_model_path = CompileModel(ort_env, config_->model.decoder.filename, session_options_.get(), true, config_->model.decoder.compile_options); + session_decoder_ = CreateSession(ort_env, decoder_model_path, session_options_.get()); session_info_.Add(*session_decoder_); session_info_.Add(*session_encoder_);