diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index b64a5c3e5a4a2..77c35aac65b92 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -3,6 +3,7 @@ using System; using System.Runtime.InteropServices; +using static Microsoft.ML.OnnxRuntime.NativeMethods; namespace Microsoft.ML.OnnxRuntime { @@ -325,6 +326,16 @@ public struct OrtApi public IntPtr CreateLoraAdapterFromArray; public IntPtr ReleaseLoraAdapter; public IntPtr RunOptionsAddActiveLoraAdapter; + public IntPtr SetEpDynamicOptions; + public IntPtr ReleaseValueInfo; + public IntPtr ReleaseNode; + public IntPtr ReleaseGraph; + public IntPtr ReleaseModel; + public IntPtr GetValueInfoName; + public IntPtr GetValueInfoTypeInfo; + public IntPtr GetModelEditorApi; + public IntPtr CreateTensorWithDataAndDeleterAsOrtValue; + public IntPtr SessionOptionsSetLoadCancellationFlag; } internal static class NativeMethods @@ -404,6 +415,7 @@ static NativeMethods() OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions)); OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions)); OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode)); + OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag)); OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath)); OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling)); OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling)); @@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options, + bool value); + public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath); public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index bd450451a1265..9b0f183f03681 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode } private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL; + /// + /// Sets the load cancellation flag for the session. Default is set to false. + /// Provides an opportunity for the user to cancel model loading. + /// + /// true to request cancellation, false to proceed + public void SetLoadCancellationFlag(bool value) + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); + } + #endregion #region Private Methods diff --git a/include/onnxruntime/core/common/common.h b/include/onnxruntime/core/common/common.h index 0822eba950f50..10f658f52e0d9 100644 --- a/include/onnxruntime/core/common/common.h +++ b/include/onnxruntime/core/common/common.h @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch abort(); \ } while (false) +#define ORT_THROW_FROM_STATUS(status) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, status.ToString()) \ + .what()); \ + abort(); \ + } while (false) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) \ + .what()); \ + abort(); \ + } while (false) + #else #define ORT_TRY try @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch #define ORT_THROW_EX(ex, ...) \ throw ex(__VA_ARGS__) +#define ORT_THROW_FROM_STATUS(status) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \ + static_cast<::onnxruntime::common::StatusCode>(status.Code())) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) + #endif #define ORT_MAKE_STATUS(category, code, ...) \ @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch auto _status = (expr); \ if ((!_status.IsOK())) { \ ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ - ORT_THROW(_status); \ + ORT_THROW_FROM_STATUS(_status); \ } \ } while (0) diff --git a/include/onnxruntime/core/common/exceptions.h b/include/onnxruntime/core/common/exceptions.h index 494a770b8db98..6d0f6edd6e7c4 100644 --- a/include/onnxruntime/core/common/exceptions.h +++ b/include/onnxruntime/core/common/exceptions.h @@ -11,6 +11,7 @@ #include #include "core/common/common.h" +#include "core/common/status.h" #include "core/common/code_location.h" namespace onnxruntime { @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception { /** Create a new exception that captures the location it was thrown from. @param location Location in the source code the exception is being thrown from + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + + OnnxRuntimeException(const CodeLocation& location, + const std::string& message, + common::StatusCategory category, + common::StatusCode code) noexcept + : OnnxRuntimeException(location, nullptr, message, category, code) { + } + + /** + Create a new exception that captures the location it was thrown from. + The instance will be created with ONNXRUNTIME category and FAIL code. + @param location Location in the source code the exception is being thrown from @param failed_condition Optional string containing the condition that failed. e.g. "tensor.Size() == input.Size()". May be nullptr. @param msg Message containing additional information about the exception cause. */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept + : OnnxRuntimeException(location, failed_condition, msg, + common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg, + common::StatusCategory category, + common::StatusCode code) + : location_{location}, category_(category), code_(code) { std::ostringstream ss; ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception { what_ = ss.str(); } + common::StatusCategory Category() const noexcept { + return category_; + } + + common::StatusCode Code() const noexcept { + return code_; + } + const char* what() const noexcept override { return what_.c_str(); } @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception { const CodeLocation location_; const std::vector stacktrace_; std::string what_; + common::StatusCategory category_; + common::StatusCode code_; }; } // namespace onnxruntime diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index 8f171daabbb1e..b222e411d7804 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -43,7 +43,8 @@ enum StatusCode { MODEL_LOADED = 8, NOT_IMPLEMENTED = 9, INVALID_GRAPH = 10, - EP_FAIL = 11 + EP_FAIL = 11, + MODEL_LOAD_CANCELED = 12, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -72,6 +73,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "INVALID_GRAPH"; case StatusCode::EP_FAIL: return "EP_FAIL"; + case StatusCode::MODEL_LOAD_CANCELED: + return "MODEL_LOAD_CANCELED"; default: return "GENERAL ERROR"; } @@ -104,6 +107,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); case StatusCode::EP_FAIL: return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::MODEL_LOAD_CANCELED: + return HRESULT_FROM_WIN32(ERROR_CANCELLED); default: return E_FAIL; } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 6d4cc8a1f2fa9..3bf0d5e19c525 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -255,6 +255,7 @@ typedef enum OrtErrorCode { ORT_NOT_IMPLEMENTED, ORT_INVALID_GRAPH, ORT_EP_FAIL, + ORT_MODEL_LOAD_CANCELED, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -4898,6 +4899,24 @@ struct OrtApi { _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** \brief sets load cancellation flag to abort session loading process. + * + * \param[in] options instance that was passed to the session at creation time. + * \param[in] cancel setting this to true after model loading process was initiated will + * attempt to cancel the loading process. If cancellation is successful, CreateSession() + * CreateSessionFromArray() or any other session creation API that take session options as an + * argument will return an OrtStatus indicating that session loading was canceled at user request, + * error code ORT_MODEL_LOAD_CANCELED. + * The APIs above would not return any valid Session instance. This is the best case effort and the result + * is not guaranteed. The session may have already been created and initialized + * before the cancellation request was issued. + * + * \snippet{doc} snippets.dox OrtStatus + * + */ + ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool cancel); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 979b478e2fbb4..ce7dc1c45b05e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode + SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag + SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48c5e52e33c53..524e3ecc92936 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -747,6 +747,12 @@ inline SessionOptionsImpl& SessionOptionsImpl::SetExecutionMode(ExecutionM return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetLoadCancellationFlag(bool value) { + ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetLogId(const char* logid) { ThrowOnError(GetApi().SetSessionLogId(this->p_, logid)); diff --git a/onnxruntime/core/framework/error_code_helper.h b/onnxruntime/core/framework/error_code_helper.h index 703d183ea5c87..b42c6a9ba3e10 100644 --- a/onnxruntime/core/framework/error_code_helper.h +++ b/onnxruntime/core/framework/error_code_helper.h @@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c #ifndef ORT_NO_EXCEPTIONS #define API_IMPL_BEGIN try { -#define API_IMPL_END \ - } \ - catch (const onnxruntime::NotImplementedException& ex) { \ - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ - } \ - catch (const std::exception& ex) { \ - return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ - } \ - catch (...) { \ - return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ +#define API_IMPL_END \ + } \ + catch (const onnxruntime::OnnxRuntimeException& ex) { \ + return OrtApis::CreateStatus(static_cast(ex.Code()), ex.what()); \ + } \ + catch (const onnxruntime::NotImplementedException& ex) { \ + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \ + } \ + catch (const std::exception& ex) { \ + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \ + } \ + catch (...) { \ + return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \ } #else diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index ff4d300f665b1..50f14104cfd7a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -56,6 +56,7 @@ namespace { // contains some common parameters used by the partitioning helper functions struct PartitionParams { std::reference_wrapper graph; + std::reference_wrapper check_load_cancellation_fn; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) std::reference_wrapper func_mgr; std::reference_wrapper fused_kernel_registry; @@ -143,6 +144,7 @@ struct GetCapabilityForEPParams { #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) IResourceAccountant* resource_accountant; std::reference_wrapper graph_optimizer_registry; + std::reference_wrapper check_load_cancellation_fn; }; auto get_capabilities = [](const IExecutionProvider& ep, @@ -188,7 +190,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l { const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Graph partitioning was canceled by user request"); + } if (capabilities.empty()) { return Status::OK(); @@ -209,6 +216,10 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l // Perform layout transformation on the specific EP assigned graph bool modified = false; ORT_RETURN_IF_ERROR(params.transform_layout(graph, modified, current_ep, params.debug_graph_fn)); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // It is possible some new nodes are introduced during transformation. These nodes can be either existing nodes // which are reconstructed to update domain or completely new nodes which are necessary for layout transformation. @@ -226,7 +237,12 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l capabilities.clear(); const GraphViewer graph_viewer(graph); - capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup, params.resource_accountant, + graph_optimizer_registry); + if (params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "GetCapabilities was canceled by user request"); + } // all nodes with an index >= first_new_node with domain of kMSInternalNHWCDomain should be in the capabilities InlinedHashSet new_nodes_in_capabilities; @@ -405,6 +421,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, int& fused_node_unique_id, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, + const CheckLoadCancellationFn& check_load_cancellation_fn, const logging::Logger& logger, IResourceAccountant* resource_accountant, const GraphOptimizerRegistry& graph_optimizer_registry) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -420,7 +437,10 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, // we pass through the FuncManager from the top level graph ORT_RETURN_IF_ERROR(PartitionOnnxFormatModelImpl(*subgraph, func_mgr, kernel_registry_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, - transform_layout_fn, debug_graph_fn, logger, resource_accountant, graph_optimizer_registry)); + transform_layout_fn, debug_graph_fn, + check_load_cancellation_fn, + logger, resource_accountant, + graph_optimizer_registry)); } } @@ -445,7 +465,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, std::cref(transform_layout_fn), std::cref(debug_graph_fn), resource_accountant, - std::ref(graph_optimizer_registry)}; + std::ref(graph_optimizer_registry), + std::cref(check_load_cancellation_fn)}; ORT_RETURN_IF_ERROR(GetCapabilityForEP(get_capability_params, logger)); if (capabilities.empty()) { @@ -532,6 +553,8 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, } ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs)); + ORT_RETURN_IF(check_load_cancellation_fn(), + "Graph partitioning is canceled due to user request."); if (node_compute_funcs.size() != nodes_to_compile.size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions"); @@ -633,6 +656,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide Graph& graph, const GraphOptimizerRegistry& graph_optimizer_registry, const logging::Logger& logger, + const CheckLoadCancellationFn& check_load_cancellation_fn, InlinedHashSet& not_inlined, size_t& inlined_count) { // handle testing edge case where optimizers or constant lifting results in graph with no nodes. @@ -650,6 +674,7 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide *subgraph, graph_optimizer_registry, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); } @@ -673,8 +698,13 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide InlinedHashSet claimed_by_ep; for (const auto& ep : execution_providers) { std::vector> capabilities; - ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, graph_optimizer_registry, logger, + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, + graph_optimizer_registry, logger, capabilities)); + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } + for (auto& capability : capabilities) { const auto& nodes = capability->sub_graph->nodes; if (nodes.size() == 1) { @@ -707,6 +737,9 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); } } + if (check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "AOT inlining is canceled due to user request."); + } } return Status::OK(); @@ -846,6 +879,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_kernel_registry = partition_params.fused_kernel_registry.get(); auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; + const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn; do { // process full graph with each EP @@ -861,6 +895,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, fused_kernel_registry, *ep, mode, fused_node_unique_id, transform_layout_function, partition_params.debug_graph_fn, + check_load_cancellation_fn, logger, resource_accountant, graph_optimizer_registry)); } @@ -915,7 +950,8 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::cref(partition_params.debug_graph_fn), #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) nullptr, - std::ref(graph_optimizer_registry) + std::ref(graph_optimizer_registry), + partition_params.check_load_cancellation_fn }; // clang-format on @@ -972,6 +1008,9 @@ static Status PartitionOrtFormatModelImpl(const PartitionParams& partition_param std::vector single_node_compute_func; ORT_RETURN_IF_ERROR(current_ep.Compile({IExecutionProvider::FusedNodeAndGraph{node, *compilation_entry.viewer}}, single_node_compute_func)); + if (partition_params.check_load_cancellation_fn()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph partitioning is canceled due to user request."); + } ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 element."); auto& func_mgr = partition_params.func_mgr.get(); @@ -1032,6 +1071,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, return Status::OK(); } + auto check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + auto& graph = model.MainGraph(); InlinedHashSet not_inlined; do { @@ -1041,13 +1082,13 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, graph, *graph_optimizer_registry_, logger, + check_load_cancellation_fn, not_inlined, inlined_count)); if (inlined_count == 0) { break; } - ORT_RETURN_IF_ERROR(graph.Resolve()); } while (true); @@ -1082,6 +1123,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified."); } + CheckLoadCancellationFn check_load_cancellation_fn = [this]() -> bool { return IsLoadCancellationFlagSet(); }; + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) // fused_kernel_registry is preparing the kernels created on the fly for fused sub graph. // It is only visible for current session. @@ -1092,6 +1135,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), std::ref(func_mgr), std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), @@ -1105,6 +1149,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, ORT_UNUSED_PARAMETER(debug_graph_fn); PartitionParams partition_params{ std::ref(graph), + std::cref(check_load_cancellation_fn), }; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index b9d4022cb5a14..87edc7a64c6b5 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -33,6 +33,16 @@ class GraphPartitioner { graph_optimizer_registry_(std::move(graph_optimizer_registry)) { } + GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, + const ExecutionProviders& providers, + std::unique_ptr graph_optimizer_registry, + CheckLoadCancellationFn check_load_cancellation_fn) + : kernel_registry_mgr_(kernel_registry_mgr), + providers_(providers), + graph_optimizer_registry_(std::move(graph_optimizer_registry)), + check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) { + } + // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, @@ -41,6 +51,10 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; + bool IsLoadCancellationFlagSet() const { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #ifndef ORT_MINIMAL_BUILD /// // Ahead of Time Function inlining. The main purpose of the function is to inline as many @@ -69,6 +83,7 @@ class GraphPartitioner { KernelRegistryManager& kernel_registry_mgr_; const ExecutionProviders& providers_; std::unique_ptr graph_optimizer_registry_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8d4db36106f28..ef323b99b006c 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include "core/common/inlined_containers.h" #include "core/framework/config_options.h" @@ -66,6 +67,8 @@ struct FreeDimensionOverride { int64_t dim_value; }; +using CheckLoadCancellationFn = std::function; + /** * Configuration information for a session. */ @@ -184,6 +187,18 @@ struct SessionOptions { // User specified logging func and param OrtLoggingFunction user_logging_function = nullptr; void* user_logging_param = nullptr; + + void SetLoadCancellationFlag(bool value) noexcept { + *load_cancellation_flag = value; + } + + bool IsLoadCancellationFlagSet() const noexcept { + return *load_cancellation_flag; + } + + // Load cancellation flag is necessary to be within shared memory as session_options are + // copied internally and the flag needs to be accessible across all copies. + std::shared_ptr load_cancellation_flag = std::make_shared(false); }; inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_options) { diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index d174d6cc72ead..6362a3169f3a3 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -422,6 +422,10 @@ Status SessionState::PrepackConstantInitializedTensors( auto prepacked_constant_weights = [this, &constant_initializers_use_count, &initializers_to_share_map]( bool should_cache_prepacked_weights_for_shared_initializers) -> Status { for (auto& node : GetGraphViewer().Nodes()) { + if (sess_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Weight pre-packing was canceled due to user request."); + } auto kernel = GetMutableKernel(node.Index()); int input_idx = 0; for (auto& input_def : node.InputDefs()) { @@ -1541,6 +1545,11 @@ Status SessionState::FinalizeSessionStateImpl(const std::basic_stringname(); diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 39ffc6a5b0cee..334ecb3887d14 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1268,6 +1268,10 @@ Graph::Graph(const Model& owning_model, #endif } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + // Remove constant nodes as they're replaced with initializers above. const gsl::not_null*> graph_mutable_nodes{graph_proto_->mutable_node()}; graph_mutable_nodes->erase( @@ -1365,6 +1369,10 @@ Graph::Graph(const Model& owning_model, } } + if (owning_model_.IsLoadCancellationFlagSet()) { + ORT_THROW_WITH_CATEGORY_AND_CODE(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph loading canceled due to user request."); + } + for (auto& graph_output : graph_proto_->output()) { if (utils::HasName(graph_output) && utils::HasType(graph_output)) { auto& name = graph_output.name(); diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 7629e40c1b5fe..436af7115eb1a 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -82,7 +82,7 @@ Model::Model(const std::string& graph_name, const std::vector& model_local_functions, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { model_proto_.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); model_proto_.mutable_graph()->set_name(graph_name); model_metadata_ = model_metadata; @@ -161,7 +161,7 @@ Model::Model(const ModelProto& model_proto, const PathString& model_path, Model::Model(ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, const logging::Logger& logger, const ModelOptions& options) - : model_path_(model_path) { + : model_path_(model_path), check_load_cancellation_fn_(options.check_load_cancellation_fn) { if (!utils::HasGraph(model_proto)) { ORT_THROW("ModelProto does not have a graph."); } @@ -435,6 +435,11 @@ Status Model::Load(const ModelProto& model_proto, ORT_TRY { model = std::make_unique(model_proto, model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -474,6 +479,11 @@ Status Model::Load(ModelProto&& model_proto, ORT_TRY { model = std::make_unique(std::move(model_proto), model_path, local_registries, logger, options); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to load model with error: " + std::string(ex.what())); @@ -509,6 +519,11 @@ static Status LoadModelHelper(const T& file_path, Loader loader) { ORT_TRY { status = loader(fd); } + ORT_CATCH(const OnnxRuntimeException& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = Status(ex.Category(), ex.Code(), ex.what()); + }); + } ORT_CATCH(const std::exception& ex) { ORT_HANDLE_EXCEPTION([&]() { status = Status(ONNXRUNTIME, FAIL, ex.what()); diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6fd94c60d6b99..70f82bcfb160b 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -11,6 +11,7 @@ #include "core/common/flatbuffers.h" +#include "core/framework/session_options.h" #include "core/graph/graph_viewer.h" #include "core/graph/ort_format_load_options.h" #include "core/session/onnxruntime_c_api.h" @@ -38,6 +39,14 @@ struct ModelOptions { // be returned. bool strict_shape_type_inference; + CheckLoadCancellationFn check_load_cancellation_fn; + + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference, + CheckLoadCancellationFn check_load_cancellation_fn) + : allow_released_opsets_only(allow_released_opsets_only), + strict_shape_type_inference(strict_shape_type_inference), + check_load_cancellation_fn(std::move(check_load_cancellation_fn)) {} + ModelOptions(bool allow_released_opsets_only, bool strict_shape_type_inference) : allow_released_opsets_only(allow_released_opsets_only), strict_shape_type_inference(strict_shape_type_inference) {} @@ -102,6 +111,11 @@ class Model { #endif // !defined(ORT_MINIMAL_BUILD) + // Check for load cancellation. + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + #if !defined(ORT_MINIMAL_BUILD) // Get model's IR version. // Return if not specified. @@ -343,5 +357,7 @@ class Model { // Main graph of the model. std::unique_ptr graph_; + + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.cc b/onnxruntime/core/optimizer/graph_transformer_mgr.cc index 039283bb2d4e1..83c3f70799987 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.cc +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.cc @@ -27,6 +27,9 @@ common::Status GraphTransformerManager::ApplyTransformers(Graph& graph, Transfor } for (unsigned step = 0; step < steps_; ++step) { + if (IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, "Graph transformation canceled due to user request."); + } bool graph_changed = false; for (const auto& transformer : transformers->second) { if (step > 0 && transformer->ShouldOnlyApplyOnce()) diff --git a/onnxruntime/core/optimizer/graph_transformer_mgr.h b/onnxruntime/core/optimizer/graph_transformer_mgr.h index ed66302434ab2..eab57f12bfcbb 100644 --- a/onnxruntime/core/optimizer/graph_transformer_mgr.h +++ b/onnxruntime/core/optimizer/graph_transformer_mgr.h @@ -24,6 +24,16 @@ class GraphTransformerManager { // Get the maximum number of graph transformation steps common::Status GetSteps(unsigned& steps) const; + // Set the cancellation flag ptr from session_options + void SetLoadCancellationFn(CheckLoadCancellationFn check_load_cancellation_fn) { + check_load_cancellation_fn_ = std::move(check_load_cancellation_fn); + } + + // Get the cancellation flag ptr + bool IsLoadCancellationFlagSet() const noexcept { + return check_load_cancellation_fn_ && check_load_cancellation_fn_(); + } + // Register a transformer with a level. common::Status Register(std::unique_ptr transformer, TransformerLevel level); @@ -38,5 +48,6 @@ class GraphTransformerManager { InlinedHashMap>> level_to_transformer_map_; InlinedHashMap transformers_info_; + CheckLoadCancellationFn check_load_cancellation_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 2e733f67a888c..e50ee5738c30e 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -340,3 +340,11 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* return nullptr; API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel) { + API_IMPL_BEGIN + options->value.SetLoadCancellationFlag(is_cancel); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e5ea562ce3535..1bcb651ec9605 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -383,6 +383,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options ORT_THROW_IF_ERROR(graph_transformer_mgr_.SetSteps(session_options_.max_num_graph_transformation_steps)); + graph_transformer_mgr_.SetLoadCancellationFn(this->check_load_cancellation_fn_); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1004,11 +1005,13 @@ common::Status InferenceSession::LoadOnnxModel(const PathString& model_uri) { std::copy(std::begin(interop_domains_), std::end(interop_domains_), std::back_inserter(domain_ptrs)); ORT_RETURN_IF_ERROR(AddCustomOpDomains(domain_ptrs)); #endif + const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; return onnxruntime::Model::Load(model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; common::Status st = LoadWithLoader(loader, "model_loading_uri"); @@ -1101,7 +1104,8 @@ common::Status InferenceSession::Load(const void* model_data, int model_data_len return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_array"); @@ -1139,7 +1143,8 @@ common::Status InferenceSession::LoadOnnxModel(ModelProto model_proto) { // This call will move model_proto to the constructed model instance return onnxruntime::Model::Load(std::move(model_proto), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(true, strict_shape_type_inference)); + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_proto"); @@ -1172,7 +1177,8 @@ common::Status InferenceSession::Load(std::istream& model_istream, bool allow_re const bool strict_shape_type_inference = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsConfigStrictShapeTypeInference, "0") == "1"; ModelOptions model_opts(allow_released_opsets_only, - strict_shape_type_inference); + strict_shape_type_inference, + check_load_cancellation_fn_); std::string external_data_folder_path = session_options_.config_options.GetConfigOrDefault( kOrtSessionOptionsModelExternalInitializersFileFolderPath, ""); @@ -1211,7 +1217,8 @@ common::Status InferenceSession::Load() { // Pass on ownership of the parsed ModelProto to the Model instance (its job here is done by this stage) return Model::Load(std::move(this->model_proto_), model_location_, model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, *session_logger_, - ModelOptions(allow_released_opsets_only, strict_shape_type_inference)); + ModelOptions(allow_released_opsets_only, strict_shape_type_inference, + check_load_cancellation_fn_)); }; return LoadWithLoader(loader, "model_loading_from_saved_proto"); @@ -1239,7 +1246,8 @@ common::Status InferenceSession::Load(const OrtModel& model_editor_api_model) { std::unique_ptr tmp_model; ORT_RETURN_IF_ERROR(Model::LoadFromModelEditorApiModel(model_editor_api_model, HasLocalSchema() ? &custom_schema_registries_ : nullptr, - ModelOptions(true, strict_shape_type_inference), + ModelOptions(true, strict_shape_type_inference, + check_load_cancellation_fn_), *session_logger_, tmp_model)); model_ = std::move(tmp_model); @@ -1283,7 +1291,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool auto graph_optimizer_registry = std::make_unique(&session_options_, execution_providers_.Get(onnxruntime::kCpuExecutionProvider), session_logger_); - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry), + check_load_cancellation_fn_); // Run Ahead Of time function inlining if (const bool disable_aot_function_inlining = @@ -1711,7 +1720,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, providers.Get(onnxruntime::kCpuExecutionProvider), &logger); - GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry)); + GraphPartitioner partitioner(kernel_registry_manager, providers, std::move(graph_optimizer_registry), + [&sess_options]() -> bool { return sess_options.IsLoadCancellationFlagSet(); }); ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, @@ -1784,6 +1794,11 @@ common::Status InferenceSession::HasInvalidCombinationOfExecutionProviders() con #pragma warning(disable : 26117) #endif common::Status InferenceSession::Initialize() { + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } + Status status = Status::OK(); TimePoint tp; if (session_profiler_.IsEnabled()) { @@ -2009,6 +2024,10 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + if (session_options_.IsLoadCancellationFlagSet()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, MODEL_LOAD_CANCELED, + "Session initialization canceled due to user request."); + } // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. // diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 5b484103c9ecf..7b5d98c38a0fa 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -781,6 +781,10 @@ class InferenceSession { // the session options are released after the individual operators are destroyed. SessionOptions session_options_; + CheckLoadCancellationFn check_load_cancellation_fn_ = [this]() { + return session_options_.IsLoadCancellationFlagSet(); + }; + /// Logging manager if provided. logging::LoggingManager* logging_manager_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 0e23d7a791bec..ac67a3ce5c1a2 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -720,22 +720,11 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSession, _In_ const OrtEnv* env, _In_ const O _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, model_path, nullptr, 0, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -743,22 +732,10 @@ ORT_API_STATUS_IMPL(OrtApis::CreateSessionFromArray, _In_ const OrtEnv* env, _In size_t model_data_length, _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out) { API_IMPL_BEGIN std::unique_ptr sess; - OrtStatus* status = nullptr; - *out = nullptr; - - ORT_TRY { - ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); - ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); - - *out = reinterpret_cast(sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } - - return status; + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(options, env, nullptr, model_data, model_data_length, sess)); + ORT_API_RETURN_IF_ERROR(InitializeSession(options, *sess)); + *out = reinterpret_cast(sess.release()); + return nullptr; API_IMPL_END } @@ -2810,6 +2787,7 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::GetModelEditorApi, &OrtApis::CreateTensorWithDataAndDeleterAsOrtValue, + &OrtApis::SessionOptionsSetLoadCancellationFlag, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9d8aeb18a782f..0a87036a0dd1d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -549,4 +549,7 @@ ORT_API_STATUS_IMPL(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* ONNXTensorElementDataType type, _Outptr_ OrtValue** out); +ORT_API_STATUS_IMPL(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool is_cancel); + } // namespace OrtApis diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 975502063ac2a..a069cfa0b4713 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1753,6 +1753,12 @@ Applies to session load, initialization, etc. Default is 0.)pbdoc") options->value.execution_mode = execution_mode; }, R"pbdoc(Sets the execution mode. Default is sequential.)pbdoc") + .def( + "set_load_cancellation_flag", + [](PySessionOptions* options, bool value) -> void { + options->value.SetLoadCancellationFlag(value); + }, + R"pbdoc(Request inference session load cancellation)pbdoc") .def_property( "execution_order", [](const PySessionOptions* options) -> ExecutionOrder { return options->value.execution_order; }, diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 95101c8075fc2..dc776f74d8758 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -498,6 +499,30 @@ TEST(InferenceSessionTests, TestModelSerialization) { ASSERT_TRUE(session_object_emptyValidation.Initialize().IsOK()); } +TEST(InferenceSessionTests, RequestLoadCancellation) { + { + // Explicit cancel during load, small model is fine + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + InferenceSession session_object{so, GetEnvironment()}; + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Load(model_uri).IsOK()); + } + { + // Explicit cancel during initialize, small model is fine + const PathString model_uri = ORT_TSTR("testdata/constant_floats.onnx"); + SessionOptions so; + so.session_logid = "InferenceSessionTests.TestLoadCancellation"; + so.SetLoadCancellationFlag(false); + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(model_uri)); + so.SetLoadCancellationFlag(true); + ASSERT_FALSE(session_object.Initialize().IsOK()); + } +} + #ifdef ORT_RUN_EXTERNAL_ONNX_TESTS static bool Compare(const InputDefList& f_arg, const InputDefList& s_arg) { if (f_arg.size() != s_arg.size()) { diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 9c0b779870c70..a5fd37361a255 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -576,11 +576,10 @@ TEST(Loop, InfiniteLoopTermination) { test.Run(OpTester::ExpectResult::kExpectFailure, "Exiting due to terminate flag being set to true", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options); // Disable TensorRT on unsupported data type BOOL - // call get to propagate any exception - terminator_result.get(); - // done with the thread terminator_thread.join(); + // call get to propagate any exception + terminator_result.get(); } // Add basic test to trigger types override logic in Graph::InferAndVerifySubgraphTypes as well as diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index b517ba7032886..e00606af1c086 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -4669,6 +4670,34 @@ TEST(CApiTest, RunBaseLoraModel) { } } +TEST(CApiTest, RequestLoadCancellation) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + Ort::Env env(ORT_LOGGING_LEVEL_WARNING); + Ort::SessionOptions session_options; + + auto terminator = [&session_options]() { + session_options.SetLoadCancellationFlag(true); + return; + }; + + std::packaged_task task{terminator}; + std::future terminator_result = task.get_future(); + std::thread terminator_thread{std::move(task)}; + bool terminated = false; + try { + Ort::Session session(env, model_path, session_options); + } catch (const Ort::Exception& ex) { + terminated = OrtErrorCode::ORT_MODEL_LOAD_CANCELED == ex.GetOrtErrorCode(); + } + // done with the thread + terminator_thread.join(); + + // call get to propagate any exception + terminator_result.get(); + + ASSERT_TRUE(terminated); +} + struct MockGQA : public OrtCustomOp { MockGQA() { OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {