diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs index c348184658e7e..456097ff9db9a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs @@ -27,6 +27,8 @@ public class OrtModelCompilationOptions : SafeHandle /// /// Create a new OrtModelCompilationOptions object from SessionOptions. /// + /// By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel() + /// to enable graph optimizations. /// SessionOptions instance to read settings from. public OrtModelCompilationOptions(SessionOptions sessionOptions) : base(IntPtr.Zero, true) @@ -130,6 +132,33 @@ public void SetFlags(OrtCompileApiFlags flags) NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags)); } + /// + /// Sets information related to EP context binary file. The Ep uses this information to decide the + /// location and context binary file name when compiling with both the input and output models + /// stored in buffers. + /// + /// Path to the model directory. + /// The name of the model. + public void SetEpContextBinaryInformation(string outputDirectory, string modelName) + { + var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory); + var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName); + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation( + handle, platformOutputDirectory, platformModelName)); + } + + /// + /// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified. + /// + /// The graph optimization level to set. + public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel) + { + NativeApiStatus.VerifySuccess( + NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel( + handle, graphOptimizationLevel)); + } + internal IntPtr Handle => handle; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs index 3edc25b307a21..9d25d96bdaa5a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs @@ -21,6 +21,8 @@ public struct OrtCompileApi public IntPtr ModelCompilationOptions_SetEpContextEmbedMode; public IntPtr CompileModel; public IntPtr ModelCompilationOptions_SetFlags; + public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation; + public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel; } internal class NativeMethods @@ -101,6 +103,21 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile uint flags); public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation( + IntPtr /* OrtModelCompilationOptions* */ options, + byte[] /* const ORTCHAR_T* */ outputDirectory, + byte[] /* const ORTCHAR_T* */ modelName); + public DOrtModelCompilationOptions_SetEpContextBinaryInformation + OrtModelCompilationOptions_SetEpContextBinaryInformation; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel( + IntPtr /* OrtModelCompilationOptions* */ options, + GraphOptimizationLevel graphOptimizationLevel); + public DOrtModelCompilationOptions_SetGraphOptimizationLevel + OrtModelCompilationOptions_SetGraphOptimizationLevel; + internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi) { @@ -161,6 +178,16 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi _compileApi.ModelCompilationOptions_SetFlags, typeof(DOrtModelCompilationOptions_SetFlags)); + OrtModelCompilationOptions_SetEpContextBinaryInformation = + (DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetEpContextBinaryInformation, + typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation)); + + OrtModelCompilationOptions_SetGraphOptimizationLevel = + (DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer( + _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel, + typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel)); + } } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs index bf576b54d8b45..f1eef57e03ea5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs @@ -30,6 +30,7 @@ public void BasicUsage() compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512); compileOptions.SetEpContextEmbedMode(true); + compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC); } @@ -45,6 +46,7 @@ public void BasicUsage() UIntPtr bytesSize = new UIntPtr(); var allocator = OrtAllocator.DefaultInstance; compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize); + compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx"); compileOptions.CompileModel(); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f137d88e5fb8a..76fa9ca4d5c5f 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7074,6 +7074,9 @@ struct OrtCompileApi { * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling * CompileModel. * + * \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use + * ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations. + * * \param[in] env OrtEnv object. * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. * \param[out] out The created OrtModelCompilationOptions instance. @@ -7230,7 +7233,7 @@ struct OrtCompileApi { * \since Version 1.23. */ ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, - size_t flags); + uint32_t flags); /** Sets information related to EP context binary file. * @@ -7249,6 +7252,19 @@ struct OrtCompileApi { _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_directory, _In_ const ORTCHAR_T* model_name); + + /** Set the graph optimization level. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] graph_optimization_level The graph optimization level. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 13675ab447ab1..3df6cb68a633a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1424,7 +1424,9 @@ struct ModelCompilationOptions : detail::Base { size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory, const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation - ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags + + ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel }; /** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 05c86ae4e0c58..75bb20a82d897 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1019,11 +1019,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode( return *this; } -inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) { +inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) { Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags)); return *this; } +inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel( + GraphOptimizationLevel graph_optimization_level) { + Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_, + graph_optimization_level)); + return *this; +} + namespace detail { template diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc index 59b0992d827e1..759773042debb 100644 --- a/onnxruntime/core/session/compile_api.cc +++ b/onnxruntime/core/session/compile_api.cc @@ -231,7 +231,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode } ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, - _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) { + _In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) { API_IMPL_BEGIN #if !defined(ORT_MINIMAL_BUILD) auto model_compile_options = reinterpret_cast(ort_model_compile_options); @@ -245,6 +245,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags, API_IMPL_END } +ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* ort_model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + auto model_compile_options = reinterpret_cast(ort_model_compile_options); + ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level)); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ort_model_compile_options); + ORT_UNUSED_PARAMETER(graph_optimization_level); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* ort_model_compile_options) { API_IMPL_BEGIN @@ -278,6 +294,7 @@ static constexpr OrtCompileApi ort_compile_api = { &OrtCompileAPI::ModelCompilationOptions_SetFlags, &OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation, + &OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel, }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h index 93cc5dbf20fce..51cf71cd6ec61 100644 --- a/onnxruntime/core/session/compile_api.h +++ b/onnxruntime/core/session/compile_api.h @@ -29,8 +29,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel bool embed_ep_context_in_model); ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options, - size_t flags); + uint32_t flags); ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options, _In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name); +ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ GraphOptimizationLevel graph_optimization_level); } // namespace OrtCompileAPI diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index bbb110033f54c..3ada35eeaff63 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -27,6 +27,8 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment& // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK()); + + session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only } void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) { @@ -135,7 +137,7 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m return Status::OK(); } -Status ModelCompilationOptions::SetFlags(size_t flags) { +Status ModelCompilationOptions::SetFlags(uint32_t flags) { EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options; options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS; options.action_if_no_compiled_nodes = @@ -170,6 +172,34 @@ void ModelCompilationOptions::ResetInputModelSettings() { input_model_data_size_ = 0; } +Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) { + switch (graph_optimization_level) { + case ORT_DISABLE_ALL: + // TransformerLevel::Default means that we only run required transformers. + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default; + break; + case ORT_ENABLE_BASIC: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1; + break; + case ORT_ENABLE_EXTENDED: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2; + break; + case ORT_ENABLE_LAYOUT: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3; + break; + case ORT_ENABLE_ALL: + session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel; + break; + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ", + static_cast(graph_optimization_level), " is invalid. Valid values are: ", + "ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ", + "ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99)."); + } + + return Status::OK(); +} + Status ModelCompilationOptions::ResetOutputModelSettings() { EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options; ep_context_gen_options.output_model_file_path.clear(); diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h index 2824df863013d..cd9091561af79 100644 --- a/onnxruntime/core/session/model_compilation_options.h +++ b/onnxruntime/core/session/model_compilation_options.h @@ -95,7 +95,7 @@ class ModelCompilationOptions { /// /// unsigned integer set to the bitwise OR of enabled flags. /// Status indicating success or an error - Status SetFlags(size_t flags); + Status SetFlags(uint32_t flags); /// /// Returns a reference to the session options object. @@ -129,6 +129,13 @@ class ModelCompilationOptions { /// input model buffer's size in bytes size_t GetInputModelDataSize() const; + /// + /// Sets the graph optimization level for the underlying session that compiles the model. + /// + /// The optimization level + /// + Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); + /// /// Checks if the compilation options described by this object are valid. /// diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 64c4ada07f28f..8c8ba214eb714 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -647,6 +647,7 @@ def __init__( external_initializers_file_path: str | os.PathLike | None = None, external_initializers_size_threshold: int = 1024, flags: int = C.OrtCompileApiFlags.NONE, + graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL, ): """ Creates a ModelCompiler instance. @@ -663,6 +664,8 @@ def __init__( is None or empty. Initializers larger than this threshold are stored in the external initializers file. :param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of flags in onnxruntime.OrtCompileApiFlags. + :param graph_optimization_level: The graph optimization level. + Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL. """ input_model_path: str | os.PathLike | None = None input_model_bytes: bytes | None = None @@ -694,6 +697,7 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, ) else: self._model_compiler = C.ModelCompiler( @@ -704,6 +708,7 @@ def __init__( external_initializers_file_path, external_initializers_size_threshold, flags, + graph_optimization_level, ) def compile_to_file(self, output_model_path: str | None = None): diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc index e2b069b01f95b..69929cb68a775 100644 --- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc +++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc @@ -20,7 +20,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{}); ModelCompilationOptions& compile_options = model_compiler->model_compile_options_; @@ -43,6 +44,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrTrue to embed compiled binary data into EPContext nodes. /// The file into which to store initializers for non-compiled /// nodes. - /// Flags from OrtCompileApiFlags /// Ignored if 'external_initializers_file_path' is empty. /// Initializers with a size greater than this threshold are dumped into the external file. + /// Flags from OrtCompileApiFlags + /// Optimization level for graph transformations on the model. + /// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model. /// A Status indicating error or success. static onnxruntime::Status Create(/*out*/ std::unique_ptr& out, onnxruntime::Environment& env, @@ -46,7 +48,8 @@ class PyModelCompiler { bool embed_compiled_data_into_model = false, const std::string& external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = 0); + uint32_t flags = 0, + GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL); // Note: Creation should be done via Create(). This constructor is public so that it can be called from // std::make_shared(). diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index eb06a65ad5330..27c76f7f5c482 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -2743,7 +2743,8 @@ including arg name, arg type (contains both type and shape).)pbdoc") bool embed_compiled_data_into_model = false, std::string external_initializers_file_path = {}, size_t external_initializers_size_threshold = 1024, - size_t flags = OrtCompileApiFlags_NONE) { + uint32_t flags = OrtCompileApiFlags_NONE, + GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL) { #if !defined(ORT_MINIMAL_BUILD) std::unique_ptr result; OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options, @@ -2751,7 +2752,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") embed_compiled_data_into_model, external_initializers_file_path, external_initializers_size_threshold, - flags)); + flags, graph_optimization_level)); return result; #else ORT_UNUSED_PARAMETER(sess_options); @@ -2761,6 +2762,7 @@ including arg name, arg type (contains both type and shape).)pbdoc") ORT_UNUSED_PARAMETER(external_initializers_file_path); ORT_UNUSED_PARAMETER(external_initializers_size_threshold); ORT_UNUSED_PARAMETER(flags); + ORT_UNUSED_PARAMETER(graph_optimization_level); ORT_THROW("Compile API is not supported in this build."); #endif })) diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 739e39a6975e2..a42a56492b04a 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -317,6 +317,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -355,6 +356,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) { Ort::ModelCompilationOptions compile_options(*ort_env, so); compile_options.SetInputModelPath(input_model_file); compile_options.SetOutputModelPath(output_model_file); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -393,6 +395,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelPath(output_model_file); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -427,6 +430,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) { // Create model compilation options from the session options. Output model is stored in a buffer. Ort::ModelCompilationOptions compile_options(*ort_env, so); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); compile_options.SetInputModelPath(input_model_file); Ort::AllocatorWithDefaultOptions allocator; @@ -482,6 +486,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size()); compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -515,6 +520,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin"; compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str()); compile_options.SetEpContextEmbedMode(false); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); @@ -573,6 +579,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size); compile_options.SetOutputModelExternalInitializersFile(output_initializers_file, 0); compile_options.SetEpContextEmbedMode(true); + compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); // Compile the model. Ort::Status status = Ort::CompileModel(*ort_env, compile_options); diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py index b102676860444..ed3cd882d7e00 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py +++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py @@ -225,6 +225,56 @@ def test_compile_from_buffer_to_buffer(self): self.assertTrue(isinstance(output_model_bytes, bytes)) self.assertGreater(len(output_model_bytes), 0) + def test_compile_graph_optimization_level(self): + """ + Tests compiling a model with no optimizations (default) vs all optimizations. + """ + input_model_path = get_name("test_cast_back_to_back_non_const_mixed_types_origin.onnx") + output_model_path_0 = os.path.join(self._tmp_dir_path, "cast.disable_all.compiled.onnx") + output_model_path_1 = os.path.join(self._tmp_dir_path, "cast.enable_all.compiled.onnx") + + # Local function that compiles a model with a given graph optimization level and returns + # the count of operator types in the compiled model. + def compile_and_get_op_counts( + output_model_path: str, + graph_opt_level: onnxrt.GraphOptimizationLevel | None, + ) -> dict[str, int]: + session_options = onnxrt.SessionOptions() + if graph_opt_level is not None: + model_compiler = onnxrt.ModelCompiler( + session_options, + input_model_path, + graph_optimization_level=graph_opt_level, + ) + else: + # graph optimization level defaults to ORT_DISABLE_ALL if not provided. + model_compiler = onnxrt.ModelCompiler(session_options, input_model_path) + + model_compiler.compile_to_file(output_model_path) + self.assertTrue(os.path.exists(output_model_path)) + + model: onnx.ModelProto = onnx.load(get_name(output_model_path)) + op_counts = {} + for node in model.graph.node: + if node.op_type not in op_counts: + op_counts[node.op_type] = 1 + else: + op_counts[node.op_type] += 1 + + return op_counts + + # Compile model on CPU with no graph optimizations (default). + # Model should have 9 Casts + op_counts_0 = compile_and_get_op_counts(output_model_path_0, graph_opt_level=None) + self.assertEqual(op_counts_0["Cast"], 9) + + # Compile model on CPU with ALL graph optimizations. + # Model should have less casts (optimized out) + op_counts_1 = compile_and_get_op_counts( + output_model_path_1, graph_opt_level=onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC + ) + self.assertEqual(op_counts_1["Cast"], 8) + def test_fail_load_uncompiled_model_and_then_compile(self): """ Tests compiling scenario: