Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Runtime.InteropServices;
using static Microsoft.ML.OnnxRuntime.NativeMethods;

namespace Microsoft.ML.OnnxRuntime
{
Expand Down Expand Up @@ -325,6 +326,16 @@ public struct OrtApi
public IntPtr CreateLoraAdapterFromArray;
public IntPtr ReleaseLoraAdapter;
public IntPtr RunOptionsAddActiveLoraAdapter;
public IntPtr SetEpDynamicOptions;
public IntPtr ReleaseValueInfo;
public IntPtr ReleaseNode;
public IntPtr ReleaseGraph;
public IntPtr ReleaseModel;
public IntPtr GetValueInfoName;
public IntPtr GetValueInfoTypeInfo;
public IntPtr GetModelEditorApi;
public IntPtr CreateTensorWithDataAndDeleterAsOrtValue;
public IntPtr SessionOptionsSetLoadCancellationFlag;
}

internal static class NativeMethods
Expand Down Expand Up @@ -404,6 +415,7 @@ static NativeMethods()
OrtReleaseSessionOptions = (DOrtReleaseSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseSessionOptions, typeof(DOrtReleaseSessionOptions));
OrtCloneSessionOptions = (DOrtCloneSessionOptions)Marshal.GetDelegateForFunctionPointer(api_.CloneSessionOptions, typeof(DOrtCloneSessionOptions));
OrtSetSessionExecutionMode = (DOrtSetSessionExecutionMode)Marshal.GetDelegateForFunctionPointer(api_.SetSessionExecutionMode, typeof(DOrtSetSessionExecutionMode));
OrtSessionOptionsSetLoadCancellationFlag = (DOrtSessionOptionsSetLoadCancellationFlag)Marshal.GetDelegateForFunctionPointer(api_.SessionOptionsSetLoadCancellationFlag, typeof(DOrtSessionOptionsSetLoadCancellationFlag));
OrtSetOptimizedModelFilePath = (DOrtSetOptimizedModelFilePath)Marshal.GetDelegateForFunctionPointer(api_.SetOptimizedModelFilePath, typeof(DOrtSetOptimizedModelFilePath));
OrtEnableProfiling = (DOrtEnableProfiling)Marshal.GetDelegateForFunctionPointer(api_.EnableProfiling, typeof(DOrtEnableProfiling));
OrtDisableProfiling = (DOrtDisableProfiling)Marshal.GetDelegateForFunctionPointer(api_.DisableProfiling, typeof(DOrtDisableProfiling));
Expand Down Expand Up @@ -1025,6 +1037,12 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
ExecutionMode execution_mode);
public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode;

[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSessionOptionsSetLoadCancellationFlag(IntPtr /*(OrtSessionOptions*)*/ options,
bool value);
public static DOrtSessionOptionsSetLoadCancellationFlag OrtSessionOptionsSetLoadCancellationFlag;


[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath);
public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath;
Expand Down
10 changes: 10 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,16 @@ public ExecutionMode ExecutionMode
}
private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL;

/// <summary>
/// Sets the load cancellation flag for the session. Default is set to false.
/// Provides an opportunity for the user to cancel model loading.
/// </summary>
/// <param name="value">true to request cancellation, false to proceed</param>
public void SetLoadCancellationFlag(bool value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value));
}

#endregion

#region Private Methods
Expand Down
32 changes: 31 additions & 1 deletion include/onnxruntime/core/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
abort(); \
} while (false)

#define ORT_THROW_FROM_STATUS(status) \
do { \
::onnxruntime::PrintFinalMessage( \
::onnxruntime::OnnxRuntimeException( \
ORT_WHERE_WITH_STACK, status.ToString()) \
.what()); \
abort(); \
} while (false)

#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
do { \
::onnxruntime::PrintFinalMessage( \
::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
::onnxruntime::MakeString(__VA_ARGS__), \
::onnxruntime::common::category, \
::onnxruntime::common::code) \
.what()); \
abort(); \
} while (false)

#else

#define ORT_TRY try
Expand Down Expand Up @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
#define ORT_THROW_EX(ex, ...) \
throw ex(__VA_ARGS__)

#define ORT_THROW_FROM_STATUS(status) \
throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \
static_cast<::onnxruntime::common::StatusCode>(status.Code()))

#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \
throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \
::onnxruntime::MakeString(__VA_ARGS__), \
::onnxruntime::common::category, \
::onnxruntime::common::code)

#endif

#define ORT_MAKE_STATUS(category, code, ...) \
Expand Down Expand Up @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch
auto _status = (expr); \
if ((!_status.IsOK())) { \
::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast<const char*>(__FUNCTION__), __LINE__); \
ORT_THROW(_status); \
ORT_THROW_FROM_STATUS(_status); \
} \
} while (0)

Expand Down
47 changes: 45 additions & 2 deletions include/onnxruntime/core/common/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <vector>

#include "core/common/common.h"
#include "core/common/status.h"
#include "core/common/code_location.h"

namespace onnxruntime {
Expand All @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception {
/**
Create a new exception that captures the location it was thrown from.
@param location Location in the source code the exception is being thrown from
@param msg Message containing additional information about the exception cause.
@param category Error category
@param code Error code
*/

OnnxRuntimeException(const CodeLocation& location,
const std::string& message,
common::StatusCategory category,
common::StatusCode code) noexcept
: OnnxRuntimeException(location, nullptr, message, category, code) {
}

/**
Create a new exception that captures the location it was thrown from.
The instance will be created with ONNXRUNTIME category and FAIL code.
@param location Location in the source code the exception is being thrown from
@param failed_condition Optional string containing the condition that failed.
e.g. "tensor.Size() == input.Size()". May be nullptr.
@param msg Message containing additional information about the exception cause.
*/
OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg)
: location_{location} {
OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept
: OnnxRuntimeException(location, failed_condition, msg,
common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) {
}

/**
Create a new exception that captures the location it was thrown from.
@param location Location in the source code the exception is being thrown from
@param failed_condition Optional string containing the condition that failed.
e.g. "tensor.Size() == input.Size()". May be nullptr.
@param msg Message containing additional information about the exception cause.
@param category Error category
@param code Error code
*/
OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg,
common::StatusCategory category,
common::StatusCode code)
: location_{location}, category_(category), code_(code) {
std::ostringstream ss;

ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous
Expand All @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception {
what_ = ss.str();
}

common::StatusCategory Category() const noexcept {
return category_;
}

common::StatusCode Code() const noexcept {
return code_;
}

const char* what() const noexcept override {
return what_.c_str();
}
Expand All @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception {
const CodeLocation location_;
const std::vector<std::string> stacktrace_;
std::string what_;
common::StatusCategory category_;
common::StatusCode code_;
};

} // namespace onnxruntime
7 changes: 6 additions & 1 deletion include/onnxruntime/core/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
MODEL_LOADED = 8,
NOT_IMPLEMENTED = 9,
INVALID_GRAPH = 10,
EP_FAIL = 11
EP_FAIL = 11,
MODEL_LOAD_CANCELED = 12,
};

constexpr const char* StatusCodeToString(StatusCode status) noexcept {
Expand Down Expand Up @@ -72,6 +73,8 @@
return "INVALID_GRAPH";
case StatusCode::EP_FAIL:
return "EP_FAIL";
case StatusCode::MODEL_LOAD_CANCELED:
return "MODEL_LOAD_CANCELED";
default:
return "GENERAL ERROR";
}
Expand Down Expand Up @@ -104,6 +107,8 @@
return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT);
case StatusCode::EP_FAIL:
return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR);
case StatusCode::MODEL_LOAD_CANCELED:
return HRESULT_FROM_WIN32(ERROR_CANCELLED);

Check warning on line 111 in include/onnxruntime/core/common/status.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "CANCELLED" is a misspelling of "CANCELED" Raw Output: ./include/onnxruntime/core/common/status.h:111:38: "CANCELLED" is a misspelling of "CANCELED"
default:
return E_FAIL;
}
Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ typedef enum OrtErrorCode {
ORT_NOT_IMPLEMENTED,
ORT_INVALID_GRAPH,
ORT_EP_FAIL,
ORT_MODEL_LOAD_CANCELED,
} OrtErrorCode;

typedef enum OrtOpAttrType {
Expand Down Expand Up @@ -4898,6 +4899,24 @@ struct OrtApi {
_In_ const int64_t* shape, size_t shape_len,
ONNXTensorElementDataType type,
_Outptr_ OrtValue** out);

/** \brief sets load cancellation flag to abort session loading process.
*
* \param[in] options instance that was passed to the session at creation time.
* \param[in] cancel setting this to true after model loading process was initiated will
* attempt to cancel the loading process. If cancellation is successful, CreateSession()
* CreateSessionFromArray() or any other session creation API that take session options as an
* argument will return an OrtStatus indicating that session loading was canceled at user request,
* error code ORT_MODEL_LOAD_CANCELED.
* The APIs above would not return any valid Session instance. This is the best case effort and the result
* is not guaranteed. The session may have already been created and initialized
* before the cancellation request was issued.
*
* \snippet{doc} snippets.dox OrtStatus
*
*/
ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options,
_In_ bool cancel);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,8 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {

SessionOptionsImpl& SetExecutionMode(ExecutionMode execution_mode); ///< Wraps OrtApi::SetSessionExecutionMode

SessionOptionsImpl& SetLoadCancellationFlag(bool value); ///< Wraps OrtApi::SessionOptionsSetLoadCancellationFlag

SessionOptionsImpl& SetLogId(const char* logid); ///< Wraps OrtApi::SetSessionLogId
SessionOptionsImpl& SetLogSeverityLevel(int level); ///< Wraps OrtApi::SetSessionLogSeverityLevel

Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,12 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetExecutionMode(ExecutionM
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLoadCancellationFlag(bool value) {
ThrowOnError(GetApi().SessionOptionsSetLoadCancellationFlag(this->p_, value));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetLogId(const char* logid) {
ThrowOnError(GetApi().SetSessionLogId(this->p_, logid));
Expand Down
23 changes: 13 additions & 10 deletions onnxruntime/core/framework/error_code_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,19 @@ Status ToStatus(const OrtStatus* ort_status, common::StatusCategory category = c

#ifndef ORT_NO_EXCEPTIONS
#define API_IMPL_BEGIN try {
#define API_IMPL_END \
} \
catch (const onnxruntime::NotImplementedException& ex) { \
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
} \
catch (const std::exception& ex) { \
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
} \
catch (...) { \
return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
#define API_IMPL_END \
} \
catch (const onnxruntime::OnnxRuntimeException& ex) { \
return OrtApis::CreateStatus(static_cast<OrtErrorCode>(ex.Code()), ex.what()); \
} \
catch (const onnxruntime::NotImplementedException& ex) { \
return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, ex.what()); \
} \
catch (const std::exception& ex) { \
return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, ex.what()); \
} \
catch (...) { \
return OrtApis::CreateStatus(ORT_FAIL, "Unknown Exception"); \
}

#else
Expand Down
Loading
Loading