Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 32 additions & 55 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyColl
{
using (var cleanupList = new DisposableList<IDisposable>())
{
var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputs, v => v.Name, cleanupList);
var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);
var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputNames, n => n, cleanupList);

var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList);
return CreateDisposableResult(ortValues, outputNames);
Expand Down Expand Up @@ -205,9 +205,9 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(

using (var cleanupList = new DisposableList<IDisposable>())
{
var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputNames, n => n, cleanupList);
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputNames, n => n, cleanupList);


var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList);
Expand Down Expand Up @@ -262,11 +262,11 @@ public void Run(
using (var cleanupList = new DisposableList<IDisposable>())
{
// prepare inputs
var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputNames, n => n, cleanupList);
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);

// prepare outputs
var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputNames, n => n, cleanupList);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false);

NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
Expand Down Expand Up @@ -310,12 +310,12 @@ public void Run(
IReadOnlyCollection<NamedOnnxValue> outputs,
RunOptions options)
{
using(var cleanupList = new DisposableList<IDisposable>())
using (var cleanupList = new DisposableList<IDisposable>())
{
var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);

var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList);

NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
Expand Down Expand Up @@ -367,14 +367,14 @@ public void Run(
throw new ArgumentException($"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count}).");
}

using(var cleanupList = new DisposableList<IDisposable>())
using (var cleanupList = new DisposableList<IDisposable>())
{
// prepare inputs
var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputs, i => i.Name, cleanupList);
var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList);

// prepare outputs
var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputNames, n => n, cleanupList);
var outputValuesArray = GetOrtValuesHandles(outputValues, false);

NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
Expand Down Expand Up @@ -428,14 +428,14 @@ public void Run(
throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count}).");
}

using(var cleanupList = new DisposableList<IDisposable>())
using (var cleanupList = new DisposableList<IDisposable>())
{
// prepare inputs
var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList);
var inputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(inputNames, n => n, cleanupList);
var inputValuesArray = GetOrtValuesHandles(inputValues, true);

// prepare outputs
var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
var outputNamesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(outputs, o => o.Name, cleanupList);
var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList);

NativeApiStatus.VerifySuccess(NativeMethods.OrtRun(
Expand Down Expand Up @@ -515,7 +515,8 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> RunWithBindingAnd
var ortValue = ortValues.ElementAt(i);
result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames[i], ortValue));
}
} catch(Exception e)
}
catch (Exception e)
{
result.Dispose();
throw e;
Expand All @@ -535,36 +536,12 @@ public string EndProfiling()
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionEndProfiling(_nativeHandle,
allocator.Pointer,
out nameHandle));
using(var allocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
using (var allocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
}
}

// Delegate for string extraction from an arbitrary input/output object
private delegate string NameExtractor<in TInput>(TInput input);

/// <summary>
/// Run helper
/// </summary>
/// <param name="names">names to convert to zero terminated utf8 and pin</param>
/// <param name="cleanupList">list to add pinned memory to for later disposal</param>
/// <returns></returns>
private IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> inputs, NameExtractor<T> extractor,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to a utility class to share code

DisposableList<IDisposable> cleanupList)
{
var result = new IntPtr[inputs.Count];
for (int i = 0; i < inputs.Count; ++i)
{
var name = extractor(inputs.ElementAt(i));
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
result[i] = pinnedHandle.Pointer;
cleanupList.Add(pinnedHandle);
}
return result;
}

/// <summary>
/// This function obtains ortValues for NamedOnnxValue.
/// The problem with NamedOnnxValue is that it does not contain any Onnx (OrtValue)
Expand Down Expand Up @@ -609,8 +586,8 @@ private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> v
}


private DisposableList<OrtValue> RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames,
DisposableList<IDisposable> cleanupList)
private DisposableList<OrtValue> RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames,
DisposableList<IDisposable> cleanupList)
{
var ortValues = new DisposableList<OrtValue>(outputNames.Length);
cleanupList.Add(ortValues);
Expand Down Expand Up @@ -680,11 +657,11 @@ public ModelMetadata ModelMetadata
/// </summary>
public ulong ProfilingStartTimeNs
{
get
{
return _profilingStartTimeNs;
}
}
get
{
return _profilingStartTimeNs;
}
}

#endregion

