Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7436,6 +7436,32 @@ struct OrtApi {
_In_ const OrtThreadPoolCallbacksConfig* config);

/// @}

/** \brief Check if the memory pattern optimization is enabled in the session options.
*
* \param[in] options
* \param[out] out Set to 1 if the memory pattern optimization is enabled, 0 otherwise.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.27.
*
* \see OrtApi::EnableMemPattern, OrtApi::DisableMemPattern
*/
ORT_API2_STATUS(GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out);

/** \brief Get the current execution mode setting.
*
* \param[in] options
* \param[out] out Set to the current execution mode (ORT_SEQUENTIAL or ORT_PARALLEL).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.27.
*
* \see OrtApi::SetSessionExecutionMode
*/
ORT_API2_STATUS(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out);
Comment thread
adrianlizarraga marked this conversation as resolved.
};

/*
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,9 @@ struct ConstSessionOptionsImpl : Base<T> {
std::string GetConfigEntry(const char* config_key) const; ///< Wraps OrtApi::GetSessionConfigEntry
bool HasConfigEntry(const char* config_key) const; ///< Wraps OrtApi::HasSessionConfigEntry
std::string GetConfigEntryOrDefault(const char* config_key, const std::string& def) const;

bool GetMemPatternEnabled() const; ///< Wraps OrtApi::GetMemPatternEnabled
ExecutionMode GetExecutionMode() const; ///< Wraps OrtApi::GetSessionExecutionMode
};

template <typename T>
Expand Down
14 changes: 14 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,20 @@ inline std::string ConstSessionOptionsImpl<T>::GetConfigEntryOrDefault(const cha
return this->GetConfigEntry(config_key);
}

template <typename T>
inline bool ConstSessionOptionsImpl<T>::GetMemPatternEnabled() const {
int out = 0;
ThrowOnError(GetApi().GetMemPatternEnabled(this->p_, &out));
return out != 0;
}

template <typename T>
inline ExecutionMode ConstSessionOptionsImpl<T>::GetExecutionMode() const {
ExecutionMode out{};
ThrowOnError(GetApi().GetSessionExecutionMode(this->p_, &out));
return out;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetIntraOpNumThreads(int intra_op_num_threads) {
ThrowOnError(GetApi().SetIntraOpNumThreads(this->p_, intra_op_num_threads));
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionExecutionMode, _In_ OrtSessionOptions* op
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out) {
API_IMPL_BEGIN
ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL");
ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "'out' parameter must not be NULL");
*out = options->value.execution_mode;
return nullptr;
API_IMPL_END
}

// set filepath to save optimized onnx model.
ORT_API_STATUS_IMPL(OrtApis::SetOptimizedModelFilePath, _In_ OrtSessionOptions* options, _In_ const ORTCHAR_T* optimized_model_filepath) {
options->value.optimized_model_filepath = optimized_model_filepath;
Expand Down Expand Up @@ -149,6 +158,15 @@ ORT_API_STATUS_IMPL(OrtApis::DisableMemPattern, _In_ OrtSessionOptions* options)
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out) {
API_IMPL_BEGIN
ORT_API_RETURN_IF(options == nullptr, ORT_INVALID_ARGUMENT, "'options' parameter must not be NULL");
ORT_API_RETURN_IF(out == nullptr, ORT_INVALID_ARGUMENT, "'out' parameter must not be NULL");
*out = options->value.enable_mem_pattern ? 1 : 0;
return nullptr;
API_IMPL_END
}

// enable the memory arena on CPU
// Arena may pre-allocate memory for future usage.
// set this option to false if you don't want it.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4895,6 +4895,8 @@ static constexpr OrtApi ort_api_1_to_27 = {
&OrtApis::SetPerSessionThreadPoolCallbacks,
// End of Version 25 - DO NOT MODIFY ABOVE (see above text for more information)
// End of Version 26 - DO NOT MODIFY ABOVE (see above text for more information)
&OrtApis::GetMemPatternEnabled,
&OrtApis::GetSessionExecutionMode,
};
Comment thread
adrianlizarraga marked this conversation as resolved.
Comment thread
adrianlizarraga marked this conversation as resolved.

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ ORT_API_STATUS_IMPL(EnableProfiling, _In_ OrtSessionOptions* options, _In_ const
ORT_API_STATUS_IMPL(DisableProfiling, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(EnableMemPattern, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(DisableMemPattern, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(GetMemPatternEnabled, _In_ const OrtSessionOptions* options, _Out_ int* out);
ORT_API_STATUS_IMPL(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out);
ORT_API_STATUS_IMPL(EnableCpuMemArena, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(DisableCpuMemArena, _In_ OrtSessionOptions* options);
ORT_API_STATUS_IMPL(SetSessionLogId, _In_ OrtSessionOptions* options, const char* logid);
Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/test/shared_lib/test_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,36 @@ TEST(CApiTest, session_options_deterministic_compute) {
options.SetDeterministicCompute(true);
}

TEST(CApiTest, session_options_get_mem_pattern_enabled) {
Ort::SessionOptions options;

// Memory pattern is enabled by default
ASSERT_TRUE(options.GetMemPatternEnabled());

// Disable and verify
options.DisableMemPattern();
ASSERT_FALSE(options.GetMemPatternEnabled());

// Re-enable and verify
options.EnableMemPattern();
ASSERT_TRUE(options.GetMemPatternEnabled());
}

TEST(CApiTest, session_options_get_execution_mode) {
Ort::SessionOptions options;

// Default is sequential
ASSERT_EQ(options.GetExecutionMode(), ORT_SEQUENTIAL);

// Set to parallel and verify
options.SetExecutionMode(ORT_PARALLEL);
ASSERT_EQ(options.GetExecutionMode(), ORT_PARALLEL);

// Set back to sequential and verify
options.SetExecutionMode(ORT_SEQUENTIAL);
ASSERT_EQ(options.GetExecutionMode(), ORT_SEQUENTIAL);
}

#if !defined(ORT_MINIMAL_BUILD) && !defined(ORT_EXTENDED_MINIMAL_BUILD) && !defined(ORT_NO_EXCEPTIONS)

TEST(CApiTest, session_options_oversized_affinity_string) {
Expand Down
Loading