@@ -55,16 +55,16 @@ bool g_mul_mat_q = true;
5555#include < atomic>
5656#include < charconv>
5757#include < cinttypes>
58+ #include < condition_variable>
5859#include < cstddef>
5960#include < cstdint>
6061#include < float.h>
6162#include < limits>
6263#include < map>
6364#include < memory>
6465#include < mutex>
65- #include < stdint.h>
66- #include < stdio.h>
6766#include < stdarg.h>
67+ #include < stdio.h>
6868#include < stdlib.h>
6969#include < string>
7070#include < vector>
@@ -521,6 +521,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
521521 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg (device));
522522}
523523
524+ // destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
525+ // this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
526+
527+ static std::mutex ggml_cuda_lock;
528+ static std::condition_variable ggml_cuda_lock_cv;
529+ static std::atomic<int > ggml_cuda_lock_counter;
530+
531+ ggml_backend_cuda_context::~ggml_backend_cuda_context () {
532+ std::unique_lock<std::mutex> lock (ggml_cuda_lock);
533+ ggml_cuda_lock_cv.wait (lock, []{ return ggml_cuda_lock_counter.load (std::memory_order_relaxed) == 0 ; });
534+
535+ if (copy_event != nullptr ) {
536+ CUDA_CHECK (cudaEventDestroy (copy_event));
537+ }
538+ for (int i = 0 ; i < GGML_CUDA_MAX_DEVICES; ++i) {
539+ for (int j = 0 ; j < GGML_CUDA_MAX_STREAMS; ++j) {
540+ if (streams[i][j] != nullptr ) {
541+ CUDA_CHECK (cudaStreamDestroy (streams[i][j]));
542+ }
543+ }
544+ if (cublas_handles[i] != nullptr ) {
545+ CUBLAS_CHECK (cublasDestroy (cublas_handles[i]));
546+ }
547+ }
548+ }
549+
550+
524551// cuda buffer
525552
526553struct ggml_backend_cuda_buffer_context {
@@ -2992,6 +3019,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29923019
29933020 CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
29943021 graph_evaluated_or_captured = true ; // CUDA graph has been captured
3022+
3023+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
3024+ if (ggml_cuda_lock_counter.fetch_sub (1 , std::memory_order_relaxed) == 1 ) {
3025+ ggml_cuda_lock_cv.notify_all ();
3026+ }
29953027 } else {
29963028 graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
29973029 }
@@ -3067,7 +3099,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
30673099 }
30683100 }
30693101
3070- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
3102+ if (use_cuda_graph && cuda_graph_update_required) {
3103+ // Start CUDA graph capture
3104+ {
3105+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
3106+ ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
3107+ }
3108+
30713109 CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
30723110 }
30733111
0 commit comments