Expand Down Expand Up @@ -757,8 +734,8 @@ private void InitWithSessionHandle(IntPtr session, SessionOptions options)
// set profiling's start time
UIntPtr startTime = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetProfilingStartTimeNs(_nativeHandle,
out startTime));
_profilingStartTimeNs = (ulong) startTime;
out startTime));
_profilingStartTimeNs = (ulong)startTime;
}
catch (OnnxRuntimeException e)
{
Expand Down Expand Up @@ -821,7 +798,7 @@ private string GetOverridableInitializerName(ulong index)
(UIntPtr)index,
allocator.Pointer,
out nameHandle));
using(var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
}
Expand Down Expand Up @@ -963,7 +940,7 @@ public void Dispose()
/// <param name="disposing">true if invoked from Dispose() method</param>
protected virtual void Dispose(bool disposing)
{
if(_disposed)
if (_disposed)
{
return;
}
Expand Down Expand Up @@ -1137,7 +1114,7 @@ internal ModelMetadata(InferenceSession session)
}

// Process each key via the stored key handles
foreach(var allocation in ortAllocationKeys)
foreach (var allocation in ortAllocationKeys)
{
IntPtr keyHandle = allocation.Pointer;
IntPtr valueHandle = IntPtr.Zero;
Expand All @@ -1160,9 +1137,9 @@ internal ModelMetadata(InferenceSession session)
{

// Free ModelMetadata handle
NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle);
NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle);

}
}

}

Expand Down
50 changes: 50 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ public struct OrtApi
public IntPtr CreateArenaCfg;
public IntPtr ReleaseArenaCfg;
public IntPtr ModelMetadataGetGraphDescription;
public IntPtr CreateCUDAProviderOptions;
public IntPtr UpdateCUDAProviderOptions;
public IntPtr ReleaseCUDAProviderOptions;
}

internal static class NativeMethods
Expand Down Expand Up @@ -255,6 +258,8 @@ static NativeMethods()
OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary));
OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry));
OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer));
SessionOptionsAppendExecutionProvider_CUDA = (DSessionOptionsAppendExecutionProvider_CUDA)Marshal.GetDelegateForFunctionPointer(
api_.SessionOptionsAppendExecutionProvider_CUDA, typeof(DSessionOptionsAppendExecutionProvider_CUDA));

OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions));
OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions));
Expand Down Expand Up @@ -334,6 +339,10 @@ static NativeMethods()

OrtGetAvailableProviders = (DOrtGetAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.GetAvailableProviders, typeof(DOrtGetAvailableProviders));
OrtReleaseAvailableProviders = (DOrtReleaseAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAvailableProviders, typeof(DOrtReleaseAvailableProviders));

OrtCreateCUDAProviderOptions = (DOrtCreateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateCUDAProviderOptions, typeof(DOrtCreateCUDAProviderOptions));
OrtUpdateCUDAProviderOptions = (DOrtUpdateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateCUDAProviderOptions, typeof(DOrtUpdateCUDAProviderOptions));
OrtReleaseCUDAProviderOptions = (DOrtReleaseCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseCUDAProviderOptions, typeof(DOrtReleaseCUDAProviderOptions));
}

[DllImport(nativeLib, CharSet = charSet)]
Expand All @@ -356,6 +365,37 @@ static NativeMethods()

#endregion Runtime/Environment API

#region Provider Options API
/// <summary>
/// Creates native OrtCUDAProviderOptions instance
/// </summary>
/// <param name="cudaProviderOptionsInstance">(output) native instance of OrtCUDAProviderOptions</param>
public delegate IntPtr /* OrtStatus* */DOrtCreateCUDAProviderOptions(
out IntPtr /*(OrtCUDAProviderOptions**)*/ cudaProviderOptionsInstance);
public static DOrtCreateCUDAProviderOptions OrtCreateCUDAProviderOptions;

/// <summary>
/// Updates native OrtCUDAProviderOptions instance using given key/value pairs
/// </summary>
/// <param name="cudaProviderOptionsInstance">native instance of OrtCUDAProviderOptions</param>
/// <param name="providerOptionsKeys">configuration keys of OrtCUDAProviderOptions</param>
/// <param name="providerOptionsValues">configuration values of OrtCUDAProviderOptions</param>
/// <param name="numKeys">number of configuration keys</param>
public delegate IntPtr /* OrtStatus* */DOrtUpdateCUDAProviderOptions(
IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance,
IntPtr[] /*(const char* const *)*/ providerOptionsKeys,
IntPtr[] /*(const char* const *)*/ providerOptionsValues,
UIntPtr /*(size_t)*/ numKeys);
public static DOrtUpdateCUDAProviderOptions OrtUpdateCUDAProviderOptions;

