-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Enable cuda graph in TensorRT EP #10423
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
04c9fd7
634c2e3
2489976
1e71901
a330b5a
3e2e847
6427a24
f17e84c
d09a39d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -285,6 +285,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv | |
| engine_decryption_lib_path_ = info.engine_decryption_lib_path; | ||
| } | ||
| force_sequential_engine_build_ = info.force_sequential_engine_build; | ||
| cuda_graph_enable_ = info.cuda_graph_enable; | ||
| } else { | ||
| const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); | ||
| if (!max_partition_iterations_env.empty()) { | ||
|
|
@@ -369,6 +370,11 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv | |
| if (!force_sequential_engine_build_env.empty()) { | ||
| force_sequential_engine_build_ = (std::stoi(force_sequential_engine_build_env) == 0 ? false : true); | ||
| } | ||
|
|
||
| const std::string cuda_graph_enable_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kCUDAGraphEnable); | ||
| if (!cuda_graph_enable_env.empty()) { | ||
| cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true); | ||
| } | ||
| } | ||
|
|
||
| // Validate setting | ||
|
|
@@ -429,7 +435,8 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv | |
| << ", trt_cache_path: " << cache_path_ | ||
| << ", trt_engine_decryption_enable: " << engine_decryption_enable_ | ||
| << ", trt_engine_decryption_lib_path: " << engine_decryption_lib_path_ | ||
| << ", trt_force_sequential_engine_build: " << force_sequential_engine_build_; | ||
| << ", trt_force_sequential_engine_build: " << force_sequential_engine_build_ | ||
| << ", trt_cuda_graph_enable: " << cuda_graph_enable_; | ||
| } | ||
|
|
||
| TensorrtExecutionProvider::~TensorrtExecutionProvider() { | ||
|
|
@@ -1009,7 +1016,11 @@ std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetEngineBuildLock() const | |
|
|
||
| common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fused_nodes, | ||
| std::vector<NodeComputeInfo>& node_compute_funcs) { | ||
| for (const auto* fused_node : fused_nodes) { | ||
| size_t fused_nodes_size = fused_nodes.size(); | ||
| std::vector<std::unique_ptr<cudaGraphExec_t>> executable_cuda_graphs(fused_nodes_size); | ||
| for (size_t node_idx = 0; node_idx < fused_nodes_size; node_idx++) { | ||
| executable_cuda_graphs[node_idx] = std::make_unique<cudaGraphExec_t>(); | ||
| const auto* fused_node = fused_nodes[node_idx]; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, how does the TRT EP handle control flow nodes ? I fear we must explicitly not support using cuda graphs for models with control flow nodes as the graph captured for one input may not the same graph required for another input (because of the dynamic graph branching).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed. we should explicitly exclude graphs with loops/conditionals.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may highlight the constrains for dynamic shape cases in document, so that users can choose to enable cuda graph or not.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be nice to enforce constraints (doesn't support dynamic shapes and dynamic graphs) in code rather than punting to user/documentation. Let's see if there's a reasonable balance that can be achieved here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is a little bit tricky. Dynamic shape model is okay, but when input shapes of incoming data change, cuda graph needs to be recaptured, so the check has to be done in runtime. There is an API to update executable graph, but I haven't seen any APIs that can check existing cuda graph's profile, and we can't afford to update graph for every enqueue. |
||
| // Build map from input name to its index in input definitions | ||
| std::unordered_map<std::string, size_t> input_map; | ||
| const auto& input_defs = fused_node->InputDefs(); | ||
|
|
@@ -1257,6 +1268,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse | |
| contexts_.emplace(fused_node->Name(), std::move(trt_context)); | ||
| builders_.emplace(fused_node->Name(), std::move(trt_builder)); | ||
| networks_.emplace(fused_node->Name(), std::move(trt_network)); | ||
| executable_cuda_graph_map_.emplace(fused_node->Name(), std::move(executable_cuda_graphs[node_idx])); | ||
| input_info_[fused_node->Name()].push_back(input_indexes); | ||
| output_info_[fused_node->Name()].push_back(output_indexes); | ||
| output_info_[fused_node->Name()].push_back(output_types); | ||
|
|
@@ -1271,8 +1283,9 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse | |
| &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], | ||
| &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], | ||
| input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, | ||
| dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, | ||
| runtime_.get(), nullptr, allocator_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_}; | ||
| dla_enable_, dla_core_, cuda_graph_enable_, nullptr, &executable_cuda_graph_map_[context->node_name], &max_workspace_size_, | ||
| trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), nullptr, allocator_, | ||
| dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_}; | ||
| *state = p.release(); | ||
| return 0; | ||
| }; | ||
|
|
@@ -1295,6 +1308,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse | |
| auto trt_engine = trt_state->engine->get(); | ||
| auto trt_context = trt_state->context->get(); | ||
| auto trt_profile = &(trt_state->trt_profile); | ||
| auto cuda_graph_ptr = &(trt_state->cuda_graph_ptr); | ||
|
|
||
| auto alloc = trt_state->scratch_allocator; | ||
| int num_inputs = static_cast<int>(input_indexes.size()); | ||
| int num_outputs = static_cast<int>(output_indexes.size()); | ||
|
|
@@ -1856,10 +1871,31 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<Node*>& fuse | |
| } | ||
| } | ||
|
|
||
| // Run TRT inference | ||
| if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); | ||
| } | ||
| // Run TRT inference | ||
| if (trt_state->cuda_graph_enable) | ||
| { | ||
| if (*cuda_graph_ptr == nullptr) { | ||
hariharans29 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| std::unique_ptr<cudaGraph_t> graph = std::make_unique<cudaGraph_t>(); | ||
| *cuda_graph_ptr = trt_state->executable_cuda_graph->get(); | ||
| //warm up to avoid capturing initialization | ||
| if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does enqueueV2() synchronize with the GPU before returning ? If not, we may have to wait for the warm-up tasks queued on the stream to finish before the stream capture...
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this warm up even needed?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems like it's for handling a known issue with dynamic shapes? (please add a comment)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warm-up is needed to do initialization (flushing any old context) before graph capturing according to Nvidia. CUDA graph still has issue in some dynamic shape cases.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in the note section of the api doc for enqueuev2 https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a2f4429652736e8ef6e19f433400108c7
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If shape changed, cuda graph needs to be recaptured, which is not desired because the capturing happens in inference. |
||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, | ||
| "TensorRT engine has operations that are not allowed in CUDA graph capture mode. ", | ||
| "Please disable trt_cuda_graph_enable."); | ||
| } | ||
| CUDA_CALL_THROW(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If certain operators can't be supported by the TensorRT and CUDA EPs, does the CPU EP come into play ? If so, the same comment as this- #9978 (comment)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cuda graph capturing in this PR is only for TensorRT subgraphs. Unsupported ops can still fall back to CUDA/CPU EPs.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood now, thanks. One question though: Let us say the model looks like this: TensorRT subgraph 1 -> CPU op -> TensorRT subgraph 2 and you are capturing the graphs for both the TRT subgraphs, will the necessary synchronization logic before the CPU op happen even with the cuda graph setup ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the case is the same with and without cuda graph, isn't it? CPU needs to wait until TRT subgraph1 produces its result. That's interesting if there is a better cuda graph setup that can avoid the wait. |
||
| if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); | ||
| } | ||
| CUDA_CALL_THROW(cudaStreamEndCapture(stream, graph.get())); | ||
| CUDA_CALL_THROW(cudaGraphInstantiate(*cuda_graph_ptr, *(graph.get()), NULL, NULL, 0)); | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens to graph? do you need to destroy it? else we leak memory? |
||
| cudaGraphLaunch(**cuda_graph_ptr, stream); | ||
| } else { | ||
| if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { | ||
| return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); | ||
| } | ||
| } | ||
|
|
||
| // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 | ||
| for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since cudaGraph_t and cudaGraphExec_t are already pointers, i think we can't use make_unique/unique pointers like this ?
need to see if there's a better way to manage memory of the cudaGraphExec and cudaGraph