Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ if (LLAMA_CUBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()

set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cuda_driver)

Comment on lines +305 to +306
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way LLAMA_EXTRA_LIBS ends up in target_link_libraries(... PUBLIC ...) which probably isn't what was intended here?

if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
# 52 == lowest CUDA 12 standard
# 60 == f16 CUDA intrinsics
Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ endif # LLAMA_BLIS

ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
MK_LDFLAGS += -lcuda -L/usr/lib/wsl/lib -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
OBJS += ggml-cuda.o
MK_NVCCFLAGS = -use_fast_math

ifndef JETSON_EOL_MODULE_DETECT
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
endif # JETSON_EOL_MODULE_DETECT
Expand Down
203 changes: 174 additions & 29 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
#define __trap abort
#else
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
// CUDA 10.2 does not have these macro definitions.
Expand Down Expand Up @@ -213,6 +214,24 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)

// driver API
#define CU_CHECK(err) \
do { \
CUresult err_ = (err); \
if (err_ != CUDA_SUCCESS) { \
int id; \
cuDeviceGet(&id, 0); \
const char * err_str; \
cuGetErrorString(err_, &err_str); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
err_str); \
fprintf(stderr, "%s\n", #err); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"CUDA error"); \
} \
} while (0)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my test implementation for peer access I was able to simply use CUDA_CHECK for the CUresult. I don't know whether that works universally though.

Copy link
Member Author

@slaren slaren Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't work for me, it is a different type and requires a different function to get the error description, cuGetErrorString.

#if CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) \
do { \
Expand Down Expand Up @@ -6543,21 +6562,24 @@ struct scoped_spin_lock {
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};

static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;

// #define DEBUG_CUDA_MALLOC
struct cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};

static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};

static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
static void * ggml_cuda_pool_malloc_leg(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0, tot_size = 0;
size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
Expand All @@ -6566,7 +6588,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
tot_size += b.size;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
Expand All @@ -6593,19 +6614,20 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
b.size = 0;
return ptr;
}
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
g_cuda_pool_size[id] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr;
}

static void ggml_cuda_pool_free(void * ptr, size_t size) {
static void ggml_cuda_pool_free_leg(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
Expand All @@ -6620,8 +6642,122 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[id] -= size;
}

#if !defined(GGML_USE_HIPBLAS)
// pool with virtual memory
static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 36; // 64 GB

static void * ggml_cuda_pool_malloc_vmm(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));

size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];

if (size > avail) {
size_t reserve_size = size - avail;

// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = id;

// get the minimum allocation granularity for this device
size_t granularity;
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

// round up to the next multiple of the granularity
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);

GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));

// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[id] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
}

// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));

// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));

// add to the pool
g_cuda_pool_handles[id].push_back(handle);
g_cuda_pool_size[id] += reserve_size;

//printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
// id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
// (unsigned long long) (reserve_size/1024/1024));
}

GGML_ASSERT(g_cuda_pool_addr[id] != 0);

void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
*actual_size = size;
g_cuda_pool_used[id] += size;

#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr);
#endif

return ptr;
}

static void ggml_cuda_pool_free_vmm(void * ptr, size_t size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));

#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: freed %llu bytes at %llx\n", id, (unsigned long long) size, ptr);
#endif

g_cuda_pool_used[id] -= size;

// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));
}

static bool g_device_vmm[GGML_CUDA_MAX_DEVICES] = {false};

static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_vmm[id]) {
return ggml_cuda_pool_malloc_vmm(size, actual_size);
} else {
return ggml_cuda_pool_malloc_leg(size, actual_size);
}
}

static void ggml_cuda_pool_free(void * ptr, size_t size) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (g_device_vmm[id]) {
ggml_cuda_pool_free_vmm(ptr, size);
} else {
ggml_cuda_pool_free_leg(ptr, size);
}
}
#else
#define ggml_cuda_pool_malloc ggml_cuda_pool_malloc_leg
#define ggml_cuda_pool_free ggml_cuda_pool_free_leg
#endif

static bool g_cublas_loaded = false;

bool ggml_cublas_loaded(void) {
Expand Down Expand Up @@ -6660,9 +6796,18 @@ void ggml_init_cublas() {
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) {
int device_vmm = 0;

#if !defined(GGML_USE_HIPBLAS)
CUdevice device;
CU_CHECK(cuDeviceGet(&device, id));
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
g_device_vmm[id] = !!device_vmm;
#endif

cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor);
fprintf(stderr, " Device %d: %s, compute capability %d.%d, VMM: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");

g_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;
Expand Down Expand Up @@ -7437,13 +7582,13 @@ inline void ggml_cuda_op_mul_mat_cublas(

ggml_cuda_pool_free(dst_f16, dst_as);

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}

if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
}
else {
float * src0_ddq_as_f32 = nullptr;
Expand Down Expand Up @@ -7800,14 +7945,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}

if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
}
if (src1_asf > 0) {
ggml_cuda_pool_free(src1_ddf, src1_asf);
}
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
}

if (dst->backend == GGML_BACKEND_CPU) {
Expand Down Expand Up @@ -8119,17 +8264,17 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(id));

// free buffers again when done
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
}

Expand Down Expand Up @@ -8497,12 +8642,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
}
#endif

Expand Down Expand Up @@ -8903,8 +9048,8 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
}
}

ggml_cuda_pool_free(src1_contiguous, as_src1);
ggml_cuda_pool_free(dst_contiguous, as_dst);
ggml_cuda_pool_free(src1_contiguous, as_src1);
}

if (dst->backend == GGML_BACKEND_CPU) {
Expand Down
3 changes: 0 additions & 3 deletions tests/test-grad0.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,6 @@ int main(int argc, const char ** argv) {
srand(seed);
const int nargs = 1;

int64_t ne2[4];
ne2[0] = 1;

for (int ndims = 1; ndims <= 2; ++ndims) {
x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);

Expand Down