/// <summary>
/// Releases native OrtCUDAProviderOptions instance
/// </summary>
/// <param name="cudaProviderOptionsInstance">native instance of OrtCUDAProviderOptions to be released</param>
public delegate void DOrtReleaseCUDAProviderOptions(IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance);
public static DOrtReleaseCUDAProviderOptions OrtReleaseCUDAProviderOptions;
#endregion

#region Status API
public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/status);
public static DOrtGetErrorCode OrtGetErrorCode;
Expand Down Expand Up @@ -560,6 +600,16 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id);

/// <summary>
/// Append a CUDA EP instance (configured based on given provider options) to the native OrtSessionOptions instance
/// </summary>
/// <param name="options">Native OrtSessionOptions instance</param>
/// <param name="cudaProviderOptions">Native OrtCUDAProviderOptions instance</param>
public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA(
IntPtr /*(OrtSessionOptions*)*/ options,
IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions);
public static DSessionOptionsAppendExecutionProvider_CUDA SessionOptionsAppendExecutionProvider_CUDA;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SessionOptionsAppendExecutionProvider_CUDA [](start = 66, length = 42)

this function's name looks different from the other append EP functions. can it be more consistent?

Copy link
Member Author

@hariharans29 hariharans29 Jan 9, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because they are fundamentally different. OrtSessionOptionsAppendExecutionProvider_CUDA is a symbol available in the shared library if the CUDA EP is built. SessionOptionsAppendExecutionProvider_CUDA is available via the C API struct always (whether built with CUDA support or not).

Also NativeMethods is an "internal concept". If you take a look at the public interface in SessionOptions.cs - the method made available to the user is consistent. There are two overloads of AppendExecutionProvider_CUDA - one calls the OrtSessionOptionsAppendExecutionProvider_CUDA and the other calls SessionOptionsAppendExecutionProvider_CUDA and all these details are abstracted from the user.

Unfortunately, I couldn't think of a way to make the naming here more consistent given that we already have a OrtSessionOptionsAppendExecutionProvider_CUDA() defined here. Open to suggestions though.


[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id);

Expand Down
32 changes: 30 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeOnnxValueHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Linq;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

System.Linq; [](start = 6, length = 12)

Check if this is really needed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

System.Linq;

Is this needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referencing the C# code here.
I think is for ElementAt() and if I don't include System.Linq, I will get:
error CS1061: 'IReadOnlyCollection' does not contain a definition for 'ElementAt' and no accessible extension method 'ElementAt' accepting a first argument of type 'IReadOnlyCollection' could be found

using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

Expand Down Expand Up @@ -41,7 +43,7 @@ public void Dispose()
// No need for the finalizer
// If this is not disposed timely GC can't help us
#endregion
}
}

/// <summary>
/// This helper class contains methods to create native OrtValue from a managed value object
Expand Down Expand Up @@ -77,14 +79,40 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8)
Marshal.Copy(nativeUtf8, buffer, 0, len);
return Encoding.UTF8.GetString(buffer, 0, buffer.Length);
}

/// <summary>
/// Run helper
/// </summary>
/// <param name="names">names to convert to zero terminated utf8 and pin</param>
/// <param name="extractor">delegate for string extraction from inputs</param>
/// <param name="cleanupList">list to add pinned memory to for later disposal</param>
/// <returns></returns>
internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, NameExtractor<T> extractor,
DisposableList<IDisposable> cleanupList)
{
var result = new IntPtr[names.Count];
for (int i = 0; i < names.Count; ++i)
{
var name = extractor(names.ElementAt(i));
var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name);
var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned));
result[i] = pinnedHandle.Pointer;
cleanupList.Add(pinnedHandle);
}
return result;
}

// Delegate for string extraction from an arbitrary input/output object
internal delegate string NameExtractor<in TInput>(TInput input);

}

internal static class TensorElementTypeConverter
{
public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width)
{
TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType);
if(result != null)
if (result != null)
{
type = result.TensorType;
width = result.TypeSize;
Expand Down
Loading