Skip to content

Commit 319f734

Browse files
committed
cuda : synchronize graph capture and cublas handle destruction
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread ggml-ci
1 parent 381174b commit 319f734

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
#endif
2020
#include "ggml-common.h"
2121

22-
#include <cstdio>
2322
#include <array>
2423
#include <cassert>
2524
#include <cfloat>
25+
#include <cstdio>
2626
#include <string>
2727
#include <vector>
2828

@@ -767,21 +767,7 @@ struct ggml_backend_cuda_context {
767767
name(GGML_CUDA_NAME + std::to_string(device)) {
768768
}
769769

770-
~ggml_backend_cuda_context() {
771-
if (copy_event != nullptr) {
772-
CUDA_CHECK(cudaEventDestroy(copy_event));
773-
}
774-
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
775-
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
776-
if (streams[i][j] != nullptr) {
777-
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
778-
}
779-
}
780-
if (cublas_handles[i] != nullptr) {
781-
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
782-
}
783-
}
784-
}
770+
~ggml_backend_cuda_context();
785771

786772
cudaStream_t stream(int device, int stream) {
787773
if (streams[device][stream] == nullptr) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@
4747
#include <atomic>
4848
#include <charconv>
4949
#include <cinttypes>
50+
#include <condition_variable>
5051
#include <cstddef>
5152
#include <cstdint>
5253
#include <float.h>
5354
#include <limits>
5455
#include <map>
5556
#include <memory>
5657
#include <mutex>
57-
#include <stdint.h>
58-
#include <stdio.h>
5958
#include <stdarg.h>
59+
#include <stdio.h>
6060
#include <stdlib.h>
6161
#include <string>
6262
#include <vector>
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
514514
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
515515
}
516516

517+
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
518+
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
519+
520+
static std::mutex ggml_cuda_lock;
521+
static std::condition_variable ggml_cuda_lock_cv;
522+
static std::atomic<int> ggml_cuda_lock_counter;
523+
524+
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
525+
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
526+
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
527+
528+
if (copy_event != nullptr) {
529+
CUDA_CHECK(cudaEventDestroy(copy_event));
530+
}
531+
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
532+
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
533+
if (streams[i][j] != nullptr) {
534+
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
535+
}
536+
}
537+
if (cublas_handles[i] != nullptr) {
538+
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
539+
}
540+
}
541+
}
542+
543+
517544
// cuda buffer
518545

519546
struct ggml_backend_cuda_buffer_context {
@@ -2685,6 +2712,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
26852712

26862713
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
26872714
graph_evaluated_or_captured = true; // CUDA graph has been captured
2715+
2716+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2717+
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
2718+
ggml_cuda_lock_cv.notify_all();
2719+
}
26882720
} else {
26892721
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
26902722
}
@@ -2760,7 +2792,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
27602792
}
27612793
}
27622794

2763-
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
2795+
if (use_cuda_graph && cuda_graph_update_required) {
2796+
// Start CUDA graph capture
2797+
{
2798+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
2799+
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
2800+
}
2801+
27642802
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
27652803
}
27662804

0 commit comments

Comments
 (0)