Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 74 additions & 14 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() {

// cuda buffer

struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::mutex device_mutex;
int active_count = 0;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
};

struct ggml_backend_cuda_buffer_context {
int device;
void * dev_ptr = nullptr;
Expand All @@ -639,6 +651,13 @@ struct ggml_backend_cuda_buffer_context {

static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

delete ctx;
}

Expand Down Expand Up @@ -791,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac

ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
}

Expand Down Expand Up @@ -1490,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
}

static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

CUDA_CHECK(cudaFreeHost(buffer->context));
}

Expand All @@ -1498,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) {
return nullptr;
}

ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0.

void * ptr = nullptr;
cudaError_t err = cudaMallocHost((void **) &ptr, size);
if (err != cudaSuccess) {
Expand All @@ -1523,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm
buffer->buft = buft;
buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

return buffer;
}

Expand Down Expand Up @@ -3137,12 +3176,8 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
return cuda_ctx->name.c_str();
}

static void ggml_backend_cuda_free(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
static void ggml_backend_cuda_free(ggml_backend_t backend);

delete cuda_ctx;
delete backend;
}

static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
Expand Down Expand Up @@ -4871,19 +4906,24 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {

// backend device

struct ggml_backend_cuda_device_context {
int device;
std::string name;
std::string description;
std::string pci_bus_id;
int op_offload_min_batch_size;
};

static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->name.c_str();
}

static void ggml_backend_cuda_free(ggml_backend_t backend) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count--;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

delete cuda_ctx;
delete backend;
}

static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
return ctx->description.c_str();
Expand Down Expand Up @@ -4967,6 +5007,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k

static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
std::lock_guard<std::mutex> lock(ctx->device_mutex);
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

ggml_cuda_set_device(ctx->device);
CUDA_CHECK(cudaMemGetInfo(free, total));

Expand All @@ -4993,6 +5038,13 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
}
#endif // defined(__linux__)

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
// If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA
// context that permanently consumes VRAM. Reset the device to free it.
if (ctx->active_count == 0) {
CUDA_CHECK(cudaDeviceReset());
}
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}

static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
Expand Down Expand Up @@ -5687,13 +5739,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
return nullptr;
}

ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device);

ggml_backend_t cuda_backend = new ggml_backend {
/* .guid = */ ggml_backend_cuda_guid(),
/* .iface = */ ggml_backend_cuda_interface,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
/* .device = */ dev,
/* .context = */ ctx,
};

#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
std::lock_guard<std::mutex> lock(dev_ctx->device_mutex);
dev_ctx->active_count++;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)

return cuda_backend;
}

Expand Down
Loading