Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6856,6 +6856,24 @@ struct OrtCompileApi {
*/
ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options,
size_t flags);

/** Sets information related to EP context binary file.
*
* EP uses this information to decide the location and context binary file name.
* Used while compiling model with input and output in memory buffer
*
* \param[in] model_compile_options The OrtModelCompilationOptions instance.
* \param[in] output_directory Null terminated string of the path (wchar on Windows, char otherwise).
* \param[in] model_name Null terminated string of the model name (wchar on Windows, char otherwise).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetEpContextBinaryInformation,
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_directory,
_In_ const ORTCHAR_T* model_name);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,8 @@ struct ModelCompilationOptions : detail::Base<OrtModelCompilationOptions> {
size_t initializer_size_threshold); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelExternalInitializersFile
ModelCompilationOptions& SetOutputModelBuffer(OrtAllocator* allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
};

Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,15 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelPath(
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextBinaryInformation(
const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetEpContextBinaryInformation(
this->p_,
output_directory,
model_name));
return *this;
}

inline ModelCompilationOptions& ModelCompilationOptions::SetOutputModelExternalInitializersFile(
const ORTCHAR_T* file_path, size_t initializer_size_threshold) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetOutputModelExternalInitializersFile(
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/session/compile_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,35 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelPath,
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
_In_ OrtModelCompilationOptions* ort_model_compile_options,
const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast<onnxruntime::ModelCompilationOptions*>(ort_model_compile_options);

std::string output_dir = PathToUTF8String(output_directory);
if (output_dir.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid output directory: path is empty");
}

std::string model_name_str = ToUTF8String(model_name);
if (model_name_str.empty()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid model name: string is empty");
}

ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetEpContextBinaryInformation(output_dir, model_name_str));
return nullptr;
#else
ORT_UNUSED_PARAMETER(ort_model_compile_options);
ORT_UNUSED_PARAMETER(output_directory);
ORT_UNUSED_PARAMETER(model_name);
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
#endif // !defined(ORT_MINIMAL_BUILD)
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetOutputModelExternalInitializersFile,
_In_ OrtModelCompilationOptions* ort_model_compile_options,
const ORTCHAR_T* external_initializers_file_path,
Expand Down Expand Up @@ -248,6 +277,7 @@ static constexpr OrtCompileApi ort_compile_api = {
// End of Version 22 - DO NOT MODIFY ABOVE

&OrtCompileAPI::ModelCompilationOptions_SetFlags,
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
};

// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/compile_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,7 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel
ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options,
size_t flags);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name);

} // namespace OrtCompileAPI
36 changes: 33 additions & 3 deletions onnxruntime/core/session/model_compilation_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ Status ModelCompilationOptions::SetOutputModelPath(const std::string& output_mod
if (log_manager != nullptr && log_manager->HasDefaultLogger()) {
const logging::Logger& logger = log_manager->DefaultLogger();
LOGS(logger, WARNING) << "Output model path length (" << ep_context_gen_options.output_model_file_path.size()
<< ") exceeds limit of " << ConfigOptions::kMaxKeyLength << " characters."
<< "ORT will still generated the expected output file, but EPs will see an empty "
<< ") exceeds limit of " << ConfigOptions::kMaxValueLength << " characters."
<< "ORT will still generate the expected output file, but EPs will see an empty "
<< "output model path in SessionOption's ConfigOptions.";
}
}
Expand All @@ -98,6 +98,36 @@ Status ModelCompilationOptions::SetOutputModelBuffer(onnxruntime::AllocatorPtr a
return Status::OK();
}

Status ModelCompilationOptions::SetEpContextBinaryInformation(const std::string& output_directory,
const std::string& model_name) {
if (output_directory.empty() || model_name.empty()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir or model_name is empty.");
}

std::filesystem::path output_dir_path(output_directory);
if (output_dir_path.has_filename() && output_dir_path.extension() == "") {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "output_dir is not a valid directory.");
}

std::filesystem::path ctx_model_path = output_directory / std::filesystem::path(model_name);

if (ctx_model_path.string().size() <= ConfigOptions::kMaxValueLength) {
ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath,
ctx_model_path.string().c_str()));
} else {
logging::LoggingManager* log_manager = env_.GetLoggingManager();
if (log_manager != nullptr && log_manager->HasDefaultLogger()) {
const logging::Logger& logger = log_manager->DefaultLogger();
LOGS(logger, WARNING) << "output_directory length with model_name length together exceeds limit of "
<< ConfigOptions::kMaxValueLength << " characters."
<< "ORT will still generate the expected output file, but EPs will see an empty "
<< "output path in SessionOption's ConfigOptions.";
}
}

return Status::OK();
}

Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_model) {
ORT_RETURN_IF_ERROR(session_options_.value.config_options.AddConfigEntry(
kOrtSessionOptionEpContextEmbedMode, embed_ep_context_in_model ? "1" : "0"));
Expand Down Expand Up @@ -146,7 +176,7 @@ Status ModelCompilationOptions::ResetOutputModelSettings() {
ep_context_gen_options.output_model_buffer_ptr = nullptr;
ep_context_gen_options.output_model_buffer_size_ptr = nullptr;
ep_context_gen_options.output_model_buffer_allocator = nullptr;
return session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, "");
return Status::OK();
}

Status ModelCompilationOptions::CheckInputModelSettings() const {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/model_compilation_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ class ModelCompilationOptions {
Status SetOutputModelBuffer(onnxruntime::AllocatorPtr allocator, void** output_model_buffer_ptr,
size_t* output_model_buffer_size_ptr);

/// <summary>
/// Sets information relate to EP context binary file.
/// EP use this information to decide the location and context binary file name.
/// Used while compiling model with input and output in memory buffer
/// </summary>
/// <param name="output_directory">The folder path to the generated context binary file</param>
/// <param name="model_name">Model name used to decide the context binary file name: [model_name]_[ep].bin</param>
/// <returns>Status indicating potential error</returns>
Status SetEpContextBinaryInformation(const std::string& output_directory, const std::string& model_name);

/// <summary>
/// Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute of EPContext
/// nodes. Defaults to false (dumped to file).
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,11 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetInputModelFromBuffer(reinterpret_cast<const void*>(model_data.data()), model_data.size());
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
std::string target_dir = "./testdata/";
std::string model_name = "test_model_in_mem.onnx";
auto pos = model_name.rfind(".onnx");
std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin";
compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str());
compile_options.SetEpContextEmbedMode(false);

// Compile the model.
Expand All @@ -519,12 +524,18 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
ASSERT_TRUE(output_model_buffer != nullptr);
ASSERT_TRUE(output_model_buffer_size > 0);

ASSERT_TRUE(std::filesystem::exists(target_dir + bin_file_name)) << "expected context binary file should exist";

// Check that the compiled model has the expected number of EPContext nodes.
CheckEpContextNodeCounts(output_model_buffer, output_model_buffer_size, 2, 2);

// Add session option "ep.context_file_path" so that the session can use it to locate the [model_name]_qnn.bin file
std::string ctx_model = target_dir + model_name;
session_options.AddConfigEntry(kOrtSessionOptionEpContextFilePath, ctx_model.c_str());
// Should be able to create a session with the compiled model and the original session options.
EXPECT_NO_THROW((Ort::Session(*ort_env, output_model_buffer, output_model_buffer_size, session_options)));

std::filesystem::remove(target_dir + bin_file_name);
allocator.Free(output_model_buffer);
}
}
Expand Down
Loading