From b2c0f7f303ca902eec5c97c390f9181c5ecee7e7 Mon Sep 17 00:00:00 2001 From: Djip007 Date: Mon, 20 May 2024 13:22:18 +0200 Subject: [PATCH 1/2] update HIP_UMA #7399 add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable. - get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103) --- ggml-cuda.cu | 24 +++++++++++++++++++++--- ggml-cuda/common.cuh | 2 -- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b82167cbf72..456e7378d8f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -119,6 +119,24 @@ int ggml_cuda_get_device() { return id; } +// ggml_cuda_host_malloc +static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { +#if defined(GGML_USE_HIPBLAS) +#if defined(GGML_HIP_UMA) + auto res = hipMallocManaged(ptr, size); + if (res == hipSuccess) { + // if error we "need" to know why... + CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + } + return res; +#else + return hipMalloc(ptr, size); +#endif +#else + return cudaMalloc(ptr, size); +#endif +} + static ggml_cuda_device_info ggml_cuda_init() { #ifdef __HIP_PLATFORM_AMD__ // Workaround for a rocBLAS bug when using multiple graphics cards: @@ -271,7 +289,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); + CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -537,7 +555,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 void * dev_ptr; - cudaError_t err = cudaMalloc(&dev_ptr, size); + cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device); if (err != cudaSuccess) { // clear the error cudaGetLastError(); @@ -798,7 +816,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_cuda_set_device(id); char * buf; - CUDA_CHECK(cudaMalloc(&buf, size)); + CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id)); // set padding to 0 to avoid possible NaN values if (size > original_size) { diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 8f6fd71cfea..4c41ee785c2 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -80,10 +80,8 @@ #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc #ifdef GGML_HIP_UMA -#define cudaMalloc hipMallocManaged #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size) #else -#define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #endif #define cudaMemcpy hipMemcpy From a6a1abd98ebd8eea215e47c1b6547f1404ad9b7a Mon Sep 17 00:00:00 2001 From: slaren Date: Sun, 26 May 2024 18:47:42 +0200 Subject: [PATCH 2/2] simplify code, more consistent style --- ggml-cuda.cu | 10 +++------- ggml-cuda/common.cuh | 5 +---- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 456e7378d8f..6e2f525dbe2 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -119,19 +119,15 @@ int ggml_cuda_get_device() { return id; } -// ggml_cuda_host_malloc -static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { -#if defined(GGML_USE_HIPBLAS) -#if defined(GGML_HIP_UMA) +static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { + ggml_cuda_set_device(device); +#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA) auto res = hipMallocManaged(ptr, size); if (res == hipSuccess) { // if error we "need" to know why... CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); } return res; -#else - return hipMalloc(ptr, size); -#endif #else return cudaMalloc(ptr, size); #endif diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 4c41ee785c2..22872ca5c1d 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -79,11 +79,8 @@ #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister #define cudaLaunchHostFunc hipLaunchHostFunc -#ifdef GGML_HIP_UMA -#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size) -#else +#define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) -#endif #define cudaMemcpy hipMemcpy #define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync