Skip to content

Commit 5598a14

Browse files
slarenNexesenex
authored andcommitted
cuda : synchronize graph capture and cublas handle destruction (ggml-org#14288)
Workarounds an issue that may cause CUDA graph capture to fail when a cuBLAS handle is destroyed in a different thread
1 parent 47190fe commit 5598a14

File tree

2 files changed

+43
-19
lines changed

2 files changed

+43
-19
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
#endif
2020
#include "ggml-common.h"
2121

22-
#include <cstdio>
2322
#include <array>
2423
#include <cassert>
2524
#include <cfloat>
25+
#include <cstdio>
2626
#include <string>
2727
#include <vector>
2828

@@ -935,21 +935,7 @@ struct ggml_backend_cuda_context {
935935
name(GGML_CUDA_NAME + std::to_string(device)) {
936936
}
937937

938-
~ggml_backend_cuda_context() {
939-
if (copy_event != nullptr) {
940-
CUDA_CHECK(cudaEventDestroy(copy_event));
941-
}
942-
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
943-
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
944-
if (streams[i][j] != nullptr) {
945-
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
946-
}
947-
}
948-
if (cublas_handles[i] != nullptr) {
949-
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
950-
}
951-
}
952-
}
938+
~ggml_backend_cuda_context();
953939

954940
cudaStream_t stream(int device, int stream) {
955941
if (streams[device][stream] == nullptr) {

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ bool g_mul_mat_q = true;
5555
#include <atomic>
5656
#include <charconv>
5757
#include <cinttypes>
58+
#include <condition_variable>
5859
#include <cstddef>
5960
#include <cstdint>
6061
#include <float.h>
6162
#include <limits>
6263
#include <map>
6364
#include <memory>
6465
#include <mutex>
65-
#include <stdint.h>
66-
#include <stdio.h>
6766
#include <stdarg.h>
67+
#include <stdio.h>
6868
#include <stdlib.h>
6969
#include <string>
7070
#include <vector>
@@ -521,6 +521,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
521521
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
522522
}
523523

524+
// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
525+
// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
526+
527+
static std::mutex ggml_cuda_lock;
528+
static std::condition_variable ggml_cuda_lock_cv;
529+
static std::atomic<int> ggml_cuda_lock_counter;
530+
531+
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
532+
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
533+
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
534+
535+
if (copy_event != nullptr) {
536+
CUDA_CHECK(cudaEventDestroy(copy_event));
537+
}
538+
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
539+
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
540+
if (streams[i][j] != nullptr) {
541+
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
542+
}
543+
}
544+
if (cublas_handles[i] != nullptr) {
545+
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
546+
}
547+
}
548+
}
549+
550+
524551
// cuda buffer
525552

526553
struct ggml_backend_cuda_buffer_context {
@@ -2992,6 +3019,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29923019

29933020
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
29943021
graph_evaluated_or_captured = true; // CUDA graph has been captured
3022+
3023+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3024+
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
3025+
ggml_cuda_lock_cv.notify_all();
3026+
}
29953027
} else {
29963028
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
29973029
}
@@ -3067,7 +3099,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
30673099
}
30683100
}
30693101

3070-
if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture
3102+
if (use_cuda_graph && cuda_graph_update_required) {
3103+
// Start CUDA graph capture
3104+
{
3105+
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3106+
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
3107+
}
3108+
30713109
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
30723110
}
30733111

0 commit comments

Comments
 (0)