diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c728428348b53..28df8e69b5925 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7436,6 +7436,32 @@ struct OrtApi { _In_ const OrtThreadPoolCallbacksConfig* config); /// @} + + /** \brief Check if the memory pattern optimization is enabled in the session options. + * + * \param[in] options + * \param[out] out Set to 1 if the memory pattern optimization is enabled, 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + * + * \see OrtApi::EnableMemPattern, OrtApi::DisableMemPattern + */ + ORT_API2_STATUS(GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out); + + /** \brief Get the current execution mode setting. + * + * \param[in] options + * \param[out] out Set to the current execution mode (ORT_SEQUENTIAL or ORT_PARALLEL). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + * + * \see OrtApi::SetSessionExecutionMode + */ + ORT_API2_STATUS(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a19793f4c67d2..37920cf0d58f3 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1621,6 +1621,9 @@ struct ConstSessionOptionsImpl : Base { std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const; + + bool GetMemPatternEnabled() const; ///< Wraps OrtApi::GetMemPatternEnabled + ExecutionMode GetExecutionMode() const; ///< Wraps OrtApi::GetSessionExecutionMode }; template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 8cad4a6dd51ad..f32e560b433e6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1400,6 +1400,20 @@ inline std::string ConstSessionOptionsImpl::GetConfigEntryOrDefault(const cha return this->GetConfigEntry(config_key); } +template +inline bool ConstSessionOptionsImpl::GetMemPatternEnabled() const { + int out = 0; + ThrowOnError(GetApi().GetMemPatternEnabled(this->p_, &out)); + return out != 0; +} + +template +inline ExecutionMode ConstSessionOptionsImpl::GetExecutionMode() const { + ExecutionMode out{}; + ThrowOnError(GetApi().GetSessionExecutionMode(this->p_, &out)); + return out; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetIntraOpNumThreads(int intra_op_num_threads) { ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads)); diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 3df6d37d63794..06bd5c4d84089 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -118,6 +118,15 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionExecutionMode, _In_ OrtSessionOptions* op return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "'out' parameter must not be NULL"); + *out = options->value.execution_mode; + return nullptr; + API_IMPL_END +} + // set filepath to save optimized onnx model. ORT_API_STATUS_IMPL(OrtApis::SetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath) { options->value.optimized_model_filepath = optimized_model_filepath; @@ -149,6 +158,15 @@ ORT_API_STATUS_IMPL(OrtApis::DisableMemPattern, _In_ OrtSessionOptions* options) return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL"); + ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "'out' parameter must not be NULL"); + *out = options->value.enable_mem_pattern ? 1 : 0; + return nullptr; + API_IMPL_END +} + // enable the memory arena on CPU // Arena may pre-allocate memory for future usage. // set this option to false if you don't want it. diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 5ee5f1486b137..1252a65197184 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4895,6 +4895,8 @@ static constexpr OrtApi ort_api_1_to_27 = { &OrtApis::SetPerSessionThreadPoolCallbacks, // End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information) // End of Version 26 - DO NOT MODIFY ABOVE (see above text for more information) + &OrtApis::GetMemPatternEnabled, + &OrtApis::GetSessionExecutionMode, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 3cc55ee01a3fe..00ec258cef91e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -64,6 +64,8 @@ ORT_API_STATUS_IMPL(EnableProfiling, _In_ OrtSessionOptions* options, _In_ const ORT_API_STATUS_IMPL(DisableProfiling, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(EnableMemPattern, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(DisableMemPattern, _In_ OrtSessionOptions* options); +ORT_API_STATUS_IMPL(GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out); +ORT_API_STATUS_IMPL(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out); ORT_API_STATUS_IMPL(EnableCpuMemArena, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(DisableCpuMemArena, _In_ OrtSessionOptions* options); ORT_API_STATUS_IMPL(SetSessionLogId, _In_ OrtSessionOptions* options, const char* logid); diff --git a/onnxruntime/test/shared_lib/test_session_options.cc b/onnxruntime/test/shared_lib/test_session_options.cc index ba58344e1e3e2..7e29c6f9a71d1 100644 --- a/onnxruntime/test/shared_lib/test_session_options.cc +++ b/onnxruntime/test/shared_lib/test_session_options.cc @@ -22,6 +22,36 @@ TEST(CApiTest, session_options_deterministic_compute) { options.SetDeterministicCompute(true); } +TEST(CApiTest, session_options_get_mem_pattern_enabled) { + Ort::SessionOptions options; + + // Memory pattern is enabled by default + ASSERT_TRUE(options.GetMemPatternEnabled()); + + // Disable and verify + options.DisableMemPattern(); + ASSERT_FALSE(options.GetMemPatternEnabled()); + + // Re-enable and verify + options.EnableMemPattern(); + ASSERT_TRUE(options.GetMemPatternEnabled()); +} + +TEST(CApiTest, session_options_get_execution_mode) { + Ort::SessionOptions options; + + // Default is sequential + ASSERT_EQ(options.GetExecutionMode(), ORT_SEQUENTIAL); + + // Set to parallel and verify + options.SetExecutionMode(ORT_PARALLEL); + ASSERT_EQ(options.GetExecutionMode(), ORT_PARALLEL); + + // Set back to sequential and verify + options.SetExecutionMode(ORT_SEQUENTIAL); + ASSERT_EQ(options.GetExecutionMode(), ORT_SEQUENTIAL); +} + #if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) && !defined(ORT_NO_EXCEPTIONS) TEST(CApiTest, session_options_oversized_affinity_string) {