Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ struct OrtTensorRTProviderOptionsV2 {
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_cuda_graph_enable; // enable cuda graph. Default 0 = false, nonzero = true
};
52 changes: 44 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
engine_decryption_lib_path_ = info.engine_decryption_lib_path;
}
force_sequential_engine_build_ = info.force_sequential_engine_build;
cuda_graph_enable_ = info.cuda_graph_enable;
} else {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
if (!max_partition_iterations_env.empty()) {
Expand Down Expand Up @@ -369,6 +370,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (!force_sequential_engine_build_env.empty()) {
force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true);
}

const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCUDAGraphEnable);
if (!cuda_graph_enable_env.empty()) {
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}
}

// Validate setting
Expand Down Expand Up @@ -429,7 +435,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
<< ", trt_cache_path: " << cache_path_
<< ", trt_engine_decryption_enable: " << engine_decryption_enable_
<< ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_
<< ", trt_force_sequential_engine_build: " << force_sequential_engine_build_;
<< ", trt_force_sequential_engine_build: " << force_sequential_engine_build_
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_;
}

TensorrtExecutionProvider::~TensorrtExecutionProvider() {
Expand Down Expand Up @@ -1009,7 +1016,11 @@ std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetEngineBuildLock() const

common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto* fused_node : fused_nodes) {
size_t fused_nodes_size = fused_nodes.size();
std::vector<std::unique_ptr<cudaGraphExec_t>> executable_cuda_graphs(fused_nodes_size);
for (size_t node_idx = 0; node_idx < fused_nodes_size; node_idx++) {
executable_cuda_graphs[node_idx] = std::make_unique<cudaGraphExec_t>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since cudaGraph_t and cudaGraphExec_t are already pointers, i think we can't use make_unique/unique pointers like this ?
need to see if there's a better way to manage memory of the cudaGraphExec and cudaGraph

const auto* fused_node = fused_nodes[node_idx];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, how does the TRT EP handle control flow nodes ? I fear we must explicitly not support using cuda graphs for models with control flow nodes as the graph captured for one input may not the same graph required for another input (because of the dynamic graph branching).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. we should explicitly exclude graphs with loops/conditionals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may highlight the constrains for dynamic shape cases in document, so that users can choose to enable cuda graph or not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to enforce constraints (doesn't support dynamic shapes and dynamic graphs) in code rather than punting to user/documentation. Let's see if there's a reasonable balance that can be achieved here.

Copy link
Contributor Author

@stevenlix stevenlix Feb 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a little bit tricky. Dynamic shape model is okay, but when input shapes of incoming data change, cuda graph needs to be recaptured, so the check has to be done in runtime. There is an API to update executable graph, but I haven't seen any APIs that can check existing cuda graph's profile, and we can't afford to update graph for every enqueue.

// Build map from input name to its index in input definitions
std::unordered_map<std::string, size_t> input_map;
const auto& input_defs = fused_node->InputDefs();
Expand Down Expand Up @@ -1257,6 +1268,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
contexts_.emplace(fused_node->Name(), std::move(trt_context));
builders_.emplace(fused_node->Name(), std::move(trt_builder));
networks_.emplace(fused_node->Name(), std::move(trt_network));
executable_cuda_graph_map_.emplace(fused_node->Name(), std::move(executable_cuda_graphs[node_idx]));
input_info_[fused_node->Name()].push_back(input_indexes);
output_info_[fused_node->Name()].push_back(output_indexes);
output_info_[fused_node->Name()].push_back(output_types);
Expand All @@ -1271,8 +1283,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
&engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
dla_enable_, dla_core_, cuda_graph_enable_, nullptr, &executable_cuda_graph_map_[context->node_name], &max_workspace_size_,
trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr, allocator_,
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_};
*state = p.release();
return 0;
};
Expand All @@ -1295,6 +1308,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
auto trt_profile = &(trt_state->trt_profile);
auto cuda_graph_ptr = &(trt_state->cuda_graph_ptr);

auto alloc = trt_state->scratch_allocator;
int num_inputs = static_cast<int>(input_indexes.size());
int num_outputs = static_cast<int>(output_indexes.size());
Expand Down Expand Up @@ -1856,10 +1871,31 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse
}
}

// Run TRT inference
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}
// Run TRT inference
if (trt_state->cuda_graph_enable)
{
if (*cuda_graph_ptr == nullptr) {
std::unique_ptr<cudaGraph_t> graph = std::make_unique<cudaGraph_t>();
*cuda_graph_ptr = trt_state->executable_cuda_graph->get();
//warm up to avoid capturing initialization
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does enqueueV2() synchronize with the GPU before returning ? If not, we may have to wait for the warm-up tasks queued on the stream to finish before the stream capture...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this warm up even needed?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like it's for handling a known issue with dynamic shapes? (please add a comment)

Copy link
Contributor Author

@stevenlix stevenlix Jan 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warm-up is needed to do initialization (flushing any old context) before graph capturing according to Nvidia. CUDA graph still has issue in some dynamic shape cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the note section of the api doc for enqueuev2 https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a2f4429652736e8ef6e19f433400108c7
seems to allude to dynamic shapes can work if you call enqueuev2() once before graph capture? is it similar to what you are doing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If shape changed, cuda graph needs to be recaptured, which is not desired because the capturing happens in inference.

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"TensorRT engine has operations that are not allowed in CUDA graph capture mode. ",
"Please disable trt_cuda_graph_enable.");
}
CUDA_CALL_THROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If certain operators can't be supported by the TensorRT and CUDA EPs, does the CPU EP come into play ? If so, the same comment as this- #9978 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cuda graph capturing in this PR is only for TensorRT subgraphs. Unsupported ops can still fall back to CUDA/CPU EPs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood now, thanks.

One question though: Let us say the model looks like this: TensorRT subgraph 1 -> CPU op -> TensorRT subgraph 2 and you are capturing the graphs for both the TRT subgraphs, will the necessary synchronization logic before the CPU op happen even with the cuda graph setup ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the case is the same with and without cuda graph, isn't it? CPU needs to wait until TRT subgraph1 produces its result. That's interesting if there is a better cuda graph setup that can avoid the wait.

if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}
CUDA_CALL_THROW(cudaStreamEndCapture(stream, graph.get()));
CUDA_CALL_THROW(cudaGraphInstantiate(*cuda_graph_ptr, *(graph.get()), NULL, NULL, 0));
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens to graph? do you need to destroy it? else we leak memory?

cudaGraphLaunch(**cuda_graph_ptr, stream);
} else {
if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}
}

// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ static const std::string kCachePath = "ORT_TENSORRT_CACHE_PATH";
static const std::string kDecryptionEnable = "ORT_TENSORRT_ENGINE_DECRYPTION_ENABLE";
static const std::string kDecryptionLibPath = "ORT_TENSORRT_ENGINE_DECRYPTION_LIB_PATH";
static const std::string kForceSequentialEngineBuild= "ORT_TENSORRT_FORCE_SEQUENTIAL_ENGINE_BUILD";
static const std::string kCUDAGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -96,6 +97,9 @@ struct TensorrtFuncState {
bool int8_calibration_cache_available;
bool dla_enable;
int dla_core;
bool cuda_graph_enable;
cudaGraphExec_t* cuda_graph_ptr = nullptr;
std::unique_ptr<cudaGraphExec_t>* executable_cuda_graph = nullptr;
size_t* max_workspace_size_ptr = nullptr;
std::string trt_node_name_with_precision;
bool engine_cache_enable;
Expand Down Expand Up @@ -167,6 +171,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*);
int (*engine_encryption_)(const char*, char*, size_t);
bool cuda_graph_enable_ = false;

std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvonnxparser::IParser>> parsers_;
std::unordered_map<std::string, tensorrt_ptr::unique_pointer<nvinfer1::ICudaEngine>> engines_;
Expand All @@ -176,6 +181,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, std::vector<std::unordered_map<std::string, size_t>>> input_info_;
std::unordered_map<std::string, std::vector<std::unordered_map<std::string, size_t>>> output_info_;
std::unordered_map<std::string, std::unordered_map<std::string, std::unordered_map<size_t, std::pair<int64_t, int64_t>>>> input_shape_ranges_;
std::unordered_map<std::string, std::unique_ptr<cudaGraphExec_t>> executable_cuda_graph_map_;

/**Get IndexedSubGraph based on node list of the subgraph*/
std::unique_ptr<IndexedSubGraph> GetSubGraph(SubGraph_t graph_nodes_index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ constexpr const char* kCachePath = "trt_engine_cache_path";
constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable";
constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path";
constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build";
// add new provider option name here.
// add new provider option name here.
constexpr const char* kCUDAGraphEnable = "trt_cuda_graph_enable";
} // namespace provider_option_names
} // namespace tensorrt

Expand Down Expand Up @@ -64,6 +65,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build)
.AddAssignmentToReference(tensorrt::provider_option_names::kCUDAGraphEnable, info.cuda_graph_enable)//slx
.Parse(options)); // add new provider option here.

