diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
index c348184658e7e..456097ff9db9a 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/CompileModel.shared.cs
@@ -27,6 +27,8 @@ public class OrtModelCompilationOptions : SafeHandle
///
/// Create a new OrtModelCompilationOptions object from SessionOptions.
///
+ /// By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use SetGraphOptimizationLevel()
+ /// to enable graph optimizations.
/// SessionOptions instance to read settings from.
public OrtModelCompilationOptions(SessionOptions sessionOptions)
: base(IntPtr.Zero, true)
@@ -130,6 +132,33 @@ public void SetFlags(OrtCompileApiFlags flags)
NativeMethods.CompileApi.OrtModelCompilationOptions_SetFlags(handle, (uint)flags));
}
+ ///
+ /// Sets information related to EP context binary file. The Ep uses this information to decide the
+ /// location and context binary file name when compiling with both the input and output models
+ /// stored in buffers.
+ ///
+ /// Path to the model directory.
+ /// The name of the model.
+ public void SetEpContextBinaryInformation(string outputDirectory, string modelName)
+ {
+ var platformOutputDirectory = NativeOnnxValueHelper.GetPlatformSerializedString(outputDirectory);
+ var platformModelName = NativeOnnxValueHelper.GetPlatformSerializedString(modelName);
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetEpContextBinaryInformation(
+ handle, platformOutputDirectory, platformModelName));
+ }
+
+ ///
+ /// Sets the graph optimization level. Defaults to ORT_DISABLE_ALL if not specified.
+ ///
+ /// The graph optimization level to set.
+ public void SetGraphOptimizationLevel(GraphOptimizationLevel graphOptimizationLevel)
+ {
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.CompileApi.OrtModelCompilationOptions_SetGraphOptimizationLevel(
+ handle, graphOptimizationLevel));
+ }
+
internal IntPtr Handle => handle;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
index 3edc25b307a21..9d25d96bdaa5a 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeCompileApiMethods.shared.cs
@@ -21,6 +21,8 @@ public struct OrtCompileApi
public IntPtr ModelCompilationOptions_SetEpContextEmbedMode;
public IntPtr CompileModel;
public IntPtr ModelCompilationOptions_SetFlags;
+ public IntPtr ModelCompilationOptions_SetEpContextBinaryInformation;
+ public IntPtr ModelCompilationOptions_SetGraphOptimizationLevel;
}
internal class NativeMethods
@@ -101,6 +103,21 @@ public DOrtModelCompilationOptions_SetOutputModelExternalInitializersFile
uint flags);
public DOrtModelCompilationOptions_SetFlags OrtModelCompilationOptions_SetFlags;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetEpContextBinaryInformation(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ byte[] /* const ORTCHAR_T* */ outputDirectory,
+ byte[] /* const ORTCHAR_T* */ modelName);
+ public DOrtModelCompilationOptions_SetEpContextBinaryInformation
+ OrtModelCompilationOptions_SetEpContextBinaryInformation;
+
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DOrtModelCompilationOptions_SetGraphOptimizationLevel(
+ IntPtr /* OrtModelCompilationOptions* */ options,
+ GraphOptimizationLevel graphOptimizationLevel);
+ public DOrtModelCompilationOptions_SetGraphOptimizationLevel
+ OrtModelCompilationOptions_SetGraphOptimizationLevel;
+
internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi)
{
@@ -161,6 +178,16 @@ internal NativeMethods(OnnxRuntime.NativeMethods.DOrtGetCompileApi getCompileApi
_compileApi.ModelCompilationOptions_SetFlags,
typeof(DOrtModelCompilationOptions_SetFlags));
+ OrtModelCompilationOptions_SetEpContextBinaryInformation =
+ (DOrtModelCompilationOptions_SetEpContextBinaryInformation)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetEpContextBinaryInformation,
+ typeof(DOrtModelCompilationOptions_SetEpContextBinaryInformation));
+
+ OrtModelCompilationOptions_SetGraphOptimizationLevel =
+ (DOrtModelCompilationOptions_SetGraphOptimizationLevel)Marshal.GetDelegateForFunctionPointer(
+ _compileApi.ModelCompilationOptions_SetGraphOptimizationLevel,
+ typeof(DOrtModelCompilationOptions_SetGraphOptimizationLevel));
+
}
}
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
index bf576b54d8b45..f1eef57e03ea5 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/CompileApiTests.cs
@@ -30,6 +30,7 @@ public void BasicUsage()
compileOptions.SetOutputModelExternalInitializersFile("external_data.bin", 512);
compileOptions.SetEpContextEmbedMode(true);
+ compileOptions.SetGraphOptimizationLevel(GraphOptimizationLevel.ORT_ENABLE_BASIC);
}
@@ -45,6 +46,7 @@ public void BasicUsage()
UIntPtr bytesSize = new UIntPtr();
var allocator = OrtAllocator.DefaultInstance;
compileOptions.SetOutputModelBuffer(allocator, ref bytePtr, ref bytesSize);
+ compileOptions.SetEpContextBinaryInformation("./", "squeezenet.onnx");
compileOptions.CompileModel();
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index f137d88e5fb8a..76fa9ca4d5c5f 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -7074,6 +7074,9 @@ struct OrtCompileApi {
* ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling
* CompileModel.
*
+ * \note By default, the GraphOptimizationLevel is set to ORT_DISABLE_ALL. Use
+ * ModelCompilationOptions_SetGraphOptimizationLevel to enable graph optimizations.
+ *
* \param[in] env OrtEnv object.
* \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions.
* \param[out] out The created OrtModelCompilationOptions instance.
@@ -7230,7 +7233,7 @@ struct OrtCompileApi {
* \since Version 1.23.
*/
ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options,
- size_t flags);
+ uint32_t flags);
/** Sets information related to EP context binary file.
*
@@ -7249,6 +7252,19 @@ struct OrtCompileApi {
_In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_directory,
_In_ const ORTCHAR_T* model_name);
+
+ /** Set the graph optimization level.
+ *
+ * \param[in] model_compile_options The OrtModelCompilationOptions instance.
+ * \param[in] graph_optimization_level The graph optimization level.
+ *
+ * \snippet{doc} snippets.dox OrtStatus Return Value
+ *
+ * \since Version 1.23.
+ */
+ ORT_API2_STATUS(ModelCompilationOptions_SetGraphOptimizationLevel,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ GraphOptimizationLevel graph_optimization_level);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 13675ab447ab1..3df6cb68a633a 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -1424,7 +1424,9 @@ struct ModelCompilationOptions : detail::Base {
size_t* output_model_buffer_size_ptr); ///< Wraps OrtApi::ModelCompilationOptions_SetOutputModelBuffer
ModelCompilationOptions& SetEpContextBinaryInformation(const ORTCHAR_T* output_directory,
const ORTCHAR_T* model_name); ///< Wraps OrtApi::ModelCompilationOptions_SetEpContextBinaryInformation
- ModelCompilationOptions& SetFlags(size_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
+ ModelCompilationOptions& SetFlags(uint32_t flags); ///< Wraps OrtApi::ModelCompilationOptions_SetFlags
+
+ ModelCompilationOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); ///< Wraps OrtApi::ModelCompilationOptions_SetGraphOptimizationLevel
};
/** \brief Compiles an input model to generate a model with EPContext nodes that execute EP-specific kernels. Wraps OrtApi::CompileModels.
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 05c86ae4e0c58..75bb20a82d897 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1019,11 +1019,18 @@ inline ModelCompilationOptions& ModelCompilationOptions::SetEpContextEmbedMode(
return *this;
}
-inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(size_t flags) {
+inline ModelCompilationOptions& ModelCompilationOptions::SetFlags(uint32_t flags) {
Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetFlags(this->p_, flags));
return *this;
}
+inline ModelCompilationOptions& ModelCompilationOptions::SetGraphOptimizationLevel(
+ GraphOptimizationLevel graph_optimization_level) {
+ Ort::ThrowOnError(GetCompileApi().ModelCompilationOptions_SetGraphOptimizationLevel(this->p_,
+ graph_optimization_level));
+ return *this;
+}
+
namespace detail {
template
diff --git a/onnxruntime/core/session/compile_api.cc b/onnxruntime/core/session/compile_api.cc
index 59b0992d827e1..759773042debb 100644
--- a/onnxruntime/core/session/compile_api.cc
+++ b/onnxruntime/core/session/compile_api.cc
@@ -231,7 +231,7 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetEpContextEmbedMode
}
ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
- _In_ OrtModelCompilationOptions* ort_model_compile_options, size_t flags) {
+ _In_ OrtModelCompilationOptions* ort_model_compile_options, uint32_t flags) {
API_IMPL_BEGIN
#if !defined(ORT_MINIMAL_BUILD)
auto model_compile_options = reinterpret_cast(ort_model_compile_options);
@@ -245,6 +245,22 @@ ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetFlags,
API_IMPL_END
}
+ORT_API_STATUS_IMPL(OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
+ _In_ OrtModelCompilationOptions* ort_model_compile_options,
+ _In_ GraphOptimizationLevel graph_optimization_level) {
+ API_IMPL_BEGIN
+#if !defined(ORT_MINIMAL_BUILD)
+ auto model_compile_options = reinterpret_cast(ort_model_compile_options);
+ ORT_API_RETURN_IF_STATUS_NOT_OK(model_compile_options->SetGraphOptimizationLevel(graph_optimization_level));
+ return nullptr;
+#else
+ ORT_UNUSED_PARAMETER(ort_model_compile_options);
+ ORT_UNUSED_PARAMETER(graph_optimization_level);
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Compile API is not supported in this build");
+#endif // !defined(ORT_MINIMAL_BUILD)
+ API_IMPL_END
+}
+
ORT_API_STATUS_IMPL(OrtCompileAPI::CompileModel, _In_ const OrtEnv* env,
_In_ const OrtModelCompilationOptions* ort_model_compile_options) {
API_IMPL_BEGIN
@@ -278,6 +294,7 @@ static constexpr OrtCompileApi ort_compile_api = {
&OrtCompileAPI::ModelCompilationOptions_SetFlags,
&OrtCompileAPI::ModelCompilationOptions_SetEpContextBinaryInformation,
+ &OrtCompileAPI::ModelCompilationOptions_SetGraphOptimizationLevel,
};
// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned
diff --git a/onnxruntime/core/session/compile_api.h b/onnxruntime/core/session/compile_api.h
index 93cc5dbf20fce..51cf71cd6ec61 100644
--- a/onnxruntime/core/session/compile_api.h
+++ b/onnxruntime/core/session/compile_api.h
@@ -29,8 +29,11 @@ ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModel
bool embed_ep_context_in_model);
ORT_API_STATUS_IMPL(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_options,
- size_t flags);
+ uint32_t flags);
ORT_API_STATUS_IMPL(ModelCompilationOptions_SetEpContextBinaryInformation, _In_ OrtModelCompilationOptions* model_compile_options,
_In_ const ORTCHAR_T* output_dir, _In_ const ORTCHAR_T* model_name);
+ORT_API_STATUS_IMPL(ModelCompilationOptions_SetGraphOptimizationLevel,
+ _In_ OrtModelCompilationOptions* model_compile_options,
+ _In_ GraphOptimizationLevel graph_optimization_level);
} // namespace OrtCompileAPI
diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc
index bbb110033f54c..3ada35eeaff63 100644
--- a/onnxruntime/core/session/model_compilation_options.cc
+++ b/onnxruntime/core/session/model_compilation_options.cc
@@ -27,6 +27,8 @@ ModelCompilationOptions::ModelCompilationOptions(const onnxruntime::Environment&
// Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions.
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK());
ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionsDisableModelCompile, "0").IsOK());
+
+ session_options_.value.graph_optimization_level = TransformerLevel::Default; // L0: required transformers only
}
void ModelCompilationOptions::SetInputModelPath(const std::string& input_model_path) {
@@ -135,7 +137,7 @@ Status ModelCompilationOptions::SetEpContextEmbedMode(bool embed_ep_context_in_m
return Status::OK();
}
-Status ModelCompilationOptions::SetFlags(size_t flags) {
+Status ModelCompilationOptions::SetFlags(uint32_t flags) {
EpContextModelGenerationOptions& options = session_options_.value.ep_context_gen_options;
options.error_if_output_file_exists = flags & OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS;
options.action_if_no_compiled_nodes =
@@ -170,6 +172,34 @@ void ModelCompilationOptions::ResetInputModelSettings() {
input_model_data_size_ = 0;
}
+Status ModelCompilationOptions::SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level) {
+ switch (graph_optimization_level) {
+ case ORT_DISABLE_ALL:
+ // TransformerLevel::Default means that we only run required transformers.
+ session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Default;
+ break;
+ case ORT_ENABLE_BASIC:
+ session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level1;
+ break;
+ case ORT_ENABLE_EXTENDED:
+ session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level2;
+ break;
+ case ORT_ENABLE_LAYOUT:
+ session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::Level3;
+ break;
+ case ORT_ENABLE_ALL:
+ session_options_.value.graph_optimization_level = onnxruntime::TransformerLevel::MaxLevel;
+ break;
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "graph_optimization_level with value ",
+ static_cast(graph_optimization_level), " is invalid. Valid values are: ",
+ "ORT_DISABLE_ALL (0), ORT_ENABLE_BASIC (1), ORT_ENABLE_EXTENDED (2), ",
+ "ORT_ENABLE_LAYOUT (3), and ORT_ENABLE_ALL (99).");
+ }
+
+ return Status::OK();
+}
+
Status ModelCompilationOptions::ResetOutputModelSettings() {
EpContextModelGenerationOptions& ep_context_gen_options = session_options_.value.ep_context_gen_options;
ep_context_gen_options.output_model_file_path.clear();
diff --git a/onnxruntime/core/session/model_compilation_options.h b/onnxruntime/core/session/model_compilation_options.h
index 2824df863013d..cd9091561af79 100644
--- a/onnxruntime/core/session/model_compilation_options.h
+++ b/onnxruntime/core/session/model_compilation_options.h
@@ -95,7 +95,7 @@ class ModelCompilationOptions {
///
/// unsigned integer set to the bitwise OR of enabled flags.
/// Status indicating success or an error
- Status SetFlags(size_t flags);
+ Status SetFlags(uint32_t flags);
///
/// Returns a reference to the session options object.
@@ -129,6 +129,13 @@ class ModelCompilationOptions {
/// input model buffer's size in bytes
size_t GetInputModelDataSize() const;
+ ///
+ /// Sets the graph optimization level for the underlying session that compiles the model.
+ ///
+ /// The optimization level
+ ///
+ Status SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
+
///
/// Checks if the compilation options described by this object are valid.
///
diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py
index 64c4ada07f28f..8c8ba214eb714 100644
--- a/onnxruntime/python/onnxruntime_inference_collection.py
+++ b/onnxruntime/python/onnxruntime_inference_collection.py
@@ -647,6 +647,7 @@ def __init__(
external_initializers_file_path: str | os.PathLike | None = None,
external_initializers_size_threshold: int = 1024,
flags: int = C.OrtCompileApiFlags.NONE,
+ graph_optimization_level: C.GraphOptimizationLevel = C.GraphOptimizationLevel.ORT_DISABLE_ALL,
):
"""
Creates a ModelCompiler instance.
@@ -663,6 +664,8 @@ def __init__(
is None or empty. Initializers larger than this threshold are stored in the external initializers file.
:param flags: Additional boolean options to enable. Set this parameter to a bitwise OR of
flags in onnxruntime.OrtCompileApiFlags.
+ :param graph_optimization_level: The graph optimization level.
+ Defaults to onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL.
"""
input_model_path: str | os.PathLike | None = None
input_model_bytes: bytes | None = None
@@ -694,6 +697,7 @@ def __init__(
external_initializers_file_path,
external_initializers_size_threshold,
flags,
+ graph_optimization_level,
)
else:
self._model_compiler = C.ModelCompiler(
@@ -704,6 +708,7 @@ def __init__(
external_initializers_file_path,
external_initializers_size_threshold,
flags,
+ graph_optimization_level,
)
def compile_to_file(self, output_model_path: str | None = None):
diff --git a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
index e2b069b01f95b..69929cb68a775 100644
--- a/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
+++ b/onnxruntime/python/onnxruntime_pybind_model_compiler.cc
@@ -20,7 +20,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptr(env, sess_options, PrivateConstructorTag{});
ModelCompilationOptions& compile_options = model_compiler->model_compile_options_;
@@ -43,6 +44,8 @@ onnxruntime::Status PyModelCompiler::Create(/*out*/ std::unique_ptrTrue to embed compiled binary data into EPContext nodes.
/// The file into which to store initializers for non-compiled
/// nodes.
- /// Flags from OrtCompileApiFlags
/// Ignored if 'external_initializers_file_path' is empty.
/// Initializers with a size greater than this threshold are dumped into the external file.
+ /// Flags from OrtCompileApiFlags
+ /// Optimization level for graph transformations on the model.
+ /// Defaults to ORT_DISABLE_ALL to allow EP to get the original loaded model.
/// A Status indicating error or success.
static onnxruntime::Status Create(/*out*/ std::unique_ptr& out,
onnxruntime::Environment& env,
@@ -46,7 +48,8 @@ class PyModelCompiler {
bool embed_compiled_data_into_model = false,
const std::string& external_initializers_file_path = {},
size_t external_initializers_size_threshold = 1024,
- size_t flags = 0);
+ uint32_t flags = 0,
+ GraphOptimizationLevel graph_opt_level = GraphOptimizationLevel::ORT_DISABLE_ALL);
// Note: Creation should be done via Create(). This constructor is public so that it can be called from
// std::make_shared().
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index eb06a65ad5330..27c76f7f5c482 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -2743,7 +2743,8 @@ including arg name, arg type (contains both type and shape).)pbdoc")
bool embed_compiled_data_into_model = false,
std::string external_initializers_file_path = {},
size_t external_initializers_size_threshold = 1024,
- size_t flags = OrtCompileApiFlags_NONE) {
+ uint32_t flags = OrtCompileApiFlags_NONE,
+ GraphOptimizationLevel graph_optimization_level = GraphOptimizationLevel::ORT_DISABLE_ALL) {
#if !defined(ORT_MINIMAL_BUILD)
std::unique_ptr result;
OrtPybindThrowIfError(PyModelCompiler::Create(result, GetEnv(), sess_options,
@@ -2751,7 +2752,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
embed_compiled_data_into_model,
external_initializers_file_path,
external_initializers_size_threshold,
- flags));
+ flags, graph_optimization_level));
return result;
#else
ORT_UNUSED_PARAMETER(sess_options);
@@ -2761,6 +2762,7 @@ including arg name, arg type (contains both type and shape).)pbdoc")
ORT_UNUSED_PARAMETER(external_initializers_file_path);
ORT_UNUSED_PARAMETER(external_initializers_size_threshold);
ORT_UNUSED_PARAMETER(flags);
+ ORT_UNUSED_PARAMETER(graph_optimization_level);
ORT_THROW("Compile API is not supported in this build.");
#endif
}))
diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
index 739e39a6975e2..a42a56492b04a 100644
--- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
+++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc
@@ -317,6 +317,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_DisableEpCompile_ThenCompileExplicitly) {
Ort::ModelCompilationOptions compile_options(*ort_env, so);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
@@ -355,6 +356,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelFromPath) {
Ort::ModelCompilationOptions compile_options(*ort_env, so);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(output_model_file);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
@@ -393,6 +395,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputModelAsBuffer_Embe
compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size());
compile_options.SetOutputModelPath(output_model_file);
compile_options.SetEpContextEmbedMode(true);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
@@ -427,6 +430,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer) {
// Create model compilation options from the session options. Output model is stored in a buffer.
Ort::ModelCompilationOptions compile_options(*ort_env, so);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
compile_options.SetInputModelPath(input_model_file);
Ort::AllocatorWithDefaultOptions allocator;
@@ -482,6 +486,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
compile_options.SetInputModelFromBuffer(reinterpret_cast(model_data.data()), model_data.size());
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
compile_options.SetEpContextEmbedMode(true);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
@@ -515,6 +520,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_InputAndOutputModelsInB
std::string bin_file_name = model_name.substr(0, pos) + "_qnn.bin";
compile_options.SetEpContextBinaryInformation(ToWideString(target_dir).c_str(), ToWideString(model_name).c_str());
compile_options.SetEpContextEmbedMode(false);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
@@ -573,6 +579,7 @@ TEST_F(QnnHTPBackendTests, CompileApi_FromSessionOptions_OutputModelBuffer_Outpu
compile_options.SetOutputModelBuffer(allocator, &output_model_buffer, &output_model_buffer_size);
compile_options.SetOutputModelExternalInitializersFile(output_initializers_file, 0);
compile_options.SetEpContextEmbedMode(true);
+ compile_options.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
// Compile the model.
Ort::Status status = Ort::CompileModel(*ort_env, compile_options);
diff --git a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
index b102676860444..ed3cd882d7e00 100644
--- a/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
+++ b/onnxruntime/test/python/onnxruntime_test_python_compile_api.py
@@ -225,6 +225,56 @@ def test_compile_from_buffer_to_buffer(self):
self.assertTrue(isinstance(output_model_bytes, bytes))
self.assertGreater(len(output_model_bytes), 0)
+ def test_compile_graph_optimization_level(self):
+ """
+ Tests compiling a model with no optimizations (default) vs all optimizations.
+ """
+ input_model_path = get_name("test_cast_back_to_back_non_const_mixed_types_origin.onnx")
+ output_model_path_0 = os.path.join(self._tmp_dir_path, "cast.disable_all.compiled.onnx")
+ output_model_path_1 = os.path.join(self._tmp_dir_path, "cast.enable_all.compiled.onnx")
+
+ # Local function that compiles a model with a given graph optimization level and returns
+ # the count of operator types in the compiled model.
+ def compile_and_get_op_counts(
+ output_model_path: str,
+ graph_opt_level: onnxrt.GraphOptimizationLevel | None,
+ ) -> dict[str, int]:
+ session_options = onnxrt.SessionOptions()
+ if graph_opt_level is not None:
+ model_compiler = onnxrt.ModelCompiler(
+ session_options,
+ input_model_path,
+ graph_optimization_level=graph_opt_level,
+ )
+ else:
+ # graph optimization level defaults to ORT_DISABLE_ALL if not provided.
+ model_compiler = onnxrt.ModelCompiler(session_options, input_model_path)
+
+ model_compiler.compile_to_file(output_model_path)
+ self.assertTrue(os.path.exists(output_model_path))
+
+ model: onnx.ModelProto = onnx.load(get_name(output_model_path))
+ op_counts = {}
+ for node in model.graph.node:
+ if node.op_type not in op_counts:
+ op_counts[node.op_type] = 1
+ else:
+ op_counts[node.op_type] += 1
+
+ return op_counts
+
+ # Compile model on CPU with no graph optimizations (default).
+ # Model should have 9 Casts
+ op_counts_0 = compile_and_get_op_counts(output_model_path_0, graph_opt_level=None)
+ self.assertEqual(op_counts_0["Cast"], 9)
+
+ # Compile model on CPU with ALL graph optimizations.
+ # Model should have less casts (optimized out)
+ op_counts_1 = compile_and_get_op_counts(
+ output_model_path_1, graph_opt_level=onnxrt.GraphOptimizationLevel.ORT_ENABLE_BASIC
+ )
+ self.assertEqual(op_counts_1["Cast"], 8)
+
def test_fail_load_uncompiled_model_and_then_compile(self):
"""
Tests compiling scenario: