Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 40 additions & 8 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ public struct OrtApi
public IntPtr SetGlobalInterOpNumThreads;
public IntPtr SetGlobalSpinControl;
public IntPtr AddInitializer;
public IntPtr CreateEnvWithCustomLoggerAndGlobalThreadPools;
public IntPtr SessionOptionsAppendExecutionProvider_CUDA;
public IntPtr SessionOptionsAppendExecutionProvider_OpenVINO;
public IntPtr SetGlobalDenormalAsZero;
public IntPtr CreateArenaCfg;
public IntPtr ReleaseArenaCfg;
}

internal static class NativeMethods
Expand Down Expand Up @@ -260,6 +266,9 @@ static NativeMethods()
OrtRunOptionsSetTerminate = (DOrtRunOptionsSetTerminate)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsSetTerminate, typeof(DOrtRunOptionsSetTerminate));
OrtRunOptionsUnsetTerminate = (DOrtRunOptionsUnsetTerminate)Marshal.GetDelegateForFunctionPointer(api_.RunOptionsUnsetTerminate, typeof(DOrtRunOptionsUnsetTerminate));

OrtCreateArenaCfg = (DOrtCreateArenaCfg)Marshal.GetDelegateForFunctionPointer(api_.CreateArenaCfg, typeof(DOrtCreateArenaCfg));
OrtReleaseArenaCfg = (DOrtReleaseArenaCfg)Marshal.GetDelegateForFunctionPointer(api_.ReleaseArenaCfg, typeof(DOrtReleaseArenaCfg));
OrtReleaseAllocator = (DOrtReleaseAllocator)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAllocator, typeof(DOrtReleaseAllocator));
OrtCreateMemoryInfo = (DOrtCreateMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.CreateMemoryInfo, typeof(DOrtCreateMemoryInfo));
OrtCreateCpuMemoryInfo = (DOrtCreateCpuMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.CreateCpuMemoryInfo, typeof(DOrtCreateCpuMemoryInfo));
OrtReleaseMemoryInfo = (DOrtReleaseMemoryInfo)Marshal.GetDelegateForFunctionPointer(api_.ReleaseMemoryInfo, typeof(DOrtReleaseMemoryInfo));
Expand Down Expand Up @@ -311,7 +320,6 @@ static NativeMethods()
OrtGetTensorShapeElementCount = (DOrtGetTensorShapeElementCount)Marshal.GetDelegateForFunctionPointer(api_.GetTensorShapeElementCount, typeof(DOrtGetTensorShapeElementCount));
OrtReleaseValue = (DOrtReleaseValue)Marshal.GetDelegateForFunctionPointer(api_.ReleaseValue, typeof(DOrtReleaseValue));


OrtSessionGetModelMetadata = (DOrtSessionGetModelMetadata)Marshal.GetDelegateForFunctionPointer(api_.SessionGetModelMetadata, typeof(DOrtSessionGetModelMetadata));
OrtModelMetadataGetProducerName = (DOrtModelMetadataGetProducerName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetProducerName, typeof(DOrtModelMetadataGetProducerName));
OrtModelMetadataGetGraphName = (DOrtModelMetadataGetGraphName)Marshal.GetDelegateForFunctionPointer(api_.ModelMetadataGetGraphName, typeof(DOrtModelMetadataGetGraphName));
Expand Down Expand Up @@ -708,6 +716,27 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
public delegate IntPtr /*(OrtStatus*)*/DOrtAllocatorGetInfo(IntPtr /*(const OrtAllocator*)*/ ptr, out IntPtr /*(const struct OrtMemoryInfo**)*/info);
public static DOrtAllocatorGetInfo OrtAllocatorGetInfo;

/// <summary>
/// Create an instance of arena configuration which will be used to create an arena based allocator
/// See docs/C_API.md for details on what the following parameters mean and how to choose these values
/// </summary>
/// <param name="maxMemory">Maximum amount of memory the arena allocates</param>
/// <param name="arenaExtendStrategy">Strategy for arena expansion</param>
/// <param name="initialChunkSizeBytes">Size of the region that the arena allocates first</param>
/// <param name="maxDeadBytesPerChunk">Maximum amount of fragmentation allowed per chunk</param>
/// <returns>Pointer to a native OrtStatus instance indicating success/failure of config creation</returns>
public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateArenaCfg(UIntPtr /*(size_t)*/ maxMemory, int /*(int)*/ arenaExtendStrategy,
int /*(int)*/ initialChunkSizeBytes, int /*(int)*/ maxDeadBytesPerChunk,
out IntPtr /*(OrtArenaCfg**)*/ arenaCfg);
public static DOrtCreateArenaCfg OrtCreateArenaCfg;

/// <summary>
/// Destroy an instance of an arena configuration instance
/// </summary>
/// <param name="arenaCfg">arena configuration instance to be destroyed</param>
public delegate void DOrtReleaseArenaCfg(IntPtr /*(OrtArenaCfg*)*/ arenaCfg);
public static DOrtReleaseArenaCfg OrtReleaseArenaCfg;

/// <summary>
/// Create an instance of allocator according to mem_info
/// </summary>
Expand Down Expand Up @@ -861,13 +890,16 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca

/// <summary>
/// Creates an allocator instance and registers it with the env to enable
///sharing between multiple sessions that use the same env instance.
///Lifetime of the created allocator will be valid for the duration of the environment.
///Returns an error if an allocator with the same OrtMemoryInfo is already registered.
/// </summary>
/// <param name="mem_info">must be non-null</param>
/// <param name="arena_cfg">if nullptr defaults will be used</param>
public delegate void DOrtCreateAndRegisterAllocator(IntPtr /*(OrtIoBinding)*/ io_binding);
/// sharing between multiple sessions that use the same env instance.
/// Lifetime of the created allocator will be valid for the duration of the environment.
/// Returns an error if an allocator with the same OrtMemoryInfo is already registered.
/// <param name="env">Native OrtEnv instance</param>
/// <param name="memInfo">Native OrtMemoryInfo instance</param>
/// <param name="arenaCfg">Native OrtArenaCfg instance</param>
/// <retruns>A pointer to native ortStatus indicating success/failure</retruns>
public delegate IntPtr /*(OrtStatus*)*/ DOrtCreateAndRegisterAllocator(IntPtr /*(OrtEnv*)*/ env,
IntPtr /*(const OrtMemoryInfo*)*/ memInfo,
IntPtr/*(const OrtArenaCfg*)*/ arenaCfg);
public static DOrtCreateAndRegisterAllocator OrtCreateAndRegisterAllocator;

/// <summary>
Expand Down
11 changes: 11 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ public void DisableTelemetryEvents()
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableTelemetryEvents(Handle));
}

/// <summary>
/// Create and register an allocator to the OrtEnv instance
/// so as to enable sharing across all sessions using the OrtEnv instance
/// <param name="memInfo">OrtMemoryInfo instance to be used for allocator creation</param>
/// <param name="arenaCfg">OrtArenaCfg instance that will be used to define the behavior of the arena based allocator</param>
/// </summary>
public void CreateAndRegisterAllocator(OrtMemoryInfo memInfo, OrtArenaCfg arenaCfg)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateAndRegisterAllocator(Handle, memInfo.Pointer, arenaCfg.Pointer));
}

/// <summary>
/// Queries all the execution providers supported in the native onnxruntime shared library
/// </summary>
Expand Down
57 changes: 57 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/OrtAllocator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,63 @@ public enum OrtMemType
Default = 0, // the default allocator for execution provider
}

/// <summary>
/// This class encapsulates arena configuration information that will be used to define the behavior
/// of an arena based allocator
/// See docs/C_API.md for more details
/// </summary>
public class OrtArenaCfg : SafeHandle
{
/// <summary>
/// Create an instance of arena configuration which will be used to create an arena based allocator
/// See docs/C_API.md for details on what the following parameters mean and how to choose these values
/// </summary>
/// <param name="maxMemory">Maximum amount of memory the arena allocates</param>
/// <param name="arenaExtendStrategy">Strategy for arena expansion</param>
/// <param name="initialChunkSizeBytes">Size of the region that the arena allocates first</param>
/// <param name="maxDeadBytesPerChunk">Maximum amount of fragmentation allowed per chunk</param>
public OrtArenaCfg(uint maxMemory, int arenaExtendStrategy, int initialChunkSizeBytes, int maxDeadBytesPerChunk)
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateArenaCfg((UIntPtr)maxMemory,
arenaExtendStrategy,
initialChunkSizeBytes,
maxDeadBytesPerChunk,
out handle));
}

internal IntPtr Pointer
{
get
{
return handle;
}
}

#region SafeHandle

/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }

/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtEnv
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseArenaCfg(handle);
handle = IntPtr.Zero;
return true;
}

#endregion

}

/// <summary>
/// This class encapsulates and most of the time owns the underlying native OrtMemoryInfo instance.
/// Instance returned from OrtAllocator will not own OrtMemoryInfo, the class must be disposed
Expand Down
75 changes: 75 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2091,6 +2091,81 @@ private void TestWeightSharingBetweenSessions()
}
}

[Fact]
private void TestSharedAllocatorUsingCreateAndRegisterAllocator()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "mul_1.onnx");

using (var memInfo = new OrtMemoryInfo(OrtMemoryInfo.allocatorCPU,
OrtAllocatorType.ArenaAllocator, 0, OrtMemType.Default))
using (var arenaCfg = new OrtArenaCfg(0, -1, -1, -1))
{
var env = OrtEnv.Instance();
// Create and register the arena based allocator
env.CreateAndRegisterAllocator(memInfo, arenaCfg);

using (var sessionOptions = new SessionOptions())
{
// Key must match kOrtSessionOptionsConfigUseEnvAllocators in onnxruntime_session_options_config_keys.h
sessionOptions.AddSessionConfigEntry("session.use_env_allocators", "1");

// Create two sessions to share the allocator
// Create a thrid session that DOES NOT use the allocator in the environment
using (var session1 = new InferenceSession(modelPath, sessionOptions))
using (var session2 = new InferenceSession(modelPath, sessionOptions))
using (var session3 = new InferenceSession(modelPath)) // Use the default SessionOptions instance
{
// Input data
var inputDims = new long[] { 3, 2 };
var input = new float[] { 1.0F, 2.0F, 3.0F, 4.0F, 5.0F, 6.0F };

// Output data
int[] outputDims = { 3, 2 };
float[] output = { 1.0F, 4.0F, 9.0F, 16.0F, 25.0F, 36.0F };

// Run inference on all three models
var inputMeta = session1.InputMetadata;
var container = new List<NamedOnnxValue>();

foreach (var name in inputMeta.Keys)
{
Assert.Equal(typeof(float), inputMeta[name].ElementType);
Assert.True(inputMeta[name].IsTensor);
var tensor = new DenseTensor<float>(input, inputMeta[name].Dimensions);
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
}

// Run inference with named inputs and outputs created with in Run()
using (var results = session1.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
{
foreach (var r in results)
{
validateRunResultData(r.AsTensor<float>(), output, outputDims);
}
}

// Run inference with named inputs and outputs created with in Run()
using (var results = session2.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
{
foreach (var r in results)
{
validateRunResultData(r.AsTensor<float>(), output, outputDims);
}
}

// Run inference with named inputs and outputs created with in Run()
using (var results = session3.Run(container)) // results is an IReadOnlyList<NamedOnnxValue> container
{
foreach (var r in results)
{
validateRunResultData(r.AsTensor<float>(), output, outputDims);
}
}
}
}
}
}

[DllImport("kernel32", SetLastError = true)]
static extern IntPtr LoadLibrary(string lpFileName);

Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
#include "ortdevice.h"
#include "ortmemoryinfo.h"

// This configures the arena based allocator used by ORT
// See docs/C_API.md for details on what these mean and how to choose these values
struct OrtArenaCfg {
size_t max_mem; // use 0 to allow ORT to choose the default
int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
};

namespace onnxruntime {
constexpr const char* CPU = "Cpu";
constexpr const char* CUDA = "Cuda";
Expand Down
6 changes: 6 additions & 0 deletions include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class Environment {
*/
Status RegisterAllocator(AllocatorPtr allocator);

/**
* Creates and registers an allocator for sharing between multiple sessions.
* Return an error if an allocator with the same OrtMemoryInfo is already registered.
*/
Status CreateAndRegisterAllocator(const OrtMemoryInfo& mem_info, const OrtArenaCfg* arena_cfg = nullptr);

/**
* Returns the list of registered allocators in this env.
*/
Expand Down
26 changes: 17 additions & 9 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,6 @@ typedef enum OrtErrorCode {
ORT_EP_FAIL,
} OrtErrorCode;

// This configures the arena based allocator used by ORT
// See docs/C_API.md for details on what these mean and how to choose these values
typedef struct OrtArenaCfg {
size_t max_mem; // use 0 to allow ORT to choose the default
int arena_extend_strategy; // use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
int initial_chunk_size_bytes; // use -1 to allow ORT to choose the default
int max_dead_bytes_per_chunk; // use -1 to allow ORT to choose the default
} OrtArenaCfg;

#define ORT_RUNTIME_CLASS(X) \
struct Ort##X; \
typedef struct Ort##X Ort##X;
Expand All @@ -173,6 +164,7 @@ ORT_RUNTIME_CLASS(SequenceTypeInfo);
ORT_RUNTIME_CLASS(ModelMetadata);
ORT_RUNTIME_CLASS(ThreadPoolParams);
ORT_RUNTIME_CLASS(ThreadingOptions);
ORT_RUNTIME_CLASS(ArenaCfg);

#ifdef _WIN32
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
Expand Down Expand Up @@ -1126,6 +1118,22 @@ struct OrtApi {
* and that's recommended because turning this option on may hurt model accuracy.
*/
ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options);

/**
* Use this API to create the configuration of an arena that can eventually be used to define
* an arena based allocator's behavior
* \param max_mem - use 0 to allow ORT to choose the default
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
* \param out - a pointer to an OrtArenaCfg instance
* \return a nullptr in case of success or a pointer to an OrtStatus instance in case of failure
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
*/
ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out);

ORT_CLASS_RELEASE(ArenaCfg);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ORT_CLASS_RELEASE(ArenaCfg); [](start = 2, length = 28)

Normally, this would be declared automatically

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How so ?

};

/*
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ ORT_DEFINE_RELEASE(Value);
ORT_DEFINE_RELEASE(ModelMetadata);
ORT_DEFINE_RELEASE(ThreadingOptions);
ORT_DEFINE_RELEASE(IoBinding);
ORT_DEFINE_RELEASE(ArenaCfg);

/*! \class Ort::Float16_t
* \brief it is a structure that represents float16 data.
Expand Down Expand Up @@ -551,6 +552,23 @@ struct IoBinding : public Base<OrtIoBinding> {
void ClearBoundOutputs();
};

/*! \struct Ort::ArenaCfg
* \brief it is a structure that represents the configuration of an arena based allocator
* \details Please see docs/C_API.md for details
*/
struct ArenaCfg : Base<OrtArenaCfg> {
explicit ArenaCfg(std::nullptr_t) {}
/**
* \param max_mem - use 0 to allow ORT to choose the default
* \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
* \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
* \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
* \return an instance of ArenaCfg
* See docs/C_API.md for details on what the following parameters mean and how to choose these values
*/
ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArenaCfg [](start = 2, length = 8)

perhaps a factory method would be a good idea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why though ? Why move from an established pattern ? Can you please elaborate on why it will be a good idea for this ?

};

//
// Custom OPs (only needed to implement custom OPs)
//
Expand Down
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ inline void IoBinding::ClearBoundOutputs() {
GetApi().ClearBoundOutputs(p_);
}

inline ArenaCfg::ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk) {
ThrowOnError(GetApi().CreateArenaCfg(max_mem, arena_extend_strategy, initial_chunk_size_bytes, max_dead_bytes_per_chunk, &p_));
}

inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
if (strcmp(logid, "onnxruntime-node") == 0) {
Expand Down
Loading