-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Add support in C# to configure a CUDA EP instance #6291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5d14b08
36aa7b4
94e4603
9c74b15
6b8ca35
2fc1a8b
1e28cfa
4b054a1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -189,6 +189,9 @@ public struct OrtApi | |
| public IntPtr CreateArenaCfg; | ||
| public IntPtr ReleaseArenaCfg; | ||
| public IntPtr ModelMetadataGetGraphDescription; | ||
| public IntPtr CreateCUDAProviderOptions; | ||
| public IntPtr UpdateCUDAProviderOptions; | ||
| public IntPtr ReleaseCUDAProviderOptions; | ||
| } | ||
|
|
||
| internal static class NativeMethods | ||
|
|
@@ -255,6 +258,8 @@ static NativeMethods() | |
| OrtRegisterCustomOpsLibrary = (DOrtRegisterCustomOpsLibrary)Marshal.GetDelegateForFunctionPointer(api_.RegisterCustomOpsLibrary, typeof(DOrtRegisterCustomOpsLibrary)); | ||
| OrtAddSessionConfigEntry = (DOrtAddSessionConfigEntry)Marshal.GetDelegateForFunctionPointer(api_.AddSessionConfigEntry, typeof(DOrtAddSessionConfigEntry)); | ||
| OrtAddInitializer = (DOrtAddInitializer)Marshal.GetDelegateForFunctionPointer(api_.AddInitializer, typeof(DOrtAddInitializer)); | ||
| SessionOptionsAppendExecutionProvider_CUDA = (DSessionOptionsAppendExecutionProvider_CUDA)Marshal.GetDelegateForFunctionPointer( | ||
| api_.SessionOptionsAppendExecutionProvider_CUDA, typeof(DSessionOptionsAppendExecutionProvider_CUDA)); | ||
|
|
||
| OrtCreateRunOptions = (DOrtCreateRunOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateRunOptions, typeof(DOrtCreateRunOptions)); | ||
| OrtReleaseRunOptions = (DOrtReleaseRunOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseRunOptions, typeof(DOrtReleaseRunOptions)); | ||
|
|
@@ -334,6 +339,10 @@ static NativeMethods() | |
|
|
||
| OrtGetAvailableProviders = (DOrtGetAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.GetAvailableProviders, typeof(DOrtGetAvailableProviders)); | ||
| OrtReleaseAvailableProviders = (DOrtReleaseAvailableProviders)Marshal.GetDelegateForFunctionPointer(api_.ReleaseAvailableProviders, typeof(DOrtReleaseAvailableProviders)); | ||
|
|
||
| OrtCreateCUDAProviderOptions = (DOrtCreateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.CreateCUDAProviderOptions, typeof(DOrtCreateCUDAProviderOptions)); | ||
| OrtUpdateCUDAProviderOptions = (DOrtUpdateCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.UpdateCUDAProviderOptions, typeof(DOrtUpdateCUDAProviderOptions)); | ||
| OrtReleaseCUDAProviderOptions = (DOrtReleaseCUDAProviderOptions)Marshal.GetDelegateForFunctionPointer(api_.ReleaseCUDAProviderOptions, typeof(DOrtReleaseCUDAProviderOptions)); | ||
| } | ||
|
|
||
| [DllImport(nativeLib, CharSet = charSet)] | ||
|
|
@@ -356,6 +365,37 @@ static NativeMethods() | |
|
|
||
| #endregion Runtime/Environment API | ||
|
|
||
| #region Provider Options API | ||
| /// <summary> | ||
| /// Creates native OrtCUDAProviderOptions instance | ||
| /// </summary> | ||
| /// <param name="cudaProviderOptionsInstance">(output) native instance of OrtCUDAProviderOptions</param> | ||
| public delegate IntPtr /* OrtStatus* */DOrtCreateCUDAProviderOptions( | ||
| out IntPtr /*(OrtCUDAProviderOptions**)*/ cudaProviderOptionsInstance); | ||
| public static DOrtCreateCUDAProviderOptions OrtCreateCUDAProviderOptions; | ||
|
|
||
| /// <summary> | ||
| /// Updates native OrtCUDAProviderOptions instance using given key/value pairs | ||
| /// </summary> | ||
| /// <param name="cudaProviderOptionsInstance">native instance of OrtCUDAProviderOptions</param> | ||
| /// <param name="providerOptionsKeys">configuration keys of OrtCUDAProviderOptions</param> | ||
| /// <param name="providerOptionsValues">configuration values of OrtCUDAProviderOptions</param> | ||
| /// <param name="numKeys">number of configuration keys</param> | ||
| public delegate IntPtr /* OrtStatus* */DOrtUpdateCUDAProviderOptions( | ||
| IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance, | ||
| IntPtr[] /*(const char* const *)*/ providerOptionsKeys, | ||
| IntPtr[] /*(const char* const *)*/ providerOptionsValues, | ||
| UIntPtr /*(size_t)*/ numKeys); | ||
| public static DOrtUpdateCUDAProviderOptions OrtUpdateCUDAProviderOptions; | ||
|
|
||
| /// <summary> | ||
| /// Releases native OrtCUDAProviderOptions instance | ||
| /// </summary> | ||
| /// <param name="cudaProviderOptionsInstance">native instance of OrtCUDAProviderOptions to be released</param> | ||
| public delegate void DOrtReleaseCUDAProviderOptions(IntPtr /*(OrtCUDAProviderOptions*)*/ cudaProviderOptionsInstance); | ||
| public static DOrtReleaseCUDAProviderOptions OrtReleaseCUDAProviderOptions; | ||
| #endregion | ||
|
|
||
| #region Status API | ||
| public delegate ErrorCode DOrtGetErrorCode(IntPtr /*(OrtStatus*)*/status); | ||
| public static DOrtGetErrorCode OrtGetErrorCode; | ||
|
|
@@ -560,6 +600,16 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca | |
| [DllImport(nativeLib, CharSet = charSet)] | ||
| public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_CUDA(IntPtr /*(OrtSessionOptions*) */ options, int device_id); | ||
|
|
||
| /// <summary> | ||
| /// Append a CUDA EP instance (configured based on given provider options) to the native OrtSessionOptions instance | ||
| /// </summary> | ||
| /// <param name="options">Native OrtSessionOptions instance</param> | ||
| /// <param name="cudaProviderOptions">Native OrtCUDAProviderOptions instance</param> | ||
| public delegate IntPtr /*(OrtStatus*)*/DSessionOptionsAppendExecutionProvider_CUDA( | ||
| IntPtr /*(OrtSessionOptions*)*/ options, | ||
| IntPtr /*(const OrtCUDAProviderOptions*)*/ cudaProviderOptions); | ||
| public static DSessionOptionsAppendExecutionProvider_CUDA SessionOptionsAppendExecutionProvider_CUDA; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
this function's name looks different from the other append EP functions. can it be more consistent?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That is because they are fundamentally different. Also NativeMethods is an "internal concept". If you take a look at the public interface in SessionOptions.cs - the method made available to the user is consistent. There are two overloads of Unfortunately, I couldn't think of a way to make the naming here more consistent given that we already have a |
||
|
|
||
| [DllImport(nativeLib, CharSet = charSet)] | ||
| public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_DML(IntPtr /*(OrtSessionOptions*) */ options, int device_id); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,8 @@ | |
|
|
||
| using Microsoft.ML.OnnxRuntime.Tensors; | ||
| using System; | ||
| using System.Linq; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Check if this is really needed
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm referencing the C# code here. |
||
| using System.Collections.Generic; | ||
| using System.Runtime.InteropServices; | ||
| using System.Text; | ||
|
|
||
|
|
@@ -41,7 +43,7 @@ public void Dispose() | |
| // No need for the finalizer | ||
| // If this is not disposed timely GC can't help us | ||
| #endregion | ||
| } | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// This helper class contains methods to create native OrtValue from a managed value object | ||
|
|
@@ -77,14 +79,40 @@ internal static string StringFromNativeUtf8(IntPtr nativeUtf8) | |
| Marshal.Copy(nativeUtf8, buffer, 0, len); | ||
| return Encoding.UTF8.GetString(buffer, 0, buffer.Length); | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Run helper | ||
| /// </summary> | ||
| /// <param name="names">names to convert to zero terminated utf8 and pin</param> | ||
| /// <param name="extractor">delegate for string extraction from inputs</param> | ||
| /// <param name="cleanupList">list to add pinned memory to for later disposal</param> | ||
| /// <returns></returns> | ||
| internal static IntPtr[] ConvertNamesToUtf8<T>(IReadOnlyCollection<T> names, NameExtractor<T> extractor, | ||
| DisposableList<IDisposable> cleanupList) | ||
| { | ||
| var result = new IntPtr[names.Count]; | ||
| for (int i = 0; i < names.Count; ++i) | ||
| { | ||
| var name = extractor(names.ElementAt(i)); | ||
| var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name); | ||
| var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned)); | ||
| result[i] = pinnedHandle.Pointer; | ||
| cleanupList.Add(pinnedHandle); | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| // Delegate for string extraction from an arbitrary input/output object | ||
| internal delegate string NameExtractor<in TInput>(TInput input); | ||
|
|
||
| } | ||
|
|
||
| internal static class TensorElementTypeConverter | ||
| { | ||
| public static void GetTypeAndWidth(TensorElementType elemType, out Type type, out int width) | ||
| { | ||
| TensorElementTypeInfo result = TensorBase.GetElementTypeInfo(elemType); | ||
| if(result != null) | ||
| if (result != null) | ||
| { | ||
| type = result.TensorType; | ||
| width = result.TypeSize; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to a utility class to share code