Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ public class SessionOptions : SafeHandle
{
// Delay-loaded CUDA or cuDNN DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information.
private static string[] cudaDelayLoadedLibs = { };
private static string[] trtDelayLoadedLibs = { };

#region Constructor and Factory methods

Expand Down Expand Up @@ -75,6 +76,30 @@ public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId = 0)
return options;
}

/// <summary>
/// A helper method to construct a SessionOptions object for TensorRT execution.
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId"></param>
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithTensorrtProvider(int deviceId = 0)
{
CheckTensorrtExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(options.Handle, deviceId));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, deviceId));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
return options;
}
catch (Exception e)
{
options.Dispose();
throw;
}
}

/// <summary>
/// A helper method to construct a SessionOptions object for Nuphar execution.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
Expand Down Expand Up @@ -624,6 +649,27 @@ private static bool CheckCudaExecutionProviderDLLs()
return true;
}

private static bool CheckTensorrtExecutionProviderDLLs()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
foreach (var dll in trtDelayLoadedLibs)
{
IntPtr handle = LoadLibrary(dll);
if (handle != IntPtr.Zero)
continue;
var sysdir = new StringBuilder(String.Empty, 2048);
GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
throw new OnnxRuntimeException(
ErrorCode.NoSuchFile,
$"kernel32.LoadLibrary():'{dll}' not found. TensorRT/CUDA are required for GPU execution. " +
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
);
}
}
return true;
}


#endregion
#region SafeHandle
Expand Down
34 changes: 34 additions & 0 deletions csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,40 @@ public void CanCreateAndDisposeSessionWithModelPath()
}
}

#if USE_TENSORRT
[Fact]
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");

using (var cleanUp = new DisposableListTest<IDisposable>())
{
SessionOptions options = SessionOptions.MakeSessionOptionWithTensorrtProvider(0);
cleanUp.Add(options);

var session = new InferenceSession(modelPath, options);
cleanUp.Add(session);

var inputMeta = session.InputMetadata;
var container = new List<NamedOnnxValue>();
float[] inputData = LoadTensorFromFile(@"bench.in"); // this is the data for only one input tensor for this model
foreach (var name in inputMeta.Keys)
{
Assert.Equal(typeof(float), inputMeta[name].ElementType);
Assert.True(inputMeta[name].IsTensor);
var tensor = new DenseTensor<float>(inputData, inputMeta[name].Dimensions);
container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
}


using (var results = session.Run(container))
{
validateRunResults(results);
}
}
}
#endif

[Theory]
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)]
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)]
Expand Down