Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external/onnx
Submodule onnx updated 457 files
54 changes: 54 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ typedef struct OrtAllocator {
const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_);
} OrtAllocator;

typedef struct OrtAllocatorArena {
OrtAllocator *device_allocator;
void*(ORT_API_CALL* Alloc)(size_t size);
void(ORT_API_CALL* Free)(void* p);
void*(ORT_API_CALL* Reserve)(size_t size);
size_t(ORT_API_CALL* Used)();
size_t(ORT_API_CALL* Max)();
} OrtAllocatorArena;

typedef void(ORT_API_CALL* OrtLoggingFunction)(
void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location,
const char* message);
Expand Down Expand Up @@ -1157,6 +1166,51 @@ struct OrtApi {
*/
ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata,
_Inout_ OrtAllocator* allocator, _Outptr_ char** value);

/**
* Use this API to obtain a new allocated OrtAllocator* object whose inner field are the given inputs.
* It is the user responsibility to release the returned OrtAllocator*.
* \param version - C_API version (available from 7)
* \param AllocFunc - A function pointer to the callback that will be called upon every memory allocation.
* \param FreeFunc - A function pointer to the callback that will be called upon memory free.
* \param InfoFunc - A function pointer to a callback that returns OrtMemoryInfo* with OrtAllocatorType set to OrtDeviceAllocator.
* \param out - A place holder for the custom OrtAllocator.
The caller is responsible for freeing it.
*/
ORT_API2_STATUS(CreateCustomDeviceAllocator, uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*),
const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out);

/**
* Use this API to obtain a new allocated OrtAllocatorArena* object whose inner field are the given inputs.
* It is the user responsibility to release the returned OrtAllocatorArena*.
* \param device_allocator - This is the underline device allocator that the arena allocator will use in. The Info inner field
* should return OrtMemoryInfo* with OrtAllocatorType set to OrtDeviceAllocator.
* \param AllocFunc - A function pointer to the callback that will be called upon calling Alloc from within arena context.
* \param FreeFunc - A function pointer to the callback that will be called upon calling Free from within arena context.
* \param ReserveFunc - A function pointer to the callback that will be called upon calling for reserving memory from within arena context.
* \param UsedFunc - A function pointer to the callback that will be called to get the total size of allocated memory from within arena context.
* \param FreeFunc - A function pointer to the callback that will be called to get the memory limit from within arena context.
* \param out - A place holder for the custom OrtAllocatorArena.
The caller is responsible for freeing it.
*/
ORT_API2_STATUS(CreateCustomArenaAllocator, _In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t),
size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out);

/**
* Use this API to register a custom OrtAllocator* to the given env. Whenever a new session is created
* and associated with this env, if session_options is configured to use the env allocator instead of the default one,
* and not to use an arena allocator, then the memory management will be done by the given allocator.
* It is the user responsibility to release the OrtAllocator*.
*/
ORT_API2_STATUS(RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator);

