@@ -514,6 +514,10 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514514    return  std::unique_ptr<ggml_cuda_pool>(new  ggml_cuda_pool_leg (device));
515515}
516516
517+ std::mutex ggml_cuda_lock;
518+ std::condition_variable ggml_cuda_lock_cv;
519+ std::atomic<int > ggml_cuda_lock_counter;
520+ 
517521//  cuda buffer
518522
519523struct  ggml_backend_cuda_buffer_context  {
@@ -2685,6 +2689,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26852689
26862690            CUDA_CHECK (cudaStreamEndCapture (cuda_ctx->stream (), &cuda_ctx->cuda_graph ->graph ));
26872691            graph_evaluated_or_captured = true ; //  CUDA graph has been captured
2692+ 
2693+             std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2694+             if  (ggml_cuda_lock_counter.fetch_sub (1 , std::memory_order_relaxed) == 1 ) {
2695+                 ggml_cuda_lock_cv.notify_all ();
2696+             }
26882697        } else  {
26892698            graph_evaluated_or_captured = true ; //  ggml graph has been directly evaluated
26902699        }
@@ -2760,7 +2769,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27602769        }
27612770    }
27622771
2763-     if  (use_cuda_graph && cuda_graph_update_required) { //  Start CUDA graph capture
2772+     if  (use_cuda_graph && cuda_graph_update_required) {
2773+         //  Start CUDA graph capture
2774+         {
2775+             std::lock_guard<std::mutex> lock (ggml_cuda_lock);
2776+             ggml_cuda_lock_counter.fetch_add (1 , std::memory_order_relaxed);
2777+         }
2778+ 
27642779        CUDA_CHECK (cudaStreamBeginCapture (cuda_ctx->stream (), cudaStreamCaptureModeRelaxed));
27652780    }
27662781
0 commit comments