|
47 | 47 | #include <atomic> |
48 | 48 | #include <charconv> |
49 | 49 | #include <cinttypes> |
| 50 | +#include <condition_variable> |
50 | 51 | #include <cstddef> |
51 | 52 | #include <cstdint> |
52 | 53 | #include <float.h> |
53 | 54 | #include <limits> |
54 | 55 | #include <map> |
55 | 56 | #include <memory> |
56 | 57 | #include <mutex> |
57 | | -#include <stdint.h> |
58 | | -#include <stdio.h> |
59 | 58 | #include <stdarg.h> |
| 59 | +#include <stdio.h> |
60 | 60 | #include <stdlib.h> |
61 | 61 | #include <string> |
62 | 62 | #include <vector> |
@@ -514,6 +514,33 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i |
514 | 514 | return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); |
515 | 515 | } |
516 | 516 |
|
| 517 | +// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error |
| 518 | +// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured |
| 519 | + |
| 520 | +static std::mutex ggml_cuda_lock; |
| 521 | +static std::condition_variable ggml_cuda_lock_cv; |
| 522 | +static std::atomic<int> ggml_cuda_lock_counter; |
| 523 | + |
| 524 | +ggml_backend_cuda_context::~ggml_backend_cuda_context() { |
| 525 | + std::unique_lock<std::mutex> lock(ggml_cuda_lock); |
| 526 | + ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; }); |
| 527 | + |
| 528 | + if (copy_event != nullptr) { |
| 529 | + CUDA_CHECK(cudaEventDestroy(copy_event)); |
| 530 | + } |
| 531 | + for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) { |
| 532 | + for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) { |
| 533 | + if (streams[i][j] != nullptr) { |
| 534 | + CUDA_CHECK(cudaStreamDestroy(streams[i][j])); |
| 535 | + } |
| 536 | + } |
| 537 | + if (cublas_handles[i] != nullptr) { |
| 538 | + CUBLAS_CHECK(cublasDestroy(cublas_handles[i])); |
| 539 | + } |
| 540 | + } |
| 541 | +} |
| 542 | + |
| 543 | + |
517 | 544 | // cuda buffer |
518 | 545 |
|
519 | 546 | struct ggml_backend_cuda_buffer_context { |
@@ -2685,6 +2712,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx |
2685 | 2712 |
|
2686 | 2713 | CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); |
2687 | 2714 | graph_evaluated_or_captured = true; // CUDA graph has been captured |
| 2715 | + |
| 2716 | + std::lock_guard<std::mutex> lock(ggml_cuda_lock); |
| 2717 | + if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) { |
| 2718 | + ggml_cuda_lock_cv.notify_all(); |
| 2719 | + } |
2688 | 2720 | } else { |
2689 | 2721 | graph_evaluated_or_captured = true; // ggml graph has been directly evaluated |
2690 | 2722 | } |
@@ -2760,7 +2792,13 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, |
2760 | 2792 | } |
2761 | 2793 | } |
2762 | 2794 |
|
2763 | | - if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture |
| 2795 | + if (use_cuda_graph && cuda_graph_update_required) { |
| 2796 | + // Start CUDA graph capture |
| 2797 | + { |
| 2798 | + std::lock_guard<std::mutex> lock(ggml_cuda_lock); |
| 2799 | + ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed); |
| 2800 | + } |
| 2801 | + |
2764 | 2802 | CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); |
2765 | 2803 | } |
2766 | 2804 |
|
|
0 commit comments