From c1c3b9b418bffae8d6700925d0bdd6d5b979fd32 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 30 May 2026 12:14:32 +0200 Subject: [PATCH 1/5] cuda: reset device in get_memory function if no backend is active --- ggml/src/ggml-cuda/ggml-cuda.cu | 29 +++++++++++++++++++++++------ ggml/src/ggml-cuda/vendors/hip.h | 1 + ggml/src/ggml-cuda/vendors/musa.h | 1 + 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f5293ad4cbb..efeb5818901 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3137,12 +3137,8 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { return cuda_ctx->name.c_str(); } -static void ggml_backend_cuda_free(ggml_backend_t backend) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +static void ggml_backend_cuda_free(ggml_backend_t backend); - delete cuda_ctx; - delete backend; -} static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; @@ -4877,6 +4873,7 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; int op_offload_min_batch_size; + std::atomic backend_count{0}; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -4884,6 +4881,16 @@ static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { return ctx->name.c_str(); } +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + dev_ctx->backend_count.fetch_sub(1, std::memory_order_relaxed); + + delete cuda_ctx; + delete backend; +} + static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->description.c_str(); @@ -4993,6 +5000,11 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) + // If no backends are active, the cudaMemGetInfo call above lazily created a CUDA context + // that permanently consumes VRAM. Reset the device to free it. + if (ctx->backend_count.load(std::memory_order_relaxed) == 0) { + cudaDeviceReset(); + } } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { @@ -5687,13 +5699,18 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device); + ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .iface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .device = */ dev, /* .context = */ ctx, }; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + dev_ctx->backend_count.fetch_add(1, std::memory_order_relaxed); + return cuda_backend; } diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 5e0e22c7fc2..ca38125852b 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -98,6 +98,7 @@ #define cudaMemsetAsync hipMemsetAsync #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaDeviceReset hipDeviceReset #define cudaSetDevice hipSetDevice #define cuDeviceGet hipDeviceGet #define CUdevice hipDevice_t diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 99e8fa3703e..6b2cd5a3154 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -81,6 +81,7 @@ #define cudaMemsetAsync musaMemsetAsync #define cudaMemGetInfo musaMemGetInfo #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize +#define cudaDeviceReset musaDeviceReset #define cudaSetDevice musaSetDevice #define cudaStreamCreateWithFlags musaStreamCreateWithFlags #define cudaStreamDestroy musaStreamDestroy From 61e659a4272a487d6f68092cae264439496ce4b3 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 4 Jun 2026 13:43:39 +0200 Subject: [PATCH 2/5] also count device and host buffers --- ggml/src/ggml-cuda/ggml-cuda.cu | 38 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index efeb5818901..34c8c1cd50c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -622,6 +622,9 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer +static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev); +static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev); + struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; @@ -639,6 +642,9 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_backend_cuda_device_dec_active(buffer->buft->device); + delete ctx; } @@ -791,6 +797,8 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); + ggml_backend_cuda_device_inc_active(buft->device); + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -1490,6 +1498,8 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_cuda_device_dec_active(buffer->buft->device); + CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1498,6 +1508,8 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } + ggml_cuda_set_device(0); + void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { @@ -1523,6 +1535,8 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; + ggml_backend_cuda_device_inc_active(buft->device); + return buffer; } @@ -4873,9 +4887,19 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; int op_offload_min_batch_size; - std::atomic backend_count{0}; + std::atomic active_count{0}; }; +static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + ctx->active_count.fetch_add(1, std::memory_order_relaxed); +} + +static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev) { + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + ctx->active_count.fetch_sub(1, std::memory_order_relaxed); +} + static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4884,8 +4908,7 @@ static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; - dev_ctx->backend_count.fetch_sub(1, std::memory_order_relaxed); + ggml_backend_cuda_device_dec_active(backend->device); delete cuda_ctx; delete backend; @@ -5000,9 +5023,9 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) - // If no backends are active, the cudaMemGetInfo call above lazily created a CUDA context - // that permanently consumes VRAM. Reset the device to free it. - if (ctx->backend_count.load(std::memory_order_relaxed) == 0) { + // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA + // context that permanently consumes VRAM. Reset the device to free it. + if (ctx->active_count.load(std::memory_order_relaxed) == 0) { cudaDeviceReset(); } } @@ -5708,8 +5731,7 @@ ggml_backend_t ggml_backend_cuda_init(int device) { /* .context = */ ctx, }; - ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; - dev_ctx->backend_count.fetch_add(1, std::memory_order_relaxed); + ggml_backend_cuda_device_inc_active(dev); return cuda_backend; } From 94b62910c0ad075d5d05e799dbdd4395167c2408 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Thu, 4 Jun 2026 13:53:11 +0200 Subject: [PATCH 3/5] exclude hip and musa from counting and device reset --- ggml/src/ggml-cuda/ggml-cuda.cu | 9 +++++++++ ggml/src/ggml-cuda/vendors/hip.h | 1 - ggml/src/ggml-cuda/vendors/musa.h | 1 - 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 34c8c1cd50c..b4d162773d9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4887,9 +4887,12 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) std::atomic active_count{0}; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) }; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; ctx->active_count.fetch_add(1, std::memory_order_relaxed); @@ -4899,6 +4902,10 @@ static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; ctx->active_count.fetch_sub(1, std::memory_order_relaxed); } +#else +static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev) { GGML_UNUSED(dev); } +static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev) { GGML_UNUSED(dev); } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; @@ -5023,11 +5030,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA // context that permanently consumes VRAM. Reset the device to free it. if (ctx->active_count.load(std::memory_order_relaxed) == 0) { cudaDeviceReset(); } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index ca38125852b..5e0e22c7fc2 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -98,7 +98,6 @@ #define cudaMemsetAsync hipMemsetAsync #define cudaMemGetInfo hipMemGetInfo #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize -#define cudaDeviceReset hipDeviceReset #define cudaSetDevice hipSetDevice #define cuDeviceGet hipDeviceGet #define CUdevice hipDevice_t diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 6b2cd5a3154..99e8fa3703e 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -81,7 +81,6 @@ #define cudaMemsetAsync musaMemsetAsync #define cudaMemGetInfo musaMemGetInfo #define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize -#define cudaDeviceReset musaDeviceReset #define cudaSetDevice musaSetDevice #define cudaStreamCreateWithFlags musaStreamCreateWithFlags #define cudaStreamDestroy musaStreamDestroy From f2f5f24677b2501b8e9a99d2de2414036363efaa Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Fri, 5 Jun 2026 12:21:44 +0200 Subject: [PATCH 4/5] use device mutex instead of atomic --- ggml/src/ggml-cuda/ggml-cuda.cu | 86 +++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b4d162773d9..497bcced39c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -622,8 +622,17 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer -static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev); -static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev); +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::mutex device_mutex; + int active_count = 0; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +}; struct ggml_backend_cuda_buffer_context { int device; @@ -643,7 +652,11 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; - ggml_backend_cuda_device_dec_active(buffer->buft->device); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) delete ctx; } @@ -797,7 +810,11 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); - ggml_backend_cuda_device_inc_active(buft->device); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -1498,7 +1515,11 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_cuda_device_dec_active(buffer->buft->device); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1508,7 +1529,7 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } - ggml_cuda_set_device(0); + ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0. void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); @@ -1535,7 +1556,11 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; - ggml_backend_cuda_device_inc_active(buft->device); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) return buffer; } @@ -4881,32 +4906,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend device -struct ggml_backend_cuda_device_context { - int device; - std::string name; - std::string description; - std::string pci_bus_id; - int op_offload_min_batch_size; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - std::atomic active_count{0}; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -}; - -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev) { - ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; - ctx->active_count.fetch_add(1, std::memory_order_relaxed); -} - -static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev) { - ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; - ctx->active_count.fetch_sub(1, std::memory_order_relaxed); -} -#else -static void ggml_backend_cuda_device_inc_active(ggml_backend_dev_t dev) { GGML_UNUSED(dev); } -static void ggml_backend_cuda_device_dec_active(ggml_backend_dev_t dev) { GGML_UNUSED(dev); } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4915,7 +4914,11 @@ static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - ggml_backend_cuda_device_dec_active(backend->device); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) delete cuda_ctx; delete backend; @@ -5004,6 +5007,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::lock_guard lock(ctx->device_mutex); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); @@ -5033,8 +5041,8 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA // context that permanently consumes VRAM. Reset the device to free it. - if (ctx->active_count.load(std::memory_order_relaxed) == 0) { - cudaDeviceReset(); + if (ctx->active_count == 0) { + CUDA_CHECK(cudaDeviceReset()); } #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } @@ -5740,7 +5748,11 @@ ggml_backend_t ggml_backend_cuda_init(int device) { /* .context = */ ctx, }; - ggml_backend_cuda_device_inc_active(dev); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) return cuda_backend; } From 3167ccb81b6255f15dc6fc2f2f198008e7a09509 Mon Sep 17 00:00:00 2001 From: Ruben Ortlam Date: Sat, 6 Jun 2026 09:49:36 +0200 Subject: [PATCH 5/5] undo backend_free function move --- ggml/src/ggml-cuda/ggml-cuda.cu | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 497bcced39c..e779a9be9e9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3176,8 +3176,18 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { return cuda_ctx->name.c_str(); } -static void ggml_backend_cuda_free(ggml_backend_t backend); +static void ggml_backend_cuda_free(ggml_backend_t backend) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete cuda_ctx; + delete backend; +} static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; @@ -4911,19 +4921,6 @@ static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { return ctx->name.c_str(); } -static void ggml_backend_cuda_free(ggml_backend_t backend) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; - std::lock_guard lock(dev_ctx->device_mutex); - dev_ctx->active_count--; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - - delete cuda_ctx; - delete backend; -} - static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->description.c_str();