@@ -514,6 +514,10 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514514 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg (device));
515515}
516516
517+ std::mutex ggml_cuda_lock;
518+ std::condition_variable ggml_cuda_lock_cv;
519+ std::atomic<int > ggml_cuda_lock_counter;
520+
517521// cuda buffer
518522
519523struct ggml_backend_cuda_buffer_context {
@@ -2685,6 +2689,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26852689
26862690 CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
26872691 graph_evaluated_or_captured = true ; // CUDA graph has been captured
2692+
2693+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2694+ if (ggml_cuda_lock_counter.fetch_sub (1 , std::memory_order_relaxed) == 1 ) {
2695+ ggml_cuda_lock_cv.notify_all ();
2696+ }
26882697 } else {
26892698 graph_evaluated_or_captured = true ; // ggml graph has been directly evaluated
26902699 }
@@ -2760,7 +2769,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27602769 }
27612770 }
27622771
2763- if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2772+ if (use_cuda_graph && cuda_graph_update_required) {
2773+ // Start CUDA graph capture
2774+ if (ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed) == 0 ) {
2775+ ggml_cuda_lock_counter.fetch_sub (1 , std::memory_order_relaxed);
2776+ std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2777+ ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
2778+ }
2779+
27642780 CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
27652781 }
27662782
0 commit comments