-
Couldn't load subscription status.
- Fork 13.4k
cuda : improve cuda pool efficiency using virtual memory #4606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
0d77fbd
eb223dc
bd78dc9
872408c
9452d0d
20860da
4c0f300
545f23d
110b505
b7da1ba
d8b06c2
9f5ac6d
5eb6262
6fe9da0
d888362
26e97b5
ab6ad5e
5acc9e5
b9c5a6e
3081c4e
3ad45fc
532cb9b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,6 +88,7 @@ | |
| #define __trap abort | ||
| #else | ||
| #include <cuda_runtime.h> | ||
| #include <cuda.h> | ||
| #include <cublas_v2.h> | ||
| #include <cuda_fp16.h> | ||
| // CUDA 10.2 does not have these macro definitions. | ||
|
|
@@ -213,6 +214,24 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); | |
| } \ | ||
| } while (0) | ||
|
|
||
| // driver API | ||
| #define CU_CHECK(err) \ | ||
| do { \ | ||
| CUresult err_ = (err); \ | ||
| if (err_ != CUDA_SUCCESS) { \ | ||
| int id; \ | ||
| cuDeviceGet(&id, 0); \ | ||
| const char * err_str; \ | ||
| cuGetErrorString(err_, &err_str); \ | ||
| fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ | ||
| err_str); \ | ||
| fprintf(stderr, "%s\n", #err); \ | ||
| fprintf(stderr, "current device: %d\n", id); \ | ||
| GGML_ASSERT(!"CUDA error"); \ | ||
| } \ | ||
| } while (0) | ||
|
|
||
|
|
||
|
||
| #if CUDART_VERSION >= 12000 | ||
| #define CUBLAS_CHECK(err) \ | ||
| do { \ | ||
|
|
@@ -6543,21 +6562,24 @@ struct scoped_spin_lock { | |
| scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; | ||
| }; | ||
|
|
||
| static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; | ||
|
|
||
| // #define DEBUG_CUDA_MALLOC | ||
| struct cuda_buffer { | ||
| void * ptr = nullptr; | ||
| size_t size = 0; | ||
| }; | ||
|
|
||
| static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS]; | ||
| static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; | ||
| static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0}; | ||
|
|
||
| static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | ||
| static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) { | ||
| scoped_spin_lock lock(g_cuda_pool_lock); | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
| #ifdef DEBUG_CUDA_MALLOC | ||
| int nnz = 0; | ||
| size_t max_size = 0, tot_size = 0; | ||
| size_t max_size = 0; | ||
| #endif | ||
| size_t best_diff = 1ull << 36; | ||
| int ibest = -1; | ||
|
|
@@ -6566,7 +6588,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | |
| if (b.ptr != nullptr) { | ||
| #ifdef DEBUG_CUDA_MALLOC | ||
| ++nnz; | ||
| tot_size += b.size; | ||
| if (b.size > max_size) max_size = b.size; | ||
| #endif | ||
| if (b.size >= size) { | ||
|
|
@@ -6593,19 +6614,20 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | |
| b.size = 0; | ||
| return ptr; | ||
| } | ||
| #ifdef DEBUG_CUDA_MALLOC | ||
| fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz, | ||
| (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024)); | ||
| #endif | ||
| void * ptr; | ||
| size_t look_ahead_size = (size_t) (1.05 * size); | ||
| look_ahead_size = 256 * ((look_ahead_size + 255)/256); | ||
| CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); | ||
| *actual_size = look_ahead_size; | ||
| g_cuda_pool_size[id] += look_ahead_size; | ||
| #ifdef DEBUG_CUDA_MALLOC | ||
| fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz, | ||
| (uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024)); | ||
| #endif | ||
| return ptr; | ||
| } | ||
|
|
||
| static void ggml_cuda_pool_free(void * ptr, size_t size) { | ||
| static void ggml_cuda_pool_free_leg(void * ptr, size_t size) { | ||
| scoped_spin_lock lock(g_cuda_pool_lock); | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
|
|
@@ -6620,8 +6642,122 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { | |
| } | ||
| fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); | ||
| CUDA_CHECK(cudaFree(ptr)); | ||
| g_cuda_pool_size[id] -= size; | ||
| } | ||
|
|
||
| #if !defined(GGML_USE_HIPBLAS) | ||
| // pool with virtual memory | ||
| static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES]; | ||
| static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0}; | ||
| static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0}; | ||
| static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB | ||
|
|
||
| static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) { | ||
| scoped_spin_lock lock(g_cuda_pool_lock); | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
|
|
||
| size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id]; | ||
|
|
||
| if (size > avail) { | ||
| size_t reserve_size = size - avail; | ||
|
|
||
| // allocate more physical memory | ||
| CUmemAllocationProp prop = {}; | ||
| prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; | ||
| prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
JohannesGaessler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| prop.location.id = id; | ||
|
|
||
| // get the minimum allocation granularity for this device | ||
| size_t granularity; | ||
| CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); | ||
|
|
||
| // round up to the next multiple of the granularity | ||
| reserve_size = granularity * ((reserve_size + granularity - 1) / granularity); | ||
|
|
||
| GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE); | ||
|
|
||
| CUmemGenericAllocationHandle handle; | ||
| CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0)); | ||
|
|
||
| // reserve virtual address space (if not already reserved) | ||
| if (g_cuda_pool_addr[id] == 0) { | ||
| CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0)); | ||
| } | ||
|
|
||
| // map at the end of the pool | ||
| CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0)); | ||
|
|
||
| // set access | ||
| CUmemAccessDesc access = {}; | ||
| access.location.type = CU_MEM_LOCATION_TYPE_DEVICE; | ||
| access.location.id = id; | ||
| access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; | ||
| CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1)); | ||
|
|
||
| // add to the pool | ||
| g_cuda_pool_handles[id].push_back(handle); | ||
| g_cuda_pool_size[id] += reserve_size; | ||
|
|
||
| //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n", | ||
| // id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024), | ||
| // (unsigned long long) (reserve_size/1024/1024)); | ||
| } | ||
|
|
||
| GGML_ASSERT(g_cuda_pool_addr[id] != 0); | ||
|
|
||
| void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]); | ||
| *actual_size = size; | ||
| g_cuda_pool_used[id] += size; | ||
|
|
||
| #ifdef DEBUG_CUDA_MALLOC | ||
| printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr); | ||
| #endif | ||
|
|
||
| return ptr; | ||
| } | ||
|
|
||
| static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) { | ||
| scoped_spin_lock lock(g_cuda_pool_lock); | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
|
|
||
| #ifdef DEBUG_CUDA_MALLOC | ||
| printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr); | ||
| #endif | ||
|
|
||
| g_cuda_pool_used[id] -= size; | ||
|
|
||
| // all deallocations must be in reverse order of the allocations | ||
| GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id])); | ||
| } | ||
|
|
||
| static bool g_device_vmm[GGML_CUDA_MAX_DEVICES] = {false}; | ||
|
|
||
| static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
| if (g_device_vmm[id]) { | ||
| return ggml_cuda_pool_malloc_vmm(size, actual_size); | ||
| } else { | ||
| return ggml_cuda_pool_malloc_leg(size, actual_size); | ||
| } | ||
| } | ||
|
|
||
| static void ggml_cuda_pool_free(void * ptr, size_t size) { | ||
| int id; | ||
| CUDA_CHECK(cudaGetDevice(&id)); | ||
| if (g_device_vmm[id]) { | ||
| ggml_cuda_pool_free_vmm(ptr, size); | ||
| } else { | ||
| ggml_cuda_pool_free_leg(ptr, size); | ||
| } | ||
| } | ||
| #else | ||
| #define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg | ||
| #define ggml_cuda_pool_free ggml_cuda_pool_free_leg | ||
| #endif | ||
slaren marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| static bool g_cublas_loaded = false; | ||
|
|
||
| bool ggml_cublas_loaded(void) { | ||
|
|
@@ -6660,9 +6796,18 @@ void ggml_init_cublas() { | |
| #endif | ||
| fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); | ||
| for (int id = 0; id < g_device_count; ++id) { | ||
| int device_vmm = 0; | ||
|
|
||
| #if !defined(GGML_USE_HIPBLAS) | ||
| CUdevice device; | ||
| CU_CHECK(cuDeviceGet(&device, id)); | ||
| CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); | ||
| g_device_vmm[id] = !!device_vmm; | ||
| #endif | ||
slaren marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| cudaDeviceProp prop; | ||
| CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); | ||
| fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); | ||
| fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); | ||
|
|
||
| g_tensor_split[id] = total_vram; | ||
| total_vram += prop.totalGlobalMem; | ||
|
|
@@ -7437,13 +7582,13 @@ inline void ggml_cuda_op_mul_mat_cublas( | |
|
|
||
| ggml_cuda_pool_free(dst_f16, dst_as); | ||
|
|
||
| if (src0_as != 0) { | ||
| ggml_cuda_pool_free(src0_as_f16, src0_as); | ||
| } | ||
|
|
||
| if (src1_as != 0) { | ||
| ggml_cuda_pool_free(src1_as_f16, src1_as); | ||
| } | ||
|
|
||
| if (src0_as != 0) { | ||
| ggml_cuda_pool_free(src0_as_f16, src0_as); | ||
| } | ||
| } | ||
| else { | ||
| float * src0_ddq_as_f32 = nullptr; | ||
|
|
@@ -7800,14 +7945,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s | |
| CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream)); | ||
| } | ||
|
|
||
| if (src0_asf > 0) { | ||
| ggml_cuda_pool_free(src0_ddf, src0_asf); | ||
| if (dst_asf > 0) { | ||
| ggml_cuda_pool_free(dst_ddf, dst_asf); | ||
| } | ||
| if (src1_asf > 0) { | ||
| ggml_cuda_pool_free(src1_ddf, src1_asf); | ||
| } | ||
| if (dst_asf > 0) { | ||
| ggml_cuda_pool_free(dst_ddf, dst_asf); | ||
| if (src0_asf > 0) { | ||
| ggml_cuda_pool_free(src0_ddf, src0_asf); | ||
| } | ||
|
|
||
| if (dst->backend == GGML_BACKEND_CPU) { | ||
|
|
@@ -8119,17 +8264,17 @@ static void ggml_cuda_op_mul_mat( | |
| CUDA_CHECK(ggml_cuda_set_device(id)); | ||
|
|
||
| // free buffers again when done | ||
| if (src0_as[id] > 0) { | ||
| ggml_cuda_pool_free(src0_dd[id], src0_as[id]); | ||
| } | ||
| if (src1_asf[id] > 0) { | ||
| ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); | ||
| if (dst_as[id] > 0) { | ||
| ggml_cuda_pool_free(dst_dd[id], dst_as[id]); | ||
| } | ||
| if (src1_asq[id] > 0) { | ||
| ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]); | ||
| } | ||
| if (dst_as[id] > 0) { | ||
| ggml_cuda_pool_free(dst_dd[id], dst_as[id]); | ||
| if (src1_asf[id] > 0) { | ||
| ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]); | ||
| } | ||
| if (src0_as[id] > 0) { | ||
| ggml_cuda_pool_free(src0_dd[id], src0_as[id]); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -8497,12 +8642,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const | |
| cu_compute_type, | ||
| CUBLAS_GEMM_DEFAULT_TENSOR_OP)); | ||
|
|
||
| if (ptrs_src_s != 0) { | ||
| ggml_cuda_pool_free(ptrs_src, ptrs_src_s); | ||
| } | ||
| if (ptrs_dst_s != 0) { | ||
| ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); | ||
| } | ||
| if (ptrs_src_s != 0) { | ||
| ggml_cuda_pool_free(ptrs_src, ptrs_src_s); | ||
| } | ||
| } | ||
| #endif | ||
|
|
||
|
|
@@ -8903,8 +9048,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s | |
| } | ||
| } | ||
|
|
||
| ggml_cuda_pool_free(src1_contiguous, as_src1); | ||
| ggml_cuda_pool_free(dst_contiguous, as_dst); | ||
| ggml_cuda_pool_free(src1_contiguous, as_src1); | ||
| } | ||
|
|
||
| if (dst->backend == GGML_BACKEND_CPU) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way
LLAMA_EXTRA_LIBSends up intarget_link_libraries(... PUBLIC ...)which probably isn't what was intended here?