diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
index 0df4c77404898..815c365d321a2 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
@@ -192,6 +192,29 @@ public struct OrtApi
public IntPtr ModelMetadataGetGraphDescription;
}
+ #region ORT Provider options
+ [StructLayout(LayoutKind.Sequential)]
+ public struct OrtTensorRTProviderOptionsNative
+ {
+ public int device_id; // cuda device id.
+ public int has_user_compute_stream; // indicator of user specified CUDA compute stream.
+ public IntPtr user_compute_stream; // user specified CUDA compute stream.
+ public int has_trt_options; // override environment variables with following TensorRT settings at runtime.
+ public UIntPtr trt_max_workspace_size; // maximum workspace size for TensorRT.
+ public int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
+ public int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
+ public IntPtr trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
+ public int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
+ public int trt_max_partition_iterations; // maximum number of iterations allowed in model partitioning for TensorRT.
+ public int trt_min_subgraph_size; // minimum node size in a subgraph after partitioning.
+ public int trt_dump_subgraphs; // dump the subgraphs that are transformed into TRT engines in onnx format to the filesystem. Default 0 = false, nonzero = true
+ public int trt_engine_cache_enable; // enable TensorRT engine caching. Default 0 = false, nonzero = true
+ public IntPtr trt_cache_path; // specify path for TensorRT engine and profile files if engine_cache_enable is enabled, or INT8 calibration table file if trt_int8_enable is enabled.
+ }
+ #endregion
+
+
+
internal static class NativeMethods
{
private const string nativeLib = "onnxruntime";
@@ -574,6 +597,9 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_Tensorrt(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
+ [DllImport(nativeLib, CharSet = charSet)]
+ public static extern IntPtr /*(OrtStatus*)*/ SessionOptionsAppendExecutionProvider_TensorRT(IntPtr /*(OrtSessionOptions*)*/ options, ref OrtTensorRTProviderOptionsNative trt_options);
+
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtSessionOptionsAppendExecutionProvider_MIGraphX(IntPtr /*(OrtSessionOptions*)*/ options, int device_id);
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.cs
new file mode 100644
index 0000000000000..647e0c92a3cbb
--- /dev/null
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/ProviderOptions.cs
@@ -0,0 +1,114 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Runtime.InteropServices;
+
+namespace Microsoft.ML.OnnxRuntime
+{
+ ///
+ /// Provider options for TensorRT.
+ ///
+ // Example for setting:
+ // SessionOptions.OrtTensorRTProviderOptions trt_options;
+ // trt_options.device_id = 0;
+ // trt_options.has_trt_options = 1;
+ // trt_options.trt_max_workspace_size = (UIntPtr) (1<<30);
+ // trt_options.trt_fp16_enable = 1;
+ // trt_options.trt_int8_enable = 1;
+ // trt_options.trt_int8_calibration_table_name = "calibration.flatbuffers";
+ // trt_options.trt_int8_use_native_calibration_table = 0;
+ public struct OrtTensorRTProviderOptions
+ {
+ public int device_id; //!< cuda device id. Default is 0.
+ public int has_trt_options; //!< override environment variables with following TensorRT settings at runtime. Default 0 = false, nonzero = true.
+ public UIntPtr trt_max_workspace_size; //!< maximum workspace size for TensorRT. ORT C++ DLL has this field to be the type of size_t, hence using UIntPtr for conversion.
+ public int trt_fp16_enable; //!< enable TensorRT FP16 precision. Default 0 = false, nonzero = true.
+ public int trt_int8_enable; //!< enable TensorRT INT8 precision. Default 0 = false, nonzero = true.
+ public String trt_int8_calibration_table_name; //!< TensorRT INT8 calibration table name.
+ public int trt_int8_use_native_calibration_table; //!< use native TensorRT generated calibration table. Default 0 = false, nonzero = true
+ public int trt_max_partition_iterations; //!< maximum number of iterations allowed in model partitioning for TensorRT.
+ public int trt_min_subgraph_size; //!< minimum node size in a subgraph after partitioning.
+ public int trt_dump_subgraphs; //!< dump the subgraphs that are transformed into TRT engines in onnx format to the filesystem. Default 0 = false, nonzero = true
+ public int trt_engine_cache_enable; //!< enable TensorRT engine caching. Default 0 = false, nonzero = true
+ public String trt_cache_path; //!< specify path for TensorRT engine and profile files if engine_cache_enable is enabled, or INT8 calibration table file if trt_int8_enable is enabled.
+ }
+
+ ///
+ /// Holds provider options configuration for creating an InferenceSession.
+ ///
+ public class ProviderOptions : SafeHandle
+ {
+ internal IntPtr Handle
+ {
+ get
+ {
+ return handle;
+ }
+ }
+
+ #region Constructor and Factory methods
+
+ ///
+ /// Constructs an empty ProviderOptions
+ ///
+ public ProviderOptions()
+ : base(IntPtr.Zero, true)
+ {
+ }
+
+ #endregion
+
+ #region Public Methods
+
+ ///
+ /// Get TensorRT provider options with default setting.
+ ///
+ /// TRT provider options instance.
+ public static OrtTensorRTProviderOptions GetDefaultTensorRTProviderOptions()
+ {
+ OrtTensorRTProviderOptions trt_options;
+ trt_options.device_id = 0;
+ trt_options.has_trt_options = 0;
+ trt_options.trt_max_workspace_size = (UIntPtr)(1 << 30);
+ trt_options.trt_fp16_enable = 0;
+ trt_options.trt_int8_enable = 0;
+ trt_options.trt_int8_calibration_table_name = "";
+ trt_options.trt_int8_use_native_calibration_table = 0;
+ trt_options.trt_max_partition_iterations = 1000;
+ trt_options.trt_min_subgraph_size = 1;
+ trt_options.trt_dump_subgraphs = 0;
+ trt_options.trt_engine_cache_enable = 0;
+ trt_options.trt_cache_path = "";
+
+ return trt_options;
+ }
+ #endregion
+
+ #region Public Properties
+
+ ///
+ /// Overrides SafeHandle.IsInvalid
+ ///
+ /// returns true if handle is equal to Zero
+ public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
+
+ #endregion
+
+ #region SafeHandle
+ ///
+ /// Overrides SafeHandle.ReleaseHandle() to properly dispose of
+ /// the native instance of SessionOptions
+ ///
+ /// always returns true
+ protected override bool ReleaseHandle()
+ {
+ handle = IntPtr.Zero;
+ return true;
+ }
+
+ #endregion
+ }
+}
\ No newline at end of file
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
index 6bc48a0d704da..55f2a5d32f8b5 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
@@ -38,6 +38,7 @@ public class SessionOptions : SafeHandle
{
// Delay-loaded CUDA or cuDNN DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information.
private static string[] cudaDelayLoadedLibs = { };
+ private static string[] trtDelayLoadedLibs = { };
#region Constructor and Factory methods
@@ -75,6 +76,42 @@ public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId = 0)
return options;
}
+ ///
+ /// A helper method to construct a SessionOptions object for TensorRT execution.
+ /// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
+ ///
+ ///
+ /// A SessionsOptions() object configured for execution on deviceId
+ public static SessionOptions MakeSessionOptionWithTensorrtProvider(int deviceId = 0)
+ {
+ CheckTensorrtExecutionProviderDLLs();
+ SessionOptions options = new SessionOptions();
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(options.Handle, deviceId));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, deviceId));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
+ return options;
+ }
+
+ ///
+ /// A helper method to construct a SessionOptions object for TensorRT execution.
+ /// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// Provider Options for TensorRT EP.
+ /// A SessionsOptions() object configured for execution on deviceId
+ public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTProviderOptions trt_options)
+ {
+ CheckTensorrtExecutionProviderDLLs();
+ SessionOptions options = new SessionOptions();
+
+ OrtTensorRTProviderOptionsNative trt_options_native;
+ trt_options_native = PrepareNativeTensorRTProviderOptions(trt_options);
+
+ NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT(options.Handle, ref trt_options_native));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(options.Handle, trt_options.device_id));
+ NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(options.Handle, 1));
+ return options;
+ }
+
///
/// A helper method to construct a SessionOptions object for Nuphar execution.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
@@ -180,6 +217,18 @@ public void AppendExecutionProvider_Tensorrt(int deviceId)
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
}
+ ///
+ /// Use only if you have the onnxruntime package specific to this Execution Provider.
+ ///
+ /// Provider Options for TensorRT EP.
+ public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trt_options)
+ {
+ OrtTensorRTProviderOptionsNative trt_options_native;
+ trt_options_native = PrepareNativeTensorRTProviderOptions(trt_options);
+
+ NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT(handle, ref trt_options_native));
+ }
+
///
/// Use only if you have the onnxruntime package specific to this Execution Provider.
///
@@ -325,6 +374,7 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue)
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue));
}
}
+
#endregion
internal IntPtr Handle
@@ -624,6 +674,63 @@ private static bool CheckCudaExecutionProviderDLLs()
return true;
}
+ private static bool CheckTensorrtExecutionProviderDLLs()
+ {
+ if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
+ {
+ foreach (var dll in trtDelayLoadedLibs)
+ {
+ IntPtr handle = LoadLibrary(dll);
+ if (handle != IntPtr.Zero)
+ continue;
+ var sysdir = new StringBuilder(String.Empty, 2048);
+ GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
+ throw new OnnxRuntimeException(
+ ErrorCode.NoSuchFile,
+ $"kernel32.LoadLibrary():'{dll}' not found. TensorRT/CUDA are required for GPU execution. " +
+ $". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
+ );
+ }
+ }
+ return true;
+ }
+
+ private static OrtTensorRTProviderOptionsNative PrepareNativeTensorRTProviderOptions(OrtTensorRTProviderOptions trt_options)
+ {
+ OrtTensorRTProviderOptionsNative trt_options_native;
+ trt_options_native.device_id = trt_options.device_id;
+ trt_options_native.has_user_compute_stream = 0;
+ trt_options_native.user_compute_stream = IntPtr.Zero;
+ trt_options_native.has_trt_options = trt_options.has_trt_options;
+ if ((ulong)trt_options.trt_max_workspace_size > (1 << 30))
+ {
+ trt_options_native.trt_max_workspace_size = (UIntPtr)(1 << 30);
+ }
+ else
+ {
+ trt_options_native.trt_max_workspace_size = trt_options.trt_max_workspace_size;
+ }
+ trt_options_native.trt_fp16_enable = trt_options.trt_fp16_enable;
+ trt_options_native.trt_int8_enable = trt_options.trt_int8_enable;
+ var tableNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(trt_options.trt_int8_calibration_table_name), GCHandleType.Pinned);
+ using (var pinnedSettingsName = new PinnedGCHandle(tableNamePinned))
+ {
+ trt_options_native.trt_int8_calibration_table_name = pinnedSettingsName.Pointer;
+ }
+ trt_options_native.trt_int8_use_native_calibration_table = trt_options.trt_int8_use_native_calibration_table;
+ trt_options_native.trt_max_partition_iterations = trt_options.trt_max_partition_iterations;
+ trt_options_native.trt_min_subgraph_size = trt_options.trt_min_subgraph_size;
+ trt_options_native.trt_dump_subgraphs = trt_options.trt_dump_subgraphs;
+ trt_options_native.trt_engine_cache_enable = trt_options.trt_engine_cache_enable;
+ var cachePathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(trt_options.trt_cache_path), GCHandleType.Pinned);
+ using (var pinnedSettingsName2 = new PinnedGCHandle(cachePathPinned))
+ {
+ trt_options_native.trt_cache_path = pinnedSettingsName2.Pointer;
+ }
+
+ return trt_options_native;
+ }
+
#endregion
#region SafeHandle
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
index 3deb62bddd577..404161846a41b 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
@@ -227,6 +227,54 @@ public void CanCreateAndDisposeSessionWithModelPath()
}
}
+
+
+#if USE_TENSORRT
+ [Fact]
+ private void TestTensorRTProviderOptions()
+ {
+ string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");
+ string calTablPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet_calibration.flatbuffers");
+ //Environment.SetEnvironmentVariable("ORT_TENSORRT_ENGINE_CACHE_ENABLE", "1");
+
+ OrtTensorRTProviderOptions trt_options = ProviderOptions.GetDefaultTensorRTProviderOptions();
+ trt_options.device_id = 0;
+ trt_options.trt_int8_calibration_table_name = calTablPath;
+ trt_options.has_trt_options = 1;
+ trt_options.trt_max_workspace_size = (UIntPtr)(1 << 30);
+ trt_options.trt_fp16_enable = 1;
+ trt_options.trt_int8_enable = 1;
+ trt_options.trt_int8_use_native_calibration_table = 0;
+
+ var session = new InferenceSession(modelPath, SessionOptions.MakeSessionOptionWithTensorrtProvider(trt_options));
+ var inputMeta = session.InputMetadata;
+ var container = new List();
+ float[] inputData = LoadTensorFromFile(@"bench.in"); // this is the data for only one input tensor for this model
+ foreach (var name in inputMeta.Keys)
+ {
+ Assert.Equal(typeof(float), inputMeta[name].ElementType);
+ Assert.True(inputMeta[name].IsTensor);
+ var tensor = new DenseTensor(inputData, inputMeta[name].Dimensions);
+ container.Add(NamedOnnxValue.CreateFromTensor(name, tensor));
+ }
+
+
+ using (var results = session.Run(container))
+ {
+ // Following code is temporarily commented.
+ // Even though we enable fp16 or int8 through provider options, it could be disabled from TRT EP due to GPU not supporting fp16 or int8.
+ // Once From/ToProviderOptions() has been implemented in TRT EP, better test cases will be added.
+ /*
+ string[] files = Directory.GetFiles(Directory.GetCurrentDirectory(), "*int8*.engine");
+ Assert.True(files.Any());
+ files = Directory.GetFiles(Directory.GetCurrentDirectory(), "*fp16*.engine");
+ Assert.True(files.Any());
+ */
+ }
+ }
+#endif
+
+
[Theory]
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, true)]
[InlineData(GraphOptimizationLevel.ORT_DISABLE_ALL, false)]
@@ -2361,6 +2409,7 @@ private void VerifyNativeMethodsExist()
#endif
#if USE_TENSORRT
,"OrtSessionOptionsAppendExecutionProvider_Tensorrt"
+ ,"SessionOptionsAppendExecutionProvider_TensorRT"
#endif
#if USE_MIGRAPHX
,"OrtSessionOptionsAppendExecutionProvider_MIGraphX"
diff --git a/csharp/testdata/squeezenet_calibration.flatbuffers b/csharp/testdata/squeezenet_calibration.flatbuffers
new file mode 100644
index 0000000000000..e5cad768f4fe1
Binary files /dev/null and b/csharp/testdata/squeezenet_calibration.flatbuffers differ
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
index 44debc901cb77..9aa5f37ad7010 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h
@@ -8,6 +8,7 @@ extern "C" {
#endif
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);
+ORT_API_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, const OrtTensorRTProviderOptions* tensorrt_options);
#ifdef __cplusplus
}
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index faca901ce9030..7d892aacd8b4b 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -289,15 +289,20 @@ typedef struct OrtROCMProviderOptions {
/// Options for the TensorRT provider that are passed to SessionOptionsAppendExecutionProvider_TensorRT
///
typedef struct OrtTensorRTProviderOptions {
- int device_id; // cuda device id.
- int has_user_compute_stream; // indicator of user specified CUDA compute stream.
- void* user_compute_stream; // user specified CUDA compute stream.
- int has_trt_options; // override environment variables with following TensorRT settings at runtime.
- size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
- int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
- int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
- const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
- int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
+ int device_id; // cuda device id.
+ int has_user_compute_stream; // indicator of user specified CUDA compute stream.
+ void* user_compute_stream; // user specified CUDA compute stream.
+ int has_trt_options; // override environment variables with following TensorRT settings at runtime.
+ size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
+ int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
+ int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
+ const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
+ int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
+ int trt_max_partition_iterations; // maximum number of iterations allowed in model partitioning for TensorRT.
+ int trt_min_subgraph_size; // minimum node size in a subgraph after partitioning.
+ int trt_dump_subgraphs; // dump the subgraphs that are transformed into TRT engines in onnx format to the filesystem. Default 0 = false, nonzero = true
+ int trt_engine_cache_enable; // enable TensorRT engine caching. Default 0 = false, nonzero = true
+ const char* trt_cache_path; // specify path for TensorRT engine and profile files if engine_cache_enable is enabled, or INT8 calibration table file if trt_int8_enable is enabled.
} OrtTensorRTProviderOptions;
///
diff --git a/onnxruntime/core/providers/tensorrt/symbols.txt b/onnxruntime/core/providers/tensorrt/symbols.txt
index 47950c476c5e8..5e555e98a06f2 100644
--- a/onnxruntime/core/providers/tensorrt/symbols.txt
+++ b/onnxruntime/core/providers/tensorrt/symbols.txt
@@ -1 +1,2 @@
OrtSessionOptionsAppendExecutionProvider_Tensorrt
+SessionOptionsAppendExecutionProvider_TensorRT
\ No newline at end of file
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 27ed2be88115d..b8785a36cd44d 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -394,14 +394,22 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
// Get environment variables
- const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
- if (!max_partition_iterations_env.empty()) {
- max_partition_iterations_ = std::stoi(max_partition_iterations_env);
+ if (info.has_trt_options) {
+ max_partition_iterations_ = info.max_partition_iterations;
+ } else {
+ const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
+ if (!max_partition_iterations_env.empty()) {
+ max_partition_iterations_ = std::stoi(max_partition_iterations_env);
+ }
}
- const std::string min_subgraph_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMinSubgraphSize);
- if (!min_subgraph_size_env.empty()) {
- min_subgraph_size_ = std::stoi(min_subgraph_size_env);
+ if (info.has_trt_options) {
+ min_subgraph_size_ = info.min_subgraph_size;
+ } else {
+ const std::string min_subgraph_size_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMinSubgraphSize);
+ if (!min_subgraph_size_env.empty()) {
+ min_subgraph_size_ = std::stoi(min_subgraph_size_env);
+ }
}
if (info.has_trt_options) {
@@ -451,19 +459,32 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
}
}
- const std::string dump_subgraphs_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpSubgraphs);
- if (!dump_subgraphs_env.empty()) {
- dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true);
+ if (info.has_trt_options) {
+ dump_subgraphs_ = info.dump_subgraphs;
+ } else {
+ const std::string dump_subgraphs_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpSubgraphs);
+ if (!dump_subgraphs_env.empty()) {
+ dump_subgraphs_ = (std::stoi(dump_subgraphs_env) == 0 ? false : true);
+ }
}
- const std::string engine_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable);
- if (!engine_cache_enable_env.empty()) {
- engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true);
+ if (info.has_trt_options) {
+ engine_cache_enable_ = info.engine_cache_enable;
+ } else {
+ const std::string engine_cache_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCacheEnable);
+ if (!engine_cache_enable_env.empty()) {
+ engine_cache_enable_ = (std::stoi(engine_cache_enable_env) == 0 ? false : true);
+ }
}
if (engine_cache_enable_ || int8_enable_) {
const std::string engine_cache_path = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEngineCachePath);
- cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath);
+ if (info.has_trt_options) {
+ cache_path_ = info.cache_path;
+ } else {
+ cache_path_ = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCachePath);
+ }
+
if (!engine_cache_path.empty() && cache_path_.empty()) {
cache_path_ = engine_cache_path;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_ENGINE_CACHE_PATH is deprecated! Please use ORT_TENSORRT_CACHE_PATH to specify engine cache path";
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index 29b03954b0c24..3a56d3d5e7ff1 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -77,6 +77,11 @@ struct TensorrtExecutionProviderInfo {
bool int8_enable{false};
std::string int8_calibration_table_name{""};
bool int8_use_native_calibration_table{false};
+ int max_partition_iterations{1000};
+ int min_subgraph_size{1};
+ int dump_subgraphs{0};
+ int engine_cache_enable{0};
+ std::string cache_path{""};
};
// Information to construct kernel function state.
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 64071a89bc47d..3f9adbf2ca130 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -549,7 +549,7 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector
sess->GetSessionOptions().enable_cpu_mem_arena));
} else if (type == kTensorrtExecutionProvider) {
#ifdef USE_TENSORRT
- OrtTensorRTProviderOptions params{0, 0, nullptr, 0, 1 << 30, 0, 0, nullptr, 0};
+ OrtTensorRTProviderOptions params{0, 0, nullptr, 0, 1 << 30, 0, 0, nullptr, 0, 1000, 1, 0, 0, nullptr};
std::string trt_int8_calibration_table_name;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index dd799ac65570c..ee29c8d6f728c 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -318,7 +318,12 @@ int real_main(int argc, char* argv[], Ort::Env& env) {
0,
0,
nullptr,
- 0};
+ 0,
+ 1000,
+ 1,
+ 0,
+ 0,
+ nullptr};
OrtCUDAProviderOptions cuda_options{
0,
diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc
index ed439eb40ca71..2b7f5f2132465 100644
--- a/onnxruntime/test/perftest/ort_test_session.cc
+++ b/onnxruntime/test/perftest/ort_test_session.cc
@@ -68,6 +68,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
bool trt_int8_enable = false;
std::string trt_int8_calibration_table_name = "";
bool trt_int8_use_native_calibration_table = false;
+ int trt_max_partition_iterations = 1000;
+ int trt_min_subgraph_size = 1;
+ bool trt_dump_subgraphs = false;
+ bool trt_engine_cache_enable = false;
+ std::string trt_cache_path = "";
#ifdef _MSC_VER
std::string ov_string = ToMBString(performance_test_config.run_config.ep_runtime_config_string);
@@ -145,6 +150,11 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_int8_enable = trt_int8_enable;
tensorrt_options.trt_int8_calibration_table_name = trt_int8_calibration_table_name.c_str();
tensorrt_options.trt_int8_use_native_calibration_table = trt_int8_use_native_calibration_table;
+ tensorrt_options.trt_max_partition_iterations = trt_max_partition_iterations;
+ tensorrt_options.trt_min_subgraph_size = trt_min_subgraph_size;
+ tensorrt_options.trt_dump_subgraphs = trt_dump_subgraphs;
+ tensorrt_options.trt_engine_cache_enable = trt_engine_cache_enable;
+ tensorrt_options.trt_cache_path = trt_cache_path.c_str();
session_options.AppendExecutionProvider_TensorRT(tensorrt_options);
OrtCUDAProviderOptions cuda_options{
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index 9ea46c991f8b5..3b4cb81ad4879 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -43,7 +43,7 @@ std::unique_ptr DefaultCpuExecutionProvider(bool enable_aren
std::unique_ptr DefaultTensorrtExecutionProvider() {
#ifdef USE_TENSORRT
- OrtTensorRTProviderOptions params{0, 0, nullptr, 0, 1 << 30, 0, 0, nullptr, 0};
+ OrtTensorRTProviderOptions params{0, 0, nullptr, 0, 1 << 30, 0, 0, nullptr, 0, 1000, 1, 0, 0, nullptr};
if (auto factory = CreateExecutionProviderFactory_Tensorrt(¶ms))
return factory->CreateProvider();
#endif