@@ -36,14 +36,19 @@ typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs
3636typedef void (*to_fp32_cuda_t )(const void * x, float * y, int k, cudaStream_t stream);
3737typedef void (*dequantize_mul_mat_vec_cuda_t )(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
3838
39+ // QK = number of values after dequantization
40+ // QR = QK / number of values before dequantization
41+
3942#define QK4_0 32
43+ #define QR4_0 2
4044typedef struct {
4145 float d; // delta
4246 uint8_t qs[QK4_0 / 2 ]; // nibbles / quants
4347} block_q4_0;
4448static_assert (sizeof (block_q4_0) == sizeof (float ) + QK4_0 / 2 , " wrong q4_0 block size/padding" );
4549
4650#define QK4_1 32
51+ #define QR4_1 2
4752typedef struct {
4853 float d; // delta
4954 float m; // min
@@ -52,6 +57,7 @@ typedef struct {
5257static_assert (sizeof (block_q4_1) == sizeof (float ) * 2 + QK4_1 / 2 , " wrong q4_1 block size/padding" );
5358
5459#define QK5_0 32
60+ #define QR5_0 2
5561typedef struct {
5662 half d; // delta
5763 uint8_t qh[4 ]; // 5-th bit of quants
@@ -60,6 +66,7 @@ typedef struct {
6066static_assert (sizeof (block_q5_0) == sizeof (ggml_fp16_t ) + sizeof (uint32_t ) + QK5_0 / 2 , " wrong q5_0 block size/padding" );
6167
6268#define QK5_1 32
69+ #define QR5_1 2
6370typedef struct {
6471 half d; // delta
6572 half m; // min
@@ -69,6 +76,7 @@ typedef struct {
6976static_assert (sizeof (block_q5_1) == 2 * sizeof (ggml_fp16_t ) + sizeof (uint32_t ) + QK5_1 / 2 , " wrong q5_1 block size/padding" );
7077
7178#define QK8_0 32
79+ #define QR8_0 1
7280typedef struct {
7381 float d; // delta
7482 int8_t qs[QK8_0]; // quants
@@ -124,6 +132,44 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
124132 v1 = x1*d;
125133}
126134
135+ static __device__ void dequantize_q5_1 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
136+ const block_q5_1 * x = (const block_q5_1 *) vx;
137+
138+ const float d = x[ib].d ;
139+ const float m = x[ib].m ;
140+
141+ uint32_t qh;
142+ memcpy (&qh, x[ib].qh , sizeof (qh));
143+
144+ const uint8_t xh_0 = ((qh >> (iqs + 0 )) << 4 ) & 0x10 ;
145+ const uint8_t xh_1 = ((qh >> (iqs + 12 )) ) & 0x10 ;
146+
147+ const int32_t x0 = ((x[ib].qs [iqs] & 0xf ) | xh_0);
148+ const int32_t x1 = ((x[ib].qs [iqs] >> 4 ) | xh_1);
149+
150+ v0 = x0*d + m;
151+ v1 = x1*d + m;
152+ }
153+
154+ static __device__ void dequantize_q8_0 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
155+ const block_q8_0 * x = (const block_q8_0 *) vx;
156+
157+ const float d = x[ib].d ;
158+
159+ const int8_t vi0 = x[ib].qs [iqs + 0 ];
160+ const int8_t vi1 = x[ib].qs [iqs + 1 ];
161+
162+ v0 = vi0*d;
163+ v1 = vi1*d;
164+ }
165+
166+ static __device__ void convert_f16 (const void * vx, const int ib, const int iqs, float & v0, float & v1){
167+ const half * x = (const half *) vx;
168+
169+ v0 = __half2float (x[ib + 0 ]);
170+ v1 = __half2float (x[ib + 1 ]);
171+ }
172+
127173static __global__ void dequantize_block_q4_0 (const void * vx, float * y) {
128174 static const int qk = QK4_0;
129175
@@ -224,18 +270,20 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
224270 }
225271}
226272
227- template <int block_size, int qk, dequantize_kernel_t dequantize_kernel>
273+ template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
228274static __global__ void dequantize_mul_mat_vec (const void * vx, const float * y, float * dst, const int ncols) {
229275 const int row = blockIdx .x ;
230276 const int tid = threadIdx .x ;
231277
278+ const int y_offset = qr == 1 ? 1 : qk/2 ;
279+
232280 __shared__ float tmp[block_size]; // separate sum for each thread
233281 tmp[tid] = 0 ;
234282
235283 for (int i = 0 ; i < ncols/block_size; i += 2 ) {
236284 const int col = i*block_size + 2 *tid;
237285 const int ib = (row*ncols + col)/qk; // block index
238- const int iqs = (col%qk)/2 ; // quant index
286+ const int iqs = (col%qk)/qr ; // quant index
239287 const int iybs = col - col%qk; // y block start index
240288
241289 // dequantize
@@ -244,7 +292,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
244292
245293 // matrix multiplication
246294 tmp[tid] += v0 * y[iybs + iqs + 0 ];
247- tmp[tid] += v1 * y[iybs + iqs + qk/ 2 ];
295+ tmp[tid] += v1 * y[iybs + iqs + y_offset ];
248296 }
249297
250298 // sum up partial sums and write back result
@@ -287,17 +335,32 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
287335
288336static void dequantize_mul_mat_vec_q4_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
289337 GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
290- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, dequantize_q4_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
338+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
339+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
291340}
292341
293342static void dequantize_mul_mat_vec_q4_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
294343 GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
295- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, dequantize_q4_1><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
344+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
345+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
296346}
297347
298348static void dequantize_mul_mat_vec_q5_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
299349 GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
300- dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, dequantize_q5_0><<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
350+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
351+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
352+ }
353+
354+ static void dequantize_mul_mat_vec_q5_1_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
355+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
356+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
357+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
358+ }
359+
360+ static void dequantize_mul_mat_vec_q8_0_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
361+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
362+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
363+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
301364}
302365
303366// TODO: optimize
@@ -313,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
313376 convert_fp16_to_fp32<<<k, 1 , 0 , stream>>> (x, y);
314377}
315378
379+ static void convert_mul_mat_vec_f16_cuda (const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
380+ GGML_ASSERT (ncols % CUDA_DMMV_BLOCK_SIZE == 0 );
381+ dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32 , 1 , convert_f16>
382+ <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0 , stream>>> (vx, y, dst, ncols);
383+ }
384+
316385static to_fp32_cuda_t ggml_get_to_fp32_cuda (ggml_type type) {
317386 switch (type) {
318387 case GGML_TYPE_Q4_0:
@@ -340,6 +409,12 @@ static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_t
340409 return dequantize_mul_mat_vec_q4_1_cuda;
341410 case GGML_TYPE_Q5_0:
342411 return dequantize_mul_mat_vec_q5_0_cuda;
412+ case GGML_TYPE_Q5_1:
413+ return dequantize_mul_mat_vec_q5_1_cuda;
414+ case GGML_TYPE_Q8_0:
415+ return dequantize_mul_mat_vec_q8_0_cuda;
416+ case GGML_TYPE_F16:
417+ return dequantize_mul_mat_vec_q8_0_cuda;
343418 default :
344419 return nullptr ;
345420 }
0 commit comments