diff --git a/cmake/external/onnx b/cmake/external/onnx index 237926eab41de..174de7d086a76 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit 237926eab41de21fb9addc4b03b751fd6a3343ec +Subproject commit 174de7d086a768cba29374a56a7461eff87cfdb3 diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3a602b195aaa8..bd11d106c9127 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -197,6 +197,15 @@ typedef struct OrtAllocator { const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); } OrtAllocator; +typedef struct OrtAllocatorArena { + OrtAllocator *device_allocator; + void*(ORT_API_CALL* Alloc)(size_t size); + void(ORT_API_CALL* Free)(void* p); + void*(ORT_API_CALL* Reserve)(size_t size); + size_t(ORT_API_CALL* Used)(); + size_t(ORT_API_CALL* Max)(); +} OrtAllocatorArena; + typedef void(ORT_API_CALL* OrtLoggingFunction)( void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message); @@ -1157,6 +1166,51 @@ struct OrtApi { */ ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** + * Use this API to obtain a new allocated OrtAllocator* object whose inner field are the given inputs. + * It is the user responsibility to release the returned OrtAllocator*. + * \param version - C_API version (available from 7) + * \param AllocFunc - A function pointer to the callback that will be called upon every memory allocation. + * \param FreeFunc - A function pointer to the callback that will be called upon memory free. + * \param InfoFunc - A function pointer to a callback that returns OrtMemoryInfo* with OrtAllocatorType set to OrtDeviceAllocator. + * \param out - A place holder for the custom OrtAllocator. + The caller is responsible for freeing it. + */ + ORT_API2_STATUS(CreateCustomDeviceAllocator, uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*), + const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out); + + /** + * Use this API to obtain a new allocated OrtAllocatorArena* object whose inner field are the given inputs. + * It is the user responsibility to release the returned OrtAllocatorArena*. + * \param device_allocator - This is the underline device allocator that the arena allocator will use in. The Info inner field + * should return OrtMemoryInfo* with OrtAllocatorType set to OrtDeviceAllocator. + * \param AllocFunc - A function pointer to the callback that will be called upon calling Alloc from within arena context. + * \param FreeFunc - A function pointer to the callback that will be called upon calling Free from within arena context. + * \param ReserveFunc - A function pointer to the callback that will be called upon calling for reserving memory from within arena context. + * \param UsedFunc - A function pointer to the callback that will be called to get the total size of allocated memory from within arena context. + * \param FreeFunc - A function pointer to the callback that will be called to get the memory limit from within arena context. + * \param out - A place holder for the custom OrtAllocatorArena. + The caller is responsible for freeing it. + */ + ORT_API2_STATUS(CreateCustomArenaAllocator, _In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t), + size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out); + +/** + * Use this API to register a custom OrtAllocator* to the given env. Whenever a new session is created + * and associated with this env, if session_options is configured to use the env allocator instead of the default one, + * and not to use an arena allocator, then the memory management will be done by the given allocator. + * It is the user responsibility to release the OrtAllocator*. + */ + ORT_API2_STATUS(RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator); + + /** + * Use this API to register a custom OrtAllocatorArena* to the given env. Whenever a new session is created + * and associated with this env, if session_options is configured to use the env allocator instead of the default one, + * the memory management (which is set to as arena by default) will be done by the given allocator. + * It is the user responsibility to release the OrtAllocatorArena*. + */ + ORT_API2_STATUS(RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator); /** * Append TensorRT execution provider to the session options * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. diff --git a/onnxruntime/core/session/arena_allocator_impl.h b/onnxruntime/core/session/arena_allocator_impl.h new file mode 100644 index 0000000000000..4616ed4c81c5d --- /dev/null +++ b/onnxruntime/core/session/arena_allocator_impl.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/session/onnxruntime_c_api.h" +#include "core/framework/allocator.h" +#include "core/framework/arena.h" +#include "core/session/device_allocator.h" + +namespace onnxruntime { +class ArenaAllocatorWrapper : public IArenaAllocator { + public: + ArenaAllocatorWrapper(OrtAllocatorArena* impl) : IArenaAllocator(*impl->device_allocator->Info(impl->device_allocator)), + impl_(impl){} + void* Alloc(size_t size) override { + return impl_->Alloc(size); + } + void Free(void* p) override { + return impl_->Free(p); + } + void* Reserve(size_t size) override { + return impl_->Reserve(size); + } + size_t Used() const override { + return impl_->Used(); + } + size_t Max() const override { + return impl_->Max(); + } + + private: + OrtAllocatorArena* impl_; +}; + +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/session/device_allocator.cc b/onnxruntime/core/session/device_allocator.cc index 72e1a21416e7d..d080a8c4a0425 100644 --- a/onnxruntime/core/session/device_allocator.cc +++ b/onnxruntime/core/session/device_allocator.cc @@ -5,6 +5,7 @@ #include "core/session/inference_session.h" #include "core/session/ort_env.h" #include "core/session/allocator_impl.h" +#include "core/session/arena_allocator_impl.h" #ifndef ORT_NO_EXCEPTIONS @@ -53,6 +54,34 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _I return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator) { + using namespace onnxruntime; + if (!env) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); + } + IAllocator *allocator = new AllocatorWrapper(CustomAllocator); + auto st = env->RegisterAllocator(AllocatorPtr(allocator)); + + if (!st.IsOK()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); + } + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator) { + using namespace onnxruntime; + if (!env) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null"); + } + IArenaAllocator *allocator = new ArenaAllocatorWrapper(CustomArenaAllocator); + auto st = env->RegisterAllocator(AllocatorPtr(allocator)); + + if (!st.IsOK()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str()); + } + return nullptr; +} + ORT_API(void, OrtApis::ReleaseAllocator, _Frees_ptr_opt_ OrtAllocator* allocator) { delete reinterpret_cast(allocator); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ecafbc4ce5404..1fe88c2cc69bd 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -1122,6 +1122,28 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerName, _In_ const Or API_IMPL_END } + +ORT_API_STATUS_IMPL(OrtApis::CreateCustomDeviceAllocator, + uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*), + const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out) { + API_IMPL_BEGIN + OrtAllocator *ortAllocator = new OrtAllocator{version, AllocFunc, FreeFunc, InfoFunc}; + *out = ortAllocator; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateCustomArenaAllocator, + _In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t), + size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out) { + API_IMPL_BEGIN + OrtAllocatorArena *ortAllocatorArena = new OrtAllocatorArena {device_allocator, AllocFunc, FreeFunc, ReserveFunc, + UsedFunc, MaxFunc}; + *out = ortAllocatorArena; + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out) { API_IMPL_BEGIN *out = ptr->Alloc(ptr, size); @@ -2105,6 +2127,11 @@ static constexpr OrtApi ort_api_1_to_7 = { // End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::ModelMetadataGetGraphDescription, + &OrtApis::CreateCustomDeviceAllocator, + &OrtApis::CreateCustomArenaAllocator, + &OrtApis::RegisterCustomDeviceAllocator, + &OrtApis::RegisterCustomArenaAllocator, + &OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, &OrtApis::SetCurrentGpuDeviceId, &OrtApis::GetCurrentGpuDeviceId, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 10ab7328f3b17..03ec6720330cf 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -236,6 +236,14 @@ ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg); +ORT_API_STATUS_IMPL(CreateCustomDeviceAllocator, uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*), + const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out); + +ORT_API_STATUS_IMPL(CreateCustomArenaAllocator, _In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t), + size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out); +ORT_API_STATUS_IMPL(RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator); + +ORT_API_STATUS_IMPL(RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator); ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Out_ uint64_t* out); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index cdb755ff7b46d..6e8fc390c117a 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1381,3 +1381,190 @@ TEST(CApiTest, TestIncorrectInputTypeToModel_SequenceTensors) { ASSERT_TRUE(exception_thrown); } #endif + +std::atomic memory_inuse{0}; + +void* myAlloc(OrtAllocator *ptr, size_t size) { + ORT_UNUSED_PARAMETER(ptr); + constexpr size_t extra_len = sizeof(size_t); + memory_inuse.fetch_add(size += extra_len); + void* p = ::malloc(size); + if (p == nullptr) + return p; + *(size_t*)p = size; + return (char*)p + extra_len; +} + +void myFree(OrtAllocator *ptr, void* p) { + ORT_UNUSED_PARAMETER(ptr); + constexpr size_t extra_len = sizeof(size_t); + if (!p) return; + p = (char*)p - extra_len; + size_t len = *(size_t*)p; + memory_inuse.fetch_sub(len); + return ::free(p); +} + +const OrtMemoryInfo* myInfo(const OrtAllocator* allocator) { + ORT_UNUSED_PARAMETER(allocator); + const auto& api = Ort::GetApi(); + OrtMemoryInfo* mem_info = nullptr; + if (api.CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &mem_info) != nullptr) { + return nullptr; + } + return mem_info; +} + +// This test uses the CreateCustomDeviceAllocator and RegisterCustomDeviceAllocator APIs to register an external allocator with the env, +// that overrides the default Alloc and Free function +TEST(CApiTest, TestCustomDeviceAllocator) { + + // simple inference test + // prepare inputs + std::vector inputs(1); + Input& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + OrtEnv* env_ptr = (OrtEnv*)(*ort_env); + + const auto& api = Ort::GetApi(); + OrtAllocator *my_allocator = nullptr; + ASSERT_TRUE(api.CreateCustomDeviceAllocator(ORT_API_VERSION, myAlloc, myFree, myInfo, &my_allocator) == nullptr); + ASSERT_TRUE(api.RegisterCustomDeviceAllocator(env_ptr, my_allocator) == nullptr); + + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1"); + session_options.DisableCpuMemArena(); + + // create session while allocating the model with the custom allocator. + ASSERT_EQ(memory_inuse, 0); + size_t model_size = 256; + { + Ort::Session session1(*ort_env, MODEL_URI, session_options); + ASSERT_EQ(memory_inuse, model_size+sizeof(size_t)); + RunSession(my_allocator, + session1, + inputs, + "Y", + expected_dims_y, + expected_values_y, + nullptr); + } + ASSERT_EQ(memory_inuse, 0); + delete my_allocator; +} + +OrtAllocator *my_device_allocator = nullptr; +std::unordered_map reserved_chunks; + +void* ArenaDeviceAlloc(OrtAllocator *ptr, size_t size) { + ORT_UNUSED_PARAMETER(ptr); + void* p = malloc(size); + return p; +} + +void ArenaDeviceFree(OrtAllocator *ptr, void* p) { + ORT_UNUSED_PARAMETER(ptr); + free(p); +} + +const OrtMemoryInfo* myArenaInfo(const OrtAllocator* allocator) { + ORT_UNUSED_PARAMETER(allocator); + const auto& api = Ort::GetApi(); + OrtMemoryInfo* mem_info = nullptr; + if (api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) != nullptr) { + return nullptr; + } + return mem_info; +} + +void* ArenaAlloc(size_t size) { + for (auto & reserved_chunk : reserved_chunks) { + if (reserved_chunk.second == size) { + return reserved_chunk.first; + } + } + void* p = ArenaDeviceAlloc(my_device_allocator, size); + reserved_chunks.insert({p, size}); + memory_inuse += size; + return p; +} + +void ArenaFree(void* p) { + auto it = reserved_chunks.find(p); + ArenaDeviceFree(my_device_allocator, it->first); + memory_inuse -= it->second; +} + +void *ArenaReserve(size_t size) { + void *p = ArenaDeviceAlloc(my_device_allocator, size); + reserved_chunks.insert({p, size}); + memory_inuse += size; + return p; +} + +size_t ArenaUsed() { + return memory_inuse; +} + +size_t ArenaMax() { + return SIZE_MAX; +} + +// This test uses the CreateCustomArenaAllocator and RegisterCustomArenaAllocator APIs to register an external allocator with the env, +// that overrides the default Alloc and Free function +TEST(CApiTest, TestCustomArenaAllocator) { + + // simple inference test + // prepare inputs + std::vector inputs(1); + Input& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + // Reset the env so it won't have a registered allocator of type arena. + const auto& api = Ort::GetApi(); + ort_env.reset(); + OrtEnv *my_env_ptr = nullptr; + ASSERT_TRUE(api.CreateEnv(ORT_LOGGING_LEVEL_WARNING, "my_env", &my_env_ptr) == nullptr); + Ort::Env* env = new Ort::Env(my_env_ptr); + ort_env.reset(env); + ASSERT_TRUE(api.CreateCustomDeviceAllocator(ORT_API_VERSION, ArenaDeviceAlloc, + ArenaDeviceFree, myArenaInfo, &my_device_allocator) == nullptr); + OrtAllocatorArena *my_arena_allocator = nullptr; + ASSERT_TRUE(api.CreateCustomArenaAllocator(my_device_allocator, ArenaAlloc, ArenaFree, ArenaReserve, ArenaUsed, + ArenaMax, &my_arena_allocator) == nullptr); + ASSERT_TRUE(api.RegisterCustomArenaAllocator(my_env_ptr, my_arena_allocator) == nullptr); + + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1"); + + // create session while allocating the model with the custom arena allocator. + ASSERT_EQ(ArenaUsed(), 0); + size_t model_size = 256; + { + Ort::Session session1(*ort_env, MODEL_URI, session_options); + ASSERT_EQ(ArenaUsed(), model_size); + ASSERT_EQ(reserved_chunks.size(), 1); + RunSession(my_device_allocator, + session1, + inputs, + "Y", + expected_dims_y, + expected_values_y, + nullptr); + } + ASSERT_EQ(ArenaUsed(), 0); + delete my_device_allocator; + delete my_arena_allocator; +} \ No newline at end of file