@@ -147,9 +147,41 @@ inline static void* ggml_aligned_malloc(size_t size) {
147147#include <Accelerate/Accelerate.h>
148148#elif defined(GGML_USE_OPENBLAS )
149149#include <cblas.h>
150- #elif defined(GGML_USE_CUBLAS )
150+ #elif defined(GGML_USE_CUBLAS ) || defined(GGML_USE_HIPBLAS )
151+
152+ #if defined(GGML_USE_HIPBLAS )
153+ #include "hipblas/hipblas.h"
154+ #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
155+ #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
156+ #define CUBLAS_OP_N HIPBLAS_OP_N
157+ #define CUBLAS_OP_T HIPBLAS_OP_T
158+ #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
159+ #define cublasCreate hipblasCreate
160+ #define cublasGemmEx hipblasGemmEx
161+ #define cublasHandle_t hipblasHandle_t
162+ #define cublasSetStream hipblasSetStream
163+ #define cublasSgemm hipblasSgemm
164+ #define cublasStatus_t hipblasStatus_t
165+ #define CUDA_R_16F HIPBLAS_R_16F
166+ #define CUDA_R_32F HIPBLAS_R_32F
167+ #define cudaError_t hipError_t
168+ #define cudaFree hipFree
169+ #define cudaGetErrorString hipGetErrorString
170+ #define cudaGetLastError hipGetLastError
171+ #define cudaMalloc hipMalloc
172+ #define cudaMemcpyAsync hipMemcpyAsync
173+ #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
174+ #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
175+ #define cudaStream_t hipStream_t
176+ #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
177+ #define cudaStreamNonBlocking hipStreamNonBlocking
178+ #define cudaStreamSynchronize hipStreamSynchronize
179+ #define cudaSuccess hipSuccess
180+ #define GGML_USE_CUBLAS
181+ #else
151182#include <cublas_v2.h>
152183#include <cuda_runtime.h>
184+ #endif
153185#include "ggml-cuda.h"
154186
155187#define CUDA_CHECK (err ) \
@@ -8040,9 +8072,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
80408072 else if (type == GGML_TYPE_Q4_2 ) {
80418073 dequantize_row_q_cuda = dequantize_row_q4_2_cuda ;
80428074 }
8043- else if (type == GGML_TYPE_Q4_3 ) {
8044- dequantize_row_q_cuda = dequantize_row_q4_3_cuda ;
8045- }
80468075 else {
80478076 GGML_ASSERT (false);
80488077 }
@@ -8076,7 +8105,6 @@ static void ggml_compute_forward_mul_mat_q_f32(
80768105 const float * x = wdata ;
80778106#endif
80788107
8079-
80808108#if defined(GGML_USE_CUBLAS )
80818109 // copy data to device
80828110 CUDA_CHECK (cudaMemcpyAsync (d_Y , y , sizeof (float ) * y_ne , cudaMemcpyHostToDevice , cudaStream ));
0 commit comments