return info;
Expand All @@ -88,6 +90,7 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kDecryptionEnable, MakeStringWithClassicLocale(info.engine_decryption_enable)},
{tensorrt::provider_option_names::kDecryptionLibPath, MakeStringWithClassicLocale(info.engine_decryption_lib_path)},
{tensorrt::provider_option_names::kForceSequentialEngineBuild, MakeStringWithClassicLocale(info.force_sequential_engine_build)},
{tensorrt::provider_option_names::kCUDAGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
// add new provider option here.
};
return options;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ struct TensorrtExecutionProviderInfo {
bool engine_decryption_enable{false};
std::string engine_decryption_lib_path{""};
bool force_sequential_engine_build{false};
bool cuda_graph_enable{false};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct Tensorrt_Provider : Provider {
info.engine_decryption_enable = options.trt_engine_decryption_enable != 0;
info.engine_decryption_lib_path = options.trt_engine_decryption_lib_path == nullptr ? "" : options.trt_engine_decryption_lib_path;
info.force_sequential_engine_build = options.trt_force_sequential_engine_build != 0;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;
return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down Expand Up @@ -135,6 +136,7 @@ struct Tensorrt_Provider : Provider {
}

trt_options.trt_force_sequential_engine_build = internal_options.force_sequential_engine_build;
trt_options.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ OrtTensorRTProviderOptionsV2 OrtTensorRTProviderOptionsToOrtTensorRTProviderOpti
trt_options_converted.trt_force_sequential_engine_build = legacy_trt_options->trt_force_sequential_engine_build;
// Add new provider option below
// Use default value as this field is not available in OrtTensorRTProviderOptionsV

trt_options_converted.trt_cuda_graph_enable = 0;
return trt_options_converted;
}

Expand Down Expand Up @@ -1489,6 +1489,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorRTProviderOptions, _Outptr_ OrtTensorRT
(*out)->trt_engine_decryption_enable = false;
(*out)->trt_engine_decryption_lib_path = nullptr;
(*out)->trt_force_sequential_engine_build = false;
(*out)->trt_cuda_graph_enable = false;
return nullptr;
#else
ORT_UNUSED_PARAMETER(out);
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
nullptr,
0,
nullptr,
0,
0};
for (auto option : it->second) {
if (option.first == "device_id") {
Expand Down Expand Up @@ -510,6 +511,14 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be a boolean i.e. 'True' or 'False'. Default value is False.\n");
}
} else if (option.first == "trt_cuda_graph_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_cuda_graph_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_cuda_graph_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. 'True' or 'False'. Default value is False.\n");
}
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/tensorrt/perf/mem_test/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ std::unique_ptr<OrtTensorRTProviderOptions> get_default_trt_provider_options() {
tensorrt_options->trt_engine_decryption_enable = false;
tensorrt_options->trt_engine_decryption_lib_path = "";
tensorrt_options->trt_force_sequential_engine_build = false;
tensorrt_options->trt_cuda_graph_enable = false;

return tensorrt_options;
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/perftest/command_args_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ namespace perftest {
"\t [TensorRT only] [trt_engine_cache_enable]: Enable engine caching.\n"
"\t [TensorRT only] [trt_engine_cache_path]: Specify engine cache path.\n"
"\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n"
"\t [TensorRT only] [trt_cuda_graph_enable]: Enable CUDA graph.\n"
"\t [Usage]: -e <provider_name> -i '<key1>|<value1> <key2>|<value2>'\n\n"
"\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n"
"\t [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n"
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
bool trt_engine_decryption_enable = false;
std::string trt_engine_decryption_lib_path = "";
bool trt_force_sequential_engine_build = false;
bool trt_cuda_graph_enable = false;

#ifdef _MSC_VER
std::string ov_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string);
Expand Down Expand Up @@ -206,8 +207,16 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_force_sequential_engine_build' should be a boolean i.e. true or false. Default value is false.\n");
}
} else if (key == "trt_cuda_graph_enable") {
if (value == "true" || value == "True") {
trt_cuda_graph_enable = true;
} else if (value == "false" || value == "False") {
trt_cuda_graph_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be a boolean i.e. true or false. Default value is false.\n");
}
} else {
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build'] \n");
ORT_THROW("[ERROR] [TensorRT] wrong key type entered. Choose from the following runtime key options that are available for TensorRT. ['device_id', 'trt_max_partition_iterations', 'trt_min_subgraph_size', 'trt_max_workspace_size', 'trt_fp16_enable', 'trt_int8_enable', 'trt_int8_calibration_table_name', 'trt_int8_use_native_calibration_table', 'trt_dla_enable', 'trt_dla_core', 'trt_dump_subgraphs', 'trt_engine_cache_enable', 'trt_engine_cache_path', 'trt_engine_decryption_enable', 'trt_engine_decryption_lib_path', 'trt_force_sequential_engine_build', 'trt_cuda_graph_enable'] \n");
}
}
OrtTensorRTProviderOptionsV2 tensorrt_options;
Expand All @@ -229,6 +238,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device
tensorrt_options.trt_engine_decryption_enable = trt_engine_decryption_enable;
tensorrt_options.trt_engine_decryption_lib_path = trt_engine_decryption_lib_path.c_str();
tensorrt_options.trt_force_sequential_engine_build = trt_force_sequential_engine_build;
tensorrt_options.trt_cuda_graph_enable = trt_cuda_graph_enable;
session_options.AppendExecutionProvider_TensorRT_V2(tensorrt_options);

OrtCUDAProviderOptions cuda_options;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/providers/cpu/model_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ TEST_P(ModelTest, Run) {
nullptr,
0,
nullptr,
0,
0};
ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(TensorrtExecutionProviderWithOptions(&params)));
} else {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ TEST_P(TensorrtExecutionProviderCacheTest, Run) {
nullptr,
0,
nullptr,
0,
0};

if (cache_type.compare("engine") == 0) {
Expand Down