-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Support CUDA Graph #9978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Support CUDA Graph #9978
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6645bcb
Support CUDA Graph
feihugis 535464a
Update the APIs
feihugis 5de03e4
Address some comments
feihugis b819c44
Reflect the latest API changes
feihugis 0ce98f0
Address some comments
feihugis 377cce9
Support warmup runs before capturing cuda graph and solve the sync re…
feihugis 7990c9b
Rename python test file and minor fix
feihugis e484ca6
Remove some unncessary comments, superfluous syncs, and warm up runs…
feihugis 297b27c
Add a CApiTest for run cuda graph
feihugis d9894a7
Fix build failures
feihugis ade7af2
Change the default run to be 1 and avoid using the word warmup
feihugis 8e3fcc9
Change static var to private member var and a minor format fixing
feihugis 9edf6c6
Run cuda graph with multiple threads
feihugis 21e5192
Address comments
feihugis 87f1075
Remove a file added by mistake
feihugis 9a7ab55
Update doc for update_inplace
feihugis 832cf77
Enhance the support of CUDA graph with multi threads
feihugis af0bc10
Add CachedExecutionProviderForGraphReplay and more comments
feihugis 5ee89bb
Rollback the support of cuda graph with multi threads
feihugis a14f819
Address the comments
feihugis File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,89 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "core/providers/cuda/cuda_graph.h" | ||
pranavsharma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| #include "core/providers/cuda/cuda_common.h" | ||
| #include <cuda_runtime_api.h> | ||
| #include <driver_types.h> | ||
|
|
||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) { | ||
| #if (defined(CUDA_VERSION) && CUDA_VERSION < 10000) | ||
| ORT_THROW("CUDA graphs can only be used in Onnxruntime built with CUDA >= 10.0"); | ||
| #endif | ||
| } | ||
|
|
||
| void CUDAGraph::SetStream(cudaStream_t stream) { | ||
| stream_ = stream; | ||
| } | ||
|
|
||
| void CUDAGraph::CaptureBegin() { | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 | ||
| ORT_ENFORCE(!has_graph_exec_, | ||
| "This cuda graph has already captured a graph. " | ||
| "Create a new instance to capture a new graph."); | ||
|
|
||
| CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); | ||
| // For now cuda graph can only work with a single thread. In the future, we | ||
| // will support multiple threads. For multiple threads with multiple graphs | ||
| // and streams, `cudaStreamCaptureModeGlobal` needs to be changed to | ||
| // `cudaStreamCaptureModeThreadLocal` | ||
| CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal)); | ||
| #else | ||
| ORT_THROW("CUDA graphs can only be used in Onnxruntime built with CUDA >= 10.0"); | ||
| #endif | ||
| } | ||
|
|
||
| void CUDAGraph::CaptureEnd() { | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 | ||
| CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_)); | ||
| if (graph_ == NULL) { | ||
| ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL"); | ||
| } | ||
|
|
||
| has_graph_ = true; | ||
| CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); | ||
| has_graph_exec_ = true; | ||
| CUDA_CALL_THROW(cudaGraphDestroy(graph_)); | ||
| has_graph_ = false; | ||
| #else | ||
| ORT_THROW("CUDA graphs can only be used in Onnxruntime built with CUDA >= 10.0"); | ||
| #endif | ||
| } | ||
|
|
||
| Status CUDAGraph::Replay() { | ||
| // Although this function is not thread safe, the lock is not needed here because | ||
| // CUDA EP maintains a separate cuda graph per thread | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 | ||
| LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_; | ||
| CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_)); | ||
| CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); | ||
| #else | ||
| ORT_THROW("CUDA graphs can only be used in Onnxruntime built with CUDA >= 10.0"); | ||
| #endif | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| void CUDAGraph::Reset() { | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 | ||
| if (has_graph_) { | ||
| CUDA_CALL_THROW(cudaGraphDestroy(graph_)); | ||
| has_graph_ = false; | ||
| } | ||
| if (has_graph_exec_) { | ||
| CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_)); | ||
| has_graph_exec_ = false; | ||
| } | ||
| #else | ||
| ORT_THROW("CUDA graphs can only be used in Onnxruntime built with CUDA >= 10.0"); | ||
| #endif | ||
| } | ||
|
|
||
| CUDAGraph::~CUDAGraph() { | ||
| Reset(); | ||
| } | ||
|
|
||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #pragma once | ||
pranavsharma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| #include "core/common/common.h" | ||
| #include "core/platform/ort_mutex.h" | ||
| #include "core/providers/cuda/cuda_pch.h" | ||
|
|
||
| namespace onnxruntime { | ||
|
|
||
| using CaptureId_t = unsigned long long; | ||
|
|
||
| struct CUDAGraph { | ||
| CUDAGraph() {}; | ||
| CUDAGraph(cudaStream_t stream); | ||
| ~CUDAGraph(); | ||
|
|
||
| void SetStream(cudaStream_t stream); | ||
| void CaptureBegin(); | ||
| void CaptureEnd(); | ||
| Status Replay(); | ||
| void Reset(); | ||
|
|
||
| private: | ||
| #if defined(CUDA_VERSION) && CUDA_VERSION >= 10000 | ||
| cudaGraph_t graph_ = NULL; | ||
| cudaGraphExec_t graph_exec_ = NULL; | ||
| #endif | ||
|
|
||
| bool has_graph_ = false; | ||
| bool has_graph_exec_ = false; | ||
|
|
||
| CaptureId_t id_; | ||
| cudaStream_t stream_ = nullptr; // Does not own the stream | ||
| }; | ||
|
|
||
| } // namespace onnxruntime | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.