/**
* Use this API to register a custom OrtAllocatorArena* to the given env. Whenever a new session is created
* and associated with this env, if session_options is configured to use the env allocator instead of the default one,
* the memory management (which is set to as arena by default) will be done by the given allocator.
* It is the user responsibility to release the OrtAllocatorArena*.
*/
ORT_API2_STATUS(RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator);
/**
* Append TensorRT execution provider to the session options
* If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure.
Expand Down
35 changes: 35 additions & 0 deletions onnxruntime/core/session/arena_allocator_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "core/session/onnxruntime_c_api.h"
#include "core/framework/allocator.h"
#include "core/framework/arena.h"
#include "core/session/device_allocator.h"

namespace onnxruntime {
class ArenaAllocatorWrapper : public IArenaAllocator {
public:
ArenaAllocatorWrapper(OrtAllocatorArena* impl) : IArenaAllocator(*impl->device_allocator->Info(impl->device_allocator)),
impl_(impl){}
void* Alloc(size_t size) override {
return impl_->Alloc(size);
}
void Free(void* p) override {
return impl_->Free(p);
}
void* Reserve(size_t size) override {
return impl_->Reserve(size);
}
size_t Used() const override {
return impl_->Used();
}
size_t Max() const override {
return impl_->Max();
}

private:
OrtAllocatorArena* impl_;
};

} // namespace onnxruntime
29 changes: 29 additions & 0 deletions onnxruntime/core/session/device_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/session/inference_session.h"
#include "core/session/ort_env.h"
#include "core/session/allocator_impl.h"
#include "core/session/arena_allocator_impl.h"

#ifndef ORT_NO_EXCEPTIONS

Expand Down Expand Up @@ -53,6 +54,34 @@ ORT_API_STATUS_IMPL(OrtApis::CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _I
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator) {
using namespace onnxruntime;
if (!env) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null");
}
IAllocator *allocator = new AllocatorWrapper(CustomAllocator);
auto st = env->RegisterAllocator(AllocatorPtr(allocator));

if (!st.IsOK()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str());
}
return nullptr;
}

ORT_API_STATUS_IMPL(OrtApis::RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator) {
using namespace onnxruntime;
if (!env) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Env is null");
}
IArenaAllocator *allocator = new ArenaAllocatorWrapper(CustomArenaAllocator);
auto st = env->RegisterAllocator(AllocatorPtr(allocator));

if (!st.IsOK()) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, st.ErrorMessage().c_str());
}
return nullptr;
}

ORT_API(void, OrtApis::ReleaseAllocator, _Frees_ptr_opt_ OrtAllocator* allocator) {
delete reinterpret_cast<onnxruntime::OrtAllocatorForDevice*>(allocator);
}
27 changes: 27 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,28 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetOverridableInitializerName, _In_ const Or
API_IMPL_END
}


ORT_API_STATUS_IMPL(OrtApis::CreateCustomDeviceAllocator,
uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*),
const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out) {
API_IMPL_BEGIN
OrtAllocator *ortAllocator = new OrtAllocator{version, AllocFunc, FreeFunc, InfoFunc};
*out = ortAllocator;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::CreateCustomArenaAllocator,
_In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t),
size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out) {
API_IMPL_BEGIN
OrtAllocatorArena *ortAllocatorArena = new OrtAllocatorArena {device_allocator, AllocFunc, FreeFunc, ReserveFunc,
UsedFunc, MaxFunc};
*out = ortAllocatorArena;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::AllocatorAlloc, _Inout_ OrtAllocator* ptr, size_t size, _Outptr_ void** out) {
API_IMPL_BEGIN
*out = ptr->Alloc(ptr, size);
Expand Down Expand Up @@ -2105,6 +2127,11 @@ static constexpr OrtApi ort_api_1_to_7 = {
// End of Version 6 - DO NOT MODIFY ABOVE (see above text for more information)

&OrtApis::ModelMetadataGetGraphDescription,
&OrtApis::CreateCustomDeviceAllocator,
&OrtApis::CreateCustomArenaAllocator,
&OrtApis::RegisterCustomDeviceAllocator,
&OrtApis::RegisterCustomArenaAllocator,

&OrtApis::SessionOptionsAppendExecutionProvider_TensorRT,
&OrtApis::SetCurrentGpuDeviceId,
&OrtApis::GetCurrentGpuDeviceId,
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,14 @@ ORT_API_STATUS_IMPL(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options,
ORT_API_STATUS_IMPL(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out);

ORT_API_STATUS_IMPL(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg);
ORT_API_STATUS_IMPL(CreateCustomDeviceAllocator, uint32_t version, void* AllocFunc(OrtAllocator*, size_t), void FreeFunc(OrtAllocator*, void*),
const OrtMemoryInfo* InfoFunc(const OrtAllocator*), _Outptr_ OrtAllocator** out);

ORT_API_STATUS_IMPL(CreateCustomArenaAllocator, _In_ OrtAllocator* device_allocator, void* AllocFunc(size_t), void FreeFunc(void*), void* ReserveFunc(size_t),
size_t UsedFunc(void), size_t MaxFunc(void), _Outptr_ OrtAllocatorArena** out);
ORT_API_STATUS_IMPL(RegisterCustomDeviceAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator *CustomAllocator);

ORT_API_STATUS_IMPL(RegisterCustomArenaAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocatorArena *CustomArenaAllocator);

ORT_API_STATUS_IMPL(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection);
ORT_API_STATUS_IMPL(SessionGetProfilingStartTimeNs, _In_ const OrtSession* sess, _Out_ uint64_t* out);
Expand Down
187 changes: 187 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,3 +1381,190 @@ TEST(CApiTest, TestIncorrectInputTypeToModel_SequenceTensors) {
ASSERT_TRUE(exception_thrown);
}
#endif

std::atomic<size_t> memory_inuse{0};

void* myAlloc(OrtAllocator *ptr, size_t size) {
ORT_UNUSED_PARAMETER(ptr);
constexpr size_t extra_len = sizeof(size_t);
memory_inuse.fetch_add(size += extra_len);
void* p = ::malloc(size);
if (p == nullptr)
return p;
*(size_t*)p = size;
return (char*)p + extra_len;
}

void myFree(OrtAllocator *ptr, void* p) {
ORT_UNUSED_PARAMETER(ptr);
constexpr size_t extra_len = sizeof(size_t);
if (!p) return;
p = (char*)p - extra_len;
size_t len = *(size_t*)p;
memory_inuse.fetch_sub(len);
return ::free(p);
}

const OrtMemoryInfo* myInfo(const OrtAllocator* allocator) {
ORT_UNUSED_PARAMETER(allocator);
const auto& api = Ort::GetApi();
OrtMemoryInfo* mem_info = nullptr;
if (api.CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &mem_info) != nullptr) {
return nullptr;
}
return mem_info;
}

// This test uses the CreateCustomDeviceAllocator and RegisterCustomDeviceAllocator APIs to register an external allocator with the env,
// that overrides the default Alloc and Free function
TEST(CApiTest, TestCustomDeviceAllocator) {

// simple inference test
// prepare inputs
std::vector<Input> inputs(1);
Input& input = inputs.back();
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
OrtEnv* env_ptr = (OrtEnv*)(*ort_env);

const auto& api = Ort::GetApi();
OrtAllocator *my_allocator = nullptr;
ASSERT_TRUE(api.CreateCustomDeviceAllocator(ORT_API_VERSION, myAlloc, myFree, myInfo, &my_allocator) == nullptr);
ASSERT_TRUE(api.RegisterCustomDeviceAllocator(env_ptr, my_allocator) == nullptr);

Ort::SessionOptions session_options;
session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1");
session_options.DisableCpuMemArena();

// create session while allocating the model with the custom allocator.
ASSERT_EQ(memory_inuse, 0);
size_t model_size = 256;
{
Ort::Session session1(*ort_env, MODEL_URI, session_options);
ASSERT_EQ(memory_inuse, model_size+sizeof(size_t));
RunSession<float>(my_allocator,
session1,
inputs,
"Y",
expected_dims_y,
expected_values_y,
nullptr);
}
ASSERT_EQ(memory_inuse, 0);
delete my_allocator;
}

OrtAllocator *my_device_allocator = nullptr;
std::unordered_map<void*, size_t> reserved_chunks;

void* ArenaDeviceAlloc(OrtAllocator *ptr, size_t size) {
ORT_UNUSED_PARAMETER(ptr);
void* p = malloc(size);
return p;
}

void ArenaDeviceFree(OrtAllocator *ptr, void* p) {
ORT_UNUSED_PARAMETER(ptr);
free(p);
}

const OrtMemoryInfo* myArenaInfo(const OrtAllocator* allocator) {
ORT_UNUSED_PARAMETER(allocator);
const auto& api = Ort::GetApi();
OrtMemoryInfo* mem_info = nullptr;
if (api.CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &mem_info) != nullptr) {
return nullptr;
}
return mem_info;
}

void* ArenaAlloc(size_t size) {
for (auto & reserved_chunk : reserved_chunks) {
if (reserved_chunk.second == size) {
return reserved_chunk.first;
}
}
void* p = ArenaDeviceAlloc(my_device_allocator, size);
reserved_chunks.insert({p, size});
memory_inuse += size;
return p;
}

void ArenaFree(void* p) {
auto it = reserved_chunks.find(p);
ArenaDeviceFree(my_device_allocator, it->first);
memory_inuse -= it->second;
}

void *ArenaReserve(size_t size) {
void *p = ArenaDeviceAlloc(my_device_allocator, size);
reserved_chunks.insert({p, size});
memory_inuse += size;
return p;
}

size_t ArenaUsed() {
return memory_inuse;
}

size_t ArenaMax() {
return SIZE_MAX;
}

// This test uses the CreateCustomArenaAllocator and RegisterCustomArenaAllocator APIs to register an external allocator with the env,
// that overrides the default Alloc and Free function
TEST(CApiTest, TestCustomArenaAllocator) {

// simple inference test
// prepare inputs
std::vector<Input> inputs(1);
Input& input = inputs.back();
input.name = "X";
input.dims = {3, 2};
input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};

// prepare expected inputs and outputs
std::vector<int64_t> expected_dims_y = {3, 2};
std::vector<float> expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};

// Reset the env so it won't have a registered allocator of type arena.
const auto& api = Ort::GetApi();
ort_env.reset();
OrtEnv *my_env_ptr = nullptr;
ASSERT_TRUE(api.CreateEnv(ORT_LOGGING_LEVEL_WARNING, "my_env", &my_env_ptr) == nullptr);
Ort::Env* env = new Ort::Env(my_env_ptr);
ort_env.reset(env);
ASSERT_TRUE(api.CreateCustomDeviceAllocator(ORT_API_VERSION, ArenaDeviceAlloc,
ArenaDeviceFree, myArenaInfo, &my_device_allocator) == nullptr);
OrtAllocatorArena *my_arena_allocator = nullptr;
ASSERT_TRUE(api.CreateCustomArenaAllocator(my_device_allocator, ArenaAlloc, ArenaFree, ArenaReserve, ArenaUsed,
ArenaMax, &my_arena_allocator) == nullptr);
ASSERT_TRUE(api.RegisterCustomArenaAllocator(my_env_ptr, my_arena_allocator) == nullptr);

Ort::SessionOptions session_options;
session_options.AddConfigEntry(kOrtSessionOptionsConfigUseEnvAllocators, "1");

// create session while allocating the model with the custom arena allocator.
ASSERT_EQ(ArenaUsed(), 0);
size_t model_size = 256;
{
Ort::Session session1(*ort_env, MODEL_URI, session_options);
ASSERT_EQ(ArenaUsed(), model_size);
ASSERT_EQ(reserved_chunks.size(), 1);
RunSession<float>(my_device_allocator,
session1,
inputs,
"Y",
expected_dims_y,
expected_values_y,
nullptr);
}
ASSERT_EQ(ArenaUsed(), 0);
delete my_device_allocator;
delete my_arena_allocator;
}