diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index b64a5c3e5a4a2..77c35aac65b92 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -3,6 +3,7 @@
using System;
using System.Runtime.InteropServices;
+using static Microsoft.ML.OnnxRuntime.NativeMethods;
namespace Microsoft.ML.OnnxRuntime
{
@@ -325,6 +326,16 @@ public struct OrtApi
public IntPtr CreateLoraAdapterFromArray;
public IntPtr ReleaseLoraAdapter;
public IntPtr RunOptionsAddActiveLoraAdapter;
+ public IntPtr SetEpDynamicOptions;
+ public IntPtr ReleaseValueInfo;
+ public IntPtr ReleaseNode;
+ public IntPtr ReleaseGraph;
+ public IntPtr ReleaseModel;
+ public IntPtr GetValueInfoName;
+ public IntPtr GetValueInfoTypeInfo;
+ public IntPtr GetModelEditorApi;
+ public IntPtr CreateTensorWithDataAndDeleterAsOrtValue;
+ public IntPtr SessionOptionsSetLoadCancellationFlag;
}
internal static class NativeMethods
@@ -404,6 +415,7 @@ static NativeMethods()
OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions));
OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions));
OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode));
+ OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag));
OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath));
OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling));
OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling));
@@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
ExecutionMode execution_mode);
public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode;
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options,
+ bool value);
+ public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag;
+
+
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath);
public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath;
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index bd450451a1265..9b0f183f03681 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode
}
private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL;
+ ///
+ /// Sets the load cancellation flag for the session. Default is set to false.
+ /// Provides an opportunity for the user to cancel model loading.
+ ///
+ /// true to request cancellation, false to proceed
+ public void SetLoadCancellationFlag(bool value)
+ {
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value));
+ }
+
#endregion
#region Private Methods
diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h
index 0822eba950f50..10f658f52e0d9 100644
--- a/include/onnxruntime/core/common/common.h
+++ b/include/onnxruntime/core/common/common.h
@@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
abort(); \
} while (false)
+#define ORT_THROW_FROM_STATUS(status) \
+ do { \
+ ::onnxruntime::PrintFinalMessage( \
+ ::onnxruntime::OnnxRuntimeException( \
+ ORT_WHERE_WITH_STACK, status.ToString()) \
+ .what()); \
+ abort(); \
+ } while (false)
+
+#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
+ do { \
+ ::onnxruntime::PrintFinalMessage( \
+ ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
+ ::onnxruntime::MakeString(__VA_ARGS__), \
+ ::onnxruntime::common::category, \
+ ::onnxruntime::common::code) \
+ .what()); \
+ abort(); \
+ } while (false)
+
#else
#define ORT_TRY try
@@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
#define ORT_THROW_EX(ex, ...) \
throw ex(__VA_ARGS__)
+#define ORT_THROW_FROM_STATUS(status) \
+ throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \
+ static_cast<::onnxruntime::common::StatusCode>(status.Code()))
+
+#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
+ throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
+ ::onnxruntime::MakeString(__VA_ARGS__), \
+ ::onnxruntime::common::category, \
+ ::onnxruntime::common::code)
+
#endif
#define ORT_MAKE_STATUS(category, code, ...) \
@@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
auto _status = (expr); \
if ((!_status.IsOK())) { \
::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \
- ORT_THROW(_status); \
+ ORT_THROW_FROM_STATUS(_status); \
} \
} while (0)
diff --git a/include/onnxruntime/core/common/exceptions.h b/include/onnxruntime/core/common/exceptions.h
index 494a770b8db98..6d0f6edd6e7c4 100644
--- a/include/onnxruntime/core/common/exceptions.h
+++ b/include/onnxruntime/core/common/exceptions.h
@@ -11,6 +11,7 @@
#include
#include "core/common/common.h"
+#include "core/common/status.h"
#include "core/common/code_location.h"
namespace onnxruntime {
@@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception {
/**
Create a new exception that captures the location it was thrown from.
@param location Location in the source code the exception is being thrown from
+ @param msg Message containing additional information about the exception cause.
+ @param category Error category
+ @param code Error code
+ */
+
+ OnnxRuntimeException(const CodeLocation& location,
+ const std::string& message,
+ common::StatusCategory category,
+ common::StatusCode code) noexcept
+ : OnnxRuntimeException(location, nullptr, message, category, code) {
+ }
+
+ /**
+ Create a new exception that captures the location it was thrown from.
+ The instance will be created with ONNXRUNTIME category and FAIL code.
+ @param location Location in the source code the exception is being thrown from
@param failed_condition Optional string containing the condition that failed.
e.g. "tensor.Size() == input.Size()". May be nullptr.
@param msg Message containing additional information about the exception cause.
*/
- OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg)
- : location_{location} {
+ OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept
+ : OnnxRuntimeException(location, failed_condition, msg,
+ common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) {
+ }
+
+ /**
+ Create a new exception that captures the location it was thrown from.
+ @param location Location in the source code the exception is being thrown from
+ @param failed_condition Optional string containing the condition that failed.
+ e.g. "tensor.Size() == input.Size()". May be nullptr.
+ @param msg Message containing additional information about the exception cause.
+ @param category Error category
+ @param code Error code
+ */
+ OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg,
+ common::StatusCategory category,
+ common::StatusCode code)
+ : location_{location}, category_(category), code_(code) {
std::ostringstream ss;
ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous
@@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception {
what_ = ss.str();
}
+ common::StatusCategory Category() const noexcept {
+ return category_;
+ }
+
+ common::StatusCode Code() const noexcept {
+ return code_;
+ }
+
const char* what() const noexcept override {
return what_.c_str();
}
@@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception {
const CodeLocation location_;
const std::vector stacktrace_;
std::string what_;
+ common::StatusCategory category_;
+ common::StatusCode code_;
};
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h
index 8f171daabbb1e..b222e411d7804 100644
--- a/include/onnxruntime/core/common/status.h
+++ b/include/onnxruntime/core/common/status.h
@@ -43,7 +43,8 @@ enum StatusCode {
MODEL_LOADED = 8,
NOT_IMPLEMENTED = 9,
INVALID_GRAPH = 10,
- EP_FAIL = 11
+ EP_FAIL = 11,
+ MODEL_LOAD_CANCELED = 12,
};
constexpr const char* StatusCodeToString(StatusCode status) noexcept {
@@ -72,6 +73,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept {
return "INVALID_GRAPH";
case StatusCode::EP_FAIL:
return "EP_FAIL";
+ case StatusCode::MODEL_LOAD_CANCELED:
+ return "MODEL_LOAD_CANCELED";
default:
return "GENERAL ERROR";
}
@@ -104,6 +107,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept {
return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
case StatusCode::EP_FAIL:
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
+ case StatusCode::MODEL_LOAD_CANCELED:
+ return HRESULT_FROM_WIN32(ERROR_CANCELLED);
default:
return E_FAIL;
}
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index 6d4cc8a1f2fa9..3bf0d5e19c525 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -255,6 +255,7 @@ typedef enum OrtErrorCode {
ORT_NOT_IMPLEMENTED,
ORT_INVALID_GRAPH,
ORT_EP_FAIL,
+ ORT_MODEL_LOAD_CANCELED,
} OrtErrorCode;
typedef enum OrtOpAttrType {
@@ -4898,6 +4899,24 @@ struct OrtApi {
_In_ const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type,
_Outptr_ OrtValue** out);
+
+ /** \brief sets load cancellation flag to abort session loading process.
+ *
+ * \param[in] options instance that was passed to the session at creation time.
+ * \param[in] cancel setting this to true after model loading process was initiated will
+ * attempt to cancel the loading process. If cancellation is successful, CreateSession()
+ * CreateSessionFromArray() or any other session creation API that take session options as an
+ * argument will return an OrtStatus indicating that session loading was canceled at user request,
+ * error code ORT_MODEL_LOAD_CANCELED.
+ * The APIs above would not return any valid Session instance. This is the best case effort and the result
+ * is not guaranteed. The session may have already been created and initialized
+ * before the cancellation request was issued.
+ *
+ * \snippet{doc} snippets.dox OrtStatus
+ *
+ */
+ ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
+ _In_ bool cancel);
};
/*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 979b478e2fbb4..ce7dc1c45b05e 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl {
SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode
+ SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag
+
SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 48c5e52e33c53..524e3ecc92936 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -747,6 +747,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionM
return *this;
}
+template
+inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) {
+ ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value));
+ return *this;
+}
+
template
inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) {
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h
index 703d183ea5c87..b42c6a9ba3e10 100644
--- a/onnxruntime/core/framework/error_code_helper.h
+++ b/onnxruntime/core/framework/error_code_helper.h
@@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c
#ifndef ORT_NO_EXCEPTIONS
#define API_IMPL_BEGIN try {
-#define API_IMPL_END \
- } \
- catch (const onnxruntime::NotImplementedException& ex) { \
- return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
- } \
- catch (const std::exception& ex) { \
- return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
- } \
- catch (...) { \
- return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
+#define API_IMPL_END \
+ } \
+ catch (const onnxruntime::OnnxRuntimeException& ex) { \
+ return OrtApis::CreateStatus(static_cast(ex.Code()), ex.what()); \
+ } \
+ catch (const onnxruntime::NotImplementedException& ex) { \
+ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
+ } \
+ catch (const std::exception& ex) { \
+ return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
+ } \
+ catch (...) { \
+ return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
}
#else
diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc
index ff4d300f665b1..50f14104cfd7a 100644
--- a/onnxruntime/core/framework/graph_partitioner.cc
+++ b/onnxruntime/core/framework/graph_partitioner.cc
@@ -56,6 +56,7 @@ namespace {
// contains some common parameters used by the partitioning helper functions
struct PartitionParams {
std::reference_wrapper graph;
+ std::reference_wrapper check_load_cancellation_fn;
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
std::reference_wrapper func_mgr;
std::reference_wrapper fused_kernel_registry;
@@ -143,6 +144,7 @@ struct GetCapabilityForEPParams {
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
IResourceAccountant* resource_accountant;
std::reference_wrapper graph_optimizer_registry;
+ std::reference_wrapper check_load_cancellation_fn;
};
auto get_capabilities = [](const IExecutionProvider& ep,
@@ -188,7 +190,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
{
const GraphViewer graph_viewer(graph);
- capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry);
+ capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant,
+ graph_optimizer_registry);
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Graph partitioning was canceled by user request");
+ }
if (capabilities.empty()) {
return Status::OK();
@@ -209,6 +216,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
// Perform layout transformation on the specific EP assigned graph
bool modified = false;
ORT_RETURN_IF_ERROR(params.transform_layout(graph, modified, current_ep, params.debug_graph_fn));
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "GetCapabilities was canceled by user request");
+ }
// It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes
// which are reconstructed to update domain or completely new nodes which are necessary for layout transformation.
@@ -226,7 +237,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
capabilities.clear();
const GraphViewer graph_viewer(graph);
- capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry);
+ capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant,
+ graph_optimizer_registry);
+ if (params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "GetCapabilities was canceled by user request");
+ }
// all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities
InlinedHashSet new_nodes_in_capabilities;
@@ -405,6 +421,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
int& fused_node_unique_id,
const layout_transformation::TransformLayoutFunction& transform_layout_fn,
const layout_transformation::DebugGraphFn& debug_graph_fn,
+ const CheckLoadCancellationFn& check_load_cancellation_fn,
const logging::Logger& logger, IResourceAccountant* resource_accountant,
const GraphOptimizerRegistry& graph_optimizer_registry) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@@ -420,7 +437,10 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
// we pass through the FuncManager from the top level graph
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr,
fused_kernel_registry, current_ep, mode, fused_node_unique_id,
- transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry));
+ transform_layout_fn, debug_graph_fn,
+ check_load_cancellation_fn,
+ logger, resource_accountant,
+ graph_optimizer_registry));
}
}
@@ -445,7 +465,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
std::cref(transform_layout_fn),
std::cref(debug_graph_fn),
resource_accountant,
- std::ref(graph_optimizer_registry)};
+ std::ref(graph_optimizer_registry),
+ std::cref(check_load_cancellation_fn)};
ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger));
if (capabilities.empty()) {
@@ -532,6 +553,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr,
}
ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));
+ ORT_RETURN_IF(check_load_cancellation_fn(),
+ "Graph partitioning is canceled due to user request.");
if (node_compute_funcs.size() != nodes_to_compile.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
@@ -633,6 +656,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
Graph& graph,
const GraphOptimizerRegistry& graph_optimizer_registry,
const logging::Logger& logger,
+ const CheckLoadCancellationFn& check_load_cancellation_fn,
InlinedHashSet& not_inlined,
size_t& inlined_count) {
// handle testing edge case where optimizers or constant lifting results in graph with no nodes.
@@ -650,6 +674,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
*subgraph,
graph_optimizer_registry,
logger,
+ check_load_cancellation_fn,
not_inlined,
inlined_count));
}
@@ -673,8 +698,13 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
InlinedHashSet claimed_by_ep;
for (const auto& ep : execution_providers) {
std::vector> capabilities;
- ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger,
+ ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep,
+ graph_optimizer_registry, logger,
capabilities));
+ if (check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request.");
+ }
+
for (auto& capability : capabilities) {
const auto& nodes = capability->sub_graph->nodes;
if (nodes.size() == 1) {
@@ -707,6 +737,9 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id)));
}
}
+ if (check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request.");
+ }
}
return Status::OK();
@@ -846,6 +879,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
auto& fused_kernel_registry = partition_params.fused_kernel_registry.get();
auto& fused_node_unique_id = partition_params.fused_node_unique_id.get();
const auto& transform_layout_function = partition_params.transform_layout_function;
+ const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn;
do {
// process full graph with each EP
@@ -861,6 +895,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
fused_kernel_registry, *ep, mode, fused_node_unique_id,
transform_layout_function,
partition_params.debug_graph_fn,
+ check_load_cancellation_fn,
logger, resource_accountant, graph_optimizer_registry));
}
@@ -915,7 +950,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
std::cref(partition_params.debug_graph_fn),
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
nullptr,
- std::ref(graph_optimizer_registry)
+ std::ref(graph_optimizer_registry),
+ partition_params.check_load_cancellation_fn
};
// clang-format on
@@ -972,6 +1008,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param
std::vector single_node_compute_func;
ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}},
single_node_compute_func));
+ if (partition_params.check_load_cancellation_fn()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning is canceled due to user request.");
+ }
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element.");
auto& func_mgr = partition_params.func_mgr.get();
@@ -1032,6 +1071,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
return Status::OK();
}
+ auto check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); };
+
auto& graph = model.MainGraph();
InlinedHashSet not_inlined;
do {
@@ -1041,13 +1082,13 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,
graph,
*graph_optimizer_registry_,
logger,
+ check_load_cancellation_fn,
not_inlined,
inlined_count));
if (inlined_count == 0) {
break;
}
-
ORT_RETURN_IF_ERROR(graph.Resolve());
} while (true);
@@ -1082,6 +1123,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
}
+ CheckLoadCancellationFn check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); };
+
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// fused_kernel_registry is preparing the kernels created on the fly for fused sub graph.
// It is only visible for current session.
@@ -1092,6 +1135,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
PartitionParams partition_params{
std::ref(graph),
+ std::cref(check_load_cancellation_fn),
std::ref(func_mgr),
std::ref(*fused_kernel_registry),
std::ref(fused_node_unique_id),
@@ -1105,6 +1149,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
ORT_UNUSED_PARAMETER(debug_graph_fn);
PartitionParams partition_params{
std::ref(graph),
+ std::cref(check_load_cancellation_fn),
};
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h
index b9d4022cb5a14..87edc7a64c6b5 100644
--- a/onnxruntime/core/framework/graph_partitioner.h
+++ b/onnxruntime/core/framework/graph_partitioner.h
@@ -33,6 +33,16 @@ class GraphPartitioner {
graph_optimizer_registry_(std::move(graph_optimizer_registry)) {
}
+ GraphPartitioner(KernelRegistryManager& kernel_registry_mgr,
+ const ExecutionProviders& providers,
+ std::unique_ptr graph_optimizer_registry,
+ CheckLoadCancellationFn check_load_cancellation_fn)
+ : kernel_registry_mgr_(kernel_registry_mgr),
+ providers_(providers),
+ graph_optimizer_registry_(std::move(graph_optimizer_registry)),
+ check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) {
+ }
+
// Run partitioning.
Status Partition(Graph& graph, FuncManager& func_mgr,
const layout_transformation::TransformLayoutFunction& transform_layout_function,
@@ -41,6 +51,10 @@ class GraphPartitioner {
Mode mode = Mode::kNormal,
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;
+ bool IsLoadCancellationFlagSet() const {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
#ifndef ORT_MINIMAL_BUILD
///
// Ahead of Time Function inlining. The main purpose of the function is to inline as many
@@ -69,6 +83,7 @@ class GraphPartitioner {
KernelRegistryManager& kernel_registry_mgr_;
const ExecutionProviders& providers_;
std::unique_ptr graph_optimizer_registry_;
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h
index 8d4db36106f28..ef323b99b006c 100644
--- a/onnxruntime/core/framework/session_options.h
+++ b/onnxruntime/core/framework/session_options.h
@@ -8,6 +8,7 @@
#include
#include
#include
+#include
#include
#include "core/common/inlined_containers.h"
#include "core/framework/config_options.h"
@@ -66,6 +67,8 @@ struct FreeDimensionOverride {
int64_t dim_value;
};
+using CheckLoadCancellationFn = std::function;
+
/**
* Configuration information for a session.
*/
@@ -184,6 +187,18 @@ struct SessionOptions {
// User specified logging func and param
OrtLoggingFunction user_logging_function = nullptr;
void* user_logging_param = nullptr;
+
+ void SetLoadCancellationFlag(bool value) noexcept {
+ *load_cancellation_flag = value;
+ }
+
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return *load_cancellation_flag;
+ }
+
+ // Load cancellation flag is necessary to be within shared memory as session_options are
+ // copied internally and the flag needs to be accessible across all copies.
+ std::shared_ptr load_cancellation_flag = std::make_shared(false);
};
inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) {
diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc
index d174d6cc72ead..6362a3169f3a3 100644
--- a/onnxruntime/core/framework/session_state.cc
+++ b/onnxruntime/core/framework/session_state.cc
@@ -422,6 +422,10 @@ Status SessionState::PrepackConstantInitializedTensors(
auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map](
bool should_cache_prepacked_weights_for_shared_initializers) -> Status {
for (auto& node : GetGraphViewer().Nodes()) {
+ if (sess_options_.IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Weight pre-packing was canceled due to user request.");
+ }
auto kernel = GetMutableKernel(node.Index());
int input_idx = 0;
for (auto& input_def : node.InputDefs()) {
@@ -1541,6 +1545,11 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringname();
diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index 39ffc6a5b0cee..334ecb3887d14 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model,
#endif
}
+ if (owning_model_.IsLoadCancellationFlagSet()) {
+ ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request.");
+ }
+
// Remove constant nodes as they're replaced with initializers above.
const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()};
graph_mutable_nodes->erase(
@@ -1365,6 +1369,10 @@ Graph::Graph(const Model& owning_model,
}
}
+ if (owning_model_.IsLoadCancellationFlagSet()) {
+ ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request.");
+ }
+
for (auto& graph_output : graph_proto_->output()) {
if (utils::HasName(graph_output) && utils::HasType(graph_output)) {
auto& name = graph_output.name();
diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc
index 7629e40c1b5fe..436af7115eb1a 100644
--- a/onnxruntime/core/graph/model.cc
+++ b/onnxruntime/core/graph/model.cc
@@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name,
const std::vector& model_local_functions,
const logging::Logger& logger,
const ModelOptions& options)
- : model_path_(model_path) {
+ : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) {
model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
model_proto_.mutable_graph()->set_name(graph_name);
model_metadata_ = model_metadata;
@@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path,
Model::Model(ModelProto&& model_proto, const PathString& model_path,
const IOnnxRuntimeOpSchemaRegistryList* local_registries,
const logging::Logger& logger, const ModelOptions& options)
- : model_path_(model_path) {
+ : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) {
if (!utils::HasGraph(model_proto)) {
ORT_THROW("ModelProto does not have a graph.");
}
@@ -435,6 +435,11 @@ Status Model::Load(const ModelProto& model_proto,
ORT_TRY {
model = std::make_unique(model_proto, model_path, local_registries, logger, options);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
@@ -474,6 +479,11 @@ Status Model::Load(ModelProto&& model_proto,
ORT_TRY {
model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what()));
@@ -509,6 +519,11 @@ static Status LoadModelHelper(const T& file_path, Loader loader) {
ORT_TRY {
status = loader(fd);
}
+ ORT_CATCH(const OnnxRuntimeException& ex) {
+ ORT_HANDLE_EXCEPTION([&]() {
+ status = Status(ex.Category(), ex.Code(), ex.what());
+ });
+ }
ORT_CATCH(const std::exception& ex) {
ORT_HANDLE_EXCEPTION([&]() {
status = Status(ONNXRUNTIME, FAIL, ex.what());
diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h
index 6fd94c60d6b99..70f82bcfb160b 100644
--- a/onnxruntime/core/graph/model.h
+++ b/onnxruntime/core/graph/model.h
@@ -11,6 +11,7 @@
#include "core/common/flatbuffers.h"
+#include "core/framework/session_options.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/ort_format_load_options.h"
#include "core/session/onnxruntime_c_api.h"
@@ -38,6 +39,14 @@ struct ModelOptions {
// be returned.
bool strict_shape_type_inference;
+ CheckLoadCancellationFn check_load_cancellation_fn;
+
+ ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference,
+ CheckLoadCancellationFn check_load_cancellation_fn)
+ : allow_released_opsets_only(allow_released_opsets_only),
+ strict_shape_type_inference(strict_shape_type_inference),
+ check_load_cancellation_fn(std::move(check_load_cancellation_fn)) {}
+
ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference)
: allow_released_opsets_only(allow_released_opsets_only),
strict_shape_type_inference(strict_shape_type_inference) {}
@@ -102,6 +111,11 @@ class Model {
#endif // !defined(ORT_MINIMAL_BUILD)
+ // Check for load cancellation.
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
#if !defined(ORT_MINIMAL_BUILD)
// Get model's IR version.
// Return if not specified.
@@ -343,5 +357,7 @@ class Model {
// Main graph of the model.
std::unique_ptr graph_;
+
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc
index 039283bb2d4e1..83c3f70799987 100644
--- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc
+++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc
@@ -27,6 +27,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor
}
for (unsigned step = 0; step < steps_; ++step) {
+ if (IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph transformation canceled due to user request.");
+ }
bool graph_changed = false;
for (const auto& transformer : transformers->second) {
if (step > 0 && transformer->ShouldOnlyApplyOnce())
diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h
index ed66302434ab2..eab57f12bfcbb 100644
--- a/onnxruntime/core/optimizer/graph_transformer_mgr.h
+++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h
@@ -24,6 +24,16 @@ class GraphTransformerManager {
// Get the maximum number of graph transformation steps
common::Status GetSteps(unsigned& steps) const;
+ // Set the cancellation flag ptr from session_options
+ void SetLoadCancellationFn(CheckLoadCancellationFn check_load_cancellation_fn) {
+ check_load_cancellation_fn_ = std::move(check_load_cancellation_fn);
+ }
+
+ // Get the cancellation flag ptr
+ bool IsLoadCancellationFlagSet() const noexcept {
+ return check_load_cancellation_fn_ && check_load_cancellation_fn_();
+ }
+
// Register a transformer with a level.
common::Status Register(std::unique_ptr transformer, TransformerLevel level);
@@ -38,5 +48,6 @@ class GraphTransformerManager {
InlinedHashMap>> level_to_transformer_map_;
InlinedHashMap transformers_info_;
+ CheckLoadCancellationFn check_load_cancellation_fn_;
};
} // namespace onnxruntime
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index 2e733f67a888c..e50ee5738c30e 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -340,3 +340,11 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions*
return nullptr;
API_IMPL_END
}
+
+ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
+ _In_ bool is_cancel) {
+ API_IMPL_BEGIN
+ options->value.SetLoadCancellationFlag(is_cancel);
+ return nullptr;
+ API_IMPL_END
+}
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index e5ea562ce3535..1bcb651ec9605 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -383,6 +383,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
#if !defined(ORT_MINIMAL_BUILD)
// Update the number of steps for the graph transformer manager using the "finalized" session options
ORT_THROW_IF_ERROR(graph_transformer_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps));
+ graph_transformer_mgr_.SetLoadCancellationFn(this->check_load_cancellation_fn_);
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@@ -1004,11 +1005,13 @@ common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) {
std::copy(std::begin(interop_domains_), std::end(interop_domains_), std::back_inserter(domain_ptrs));
ORT_RETURN_IF_ERROR(AddCustomOpDomains(domain_ptrs));
#endif
+
const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1";
return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr,
*session_logger_,
- ModelOptions(true, strict_shape_type_inference));
+ ModelOptions(true, strict_shape_type_inference,
+ check_load_cancellation_fn_));
};
common::Status st = LoadWithLoader(loader, "model_loading_uri");
@@ -1101,7 +1104,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len
return onnxruntime::Model::Load(std::move(model_proto), model_location_, model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_,
- ModelOptions(true, strict_shape_type_inference));
+ ModelOptions(true, strict_shape_type_inference,
+ check_load_cancellation_fn_));
};
return LoadWithLoader(loader, "model_loading_array");
@@ -1139,7 +1143,8 @@ common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) {
// This call will move model_proto to the constructed model instance
return onnxruntime::Model::Load(std::move(model_proto), model_location_, model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_,
- ModelOptions(true, strict_shape_type_inference));
+ ModelOptions(true, strict_shape_type_inference,
+ check_load_cancellation_fn_));
};
return LoadWithLoader(loader, "model_loading_proto");
@@ -1172,7 +1177,8 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re
const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1";
ModelOptions model_opts(allow_released_opsets_only,
- strict_shape_type_inference);
+ strict_shape_type_inference,
+ check_load_cancellation_fn_);
std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault(
kOrtSessionOptionsModelExternalInitializersFileFolderPath, "");
@@ -1211,7 +1217,8 @@ common::Status InferenceSession::Load() {
// Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage)
return Model::Load(std::move(this->model_proto_), model_location_, model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_,
- ModelOptions(allow_released_opsets_only, strict_shape_type_inference));
+ ModelOptions(allow_released_opsets_only, strict_shape_type_inference,
+ check_load_cancellation_fn_));
};
return LoadWithLoader(loader, "model_loading_from_saved_proto");
@@ -1239,7 +1246,8 @@ common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) {
std::unique_ptr tmp_model;
ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model,
HasLocalSchema() ? &custom_schema_registries_ : nullptr,
- ModelOptions(true, strict_shape_type_inference),
+ ModelOptions(true, strict_shape_type_inference,
+ check_load_cancellation_fn_),
*session_logger_, tmp_model));
model_ = std::move(tmp_model);
@@ -1283,7 +1291,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool
auto graph_optimizer_registry = std::make_unique(&session_options_,
execution_providers_.Get(onnxruntime::kCpuExecutionProvider),
session_logger_);
- GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry));
+ GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry),
+ check_load_cancellation_fn_);
// Run Ahead Of time function inlining
if (const bool disable_aot_function_inlining =
@@ -1711,7 +1720,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
providers.Get(onnxruntime::kCpuExecutionProvider),
&logger);
- GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry));
+ GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry),
+ [&sess_options]() -> bool { return sess_options.IsLoadCancellationFlagSet(); });
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
session_state.GetMutableFuncMgr(),
transform_layout_fn,
@@ -1784,6 +1794,11 @@ common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() con
#pragma warning(disable : 26117)
#endif
common::Status InferenceSession::Initialize() {
+ if (session_options_.IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Session initialization canceled due to user request.");
+ }
+
Status status = Status::OK();
TimePoint tp;
if (session_profiler_.IsEnabled()) {
@@ -2009,6 +2024,10 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());
+ if (session_options_.IsLoadCancellationFlagSet()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED,
+ "Session initialization canceled due to user request.");
+ }
// Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP.
//
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index 5b484103c9ecf..7b5d98c38a0fa 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -781,6 +781,10 @@ class InferenceSession {
// the session options are released after the individual operators are destroyed.
SessionOptions session_options_;
+ CheckLoadCancellationFn check_load_cancellation_fn_ = [this]() {
+ return session_options_.IsLoadCancellationFlagSet();
+ };
+
/// Logging manager if provided.
logging::LoggingManager* logging_manager_;
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 0e23d7a791bec..ac67a3ce5c1a2 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -720,22 +720,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O
_In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr sess;
- OrtStatus* status = nullptr;
*out = nullptr;
-
- ORT_TRY {
- ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
- ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess));
-
- *out = reinterpret_cast(sess.release());
- }
- ORT_CATCH(const std::exception& e) {
- ORT_HANDLE_EXCEPTION([&]() {
- status = OrtApis::CreateStatus(ORT_FAIL, e.what());
- });
- }
-
- return status;
+ ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess));
+ ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess));
+ *out = reinterpret_cast(sess.release());
+ return nullptr;
API_IMPL_END
}
@@ -743,22 +732,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In
size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) {
API_IMPL_BEGIN
std::unique_ptr sess;
- OrtStatus* status = nullptr;
- *out = nullptr;
-
- ORT_TRY {
- ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess));
- ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess));
-
- *out = reinterpret_cast(sess.release());
- }
- ORT_CATCH(const std::exception& e) {
- ORT_HANDLE_EXCEPTION([&]() {
- status = OrtApis::CreateStatus(ORT_FAIL, e.what());
- });
- }
-
- return status;
+ ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess));
+ ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess));
+ *out = reinterpret_cast(sess.release());
+ return nullptr;
API_IMPL_END
}
@@ -2810,6 +2787,7 @@ static constexpr OrtApi ort_api_1_to_22 = {
&OrtApis::GetModelEditorApi,
&OrtApis::CreateTensorWithDataAndDeleterAsOrtValue,
+ &OrtApis::SessionOptionsSetLoadCancellationFlag,
};
// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 9d8aeb18a782f..0a87036a0dd1d 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -549,4 +549,7 @@ ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator*
ONNXTensorElementDataType type,
_Outptr_ OrtValue** out);
+ORT_API_STATUS_IMPL(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
+ _In_ bool is_cancel);
+
} // namespace OrtApis
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 975502063ac2a..a069cfa0b4713 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -1753,6 +1753,12 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc")
options->value.execution_mode = execution_mode;
},
R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc")
+ .def(
+ "set_load_cancellation_flag",
+ [](PySessionOptions* options, bool value) -> void {
+ options->value.SetLoadCancellationFlag(value);
+ },
+ R"pbdoc(Request inference session load cancellation)pbdoc")
.def_property(
"execution_order",
[](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; },
diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc
index 95101c8075fc2..dc776f74d8758 100644
--- a/onnxruntime/test/framework/inference_session_test.cc
+++ b/onnxruntime/test/framework/inference_session_test.cc
@@ -7,6 +7,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -498,6 +499,30 @@ TEST(InferenceSessionTests, TestModelSerialization) {
ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK());
}
+TEST(InferenceSessionTests, RequestLoadCancellation) {
+ {
+ // Explicit cancel during load, small model is fine
+ SessionOptions so;
+ so.session_logid = "InferenceSessionTests.TestLoadCancellation";
+
+ const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx");
+ InferenceSession session_object{so, GetEnvironment()};
+ so.SetLoadCancellationFlag(true);
+ ASSERT_FALSE(session_object.Load(model_uri).IsOK());
+ }
+ {
+ // Explicit cancel during initialize, small model is fine
+ const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx");
+ SessionOptions so;
+ so.session_logid = "InferenceSessionTests.TestLoadCancellation";
+ so.SetLoadCancellationFlag(false);
+ InferenceSession session_object{so, GetEnvironment()};
+ ASSERT_STATUS_OK(session_object.Load(model_uri));
+ so.SetLoadCancellationFlag(true);
+ ASSERT_FALSE(session_object.Initialize().IsOK());
+ }
+}
+
#ifdef ORT_RUN_EXTERNAL_ONNX_TESTS
static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) {
if (f_arg.size() != s_arg.size()) {
diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc
index 9c0b779870c70..a5fd37361a255 100644
--- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc
+++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc
@@ -576,11 +576,10 @@ TEST(Loop, InfiniteLoopTermination) {
test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true",
{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL
- // call get to propagate any exception
- terminator_result.get();
-
// done with the thread
terminator_thread.join();
+ // call get to propagate any exception
+ terminator_result.get();
}
// Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as
diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc
index b517ba7032886..e00606af1c086 100644
--- a/onnxruntime/test/shared_lib/test_inference.cc
+++ b/onnxruntime/test/shared_lib/test_inference.cc
@@ -4,6 +4,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -4669,6 +4670,34 @@ TEST(CApiTest, RunBaseLoraModel) {
}
}
+TEST(CApiTest, RequestLoadCancellation) {
+ constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx");
+ Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
+ Ort::SessionOptions session_options;
+
+ auto terminator = [&session_options]() {
+ session_options.SetLoadCancellationFlag(true);
+ return;
+ };
+
+ std::packaged_task task{terminator};
+ std::future terminator_result = task.get_future();
+ std::thread terminator_thread{std::move(task)};
+ bool terminated = false;
+ try {
+ Ort::Session session(env, model_path, session_options);
+ } catch (const Ort::Exception& ex) {
+ terminated = OrtErrorCode::ORT_MODEL_LOAD_CANCELED == ex.GetOrtErrorCode();
+ }
+ // done with the thread
+ terminator_thread.join();
+
+ // call get to propagate any exception
+ terminator_result.get();
+
+ ASSERT_TRUE(terminated);
+}
+
struct MockGQA : public OrtCustomOp {
MockGQA() {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {