Skip to content

Commit 68a5b60

Browse files
ikawrakowIwan Kawrakow
andauthored
Make Q8_0 KV cache work with mla=2,fa on CUDA (#264)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent f4ebf13 commit 68a5b60

File tree

5 files changed

+117
-46
lines changed

5 files changed

+117
-46
lines changed

ggml/src/ggml-cuda.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3395,6 +3395,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
33953395
if (op->op == GGML_OP_MOE_FUSED_UP_GATE && a->type != op->src[1]->type) {
33963396
return false;
33973397
}
3398+
//==================================================================
3399+
//if (ggml_is_quantized(a->type) && ggml_is_quantized(b->type)) {
3400+
// return false;
3401+
//}
3402+
//==================================================================
33983403
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16 && !ggml_is_quantized(a->type)) {
33993404
return false;
34003405
}
@@ -3496,6 +3501,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
34963501
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
34973502
return true;
34983503
}
3504+
if (ggml_is_quantized(src0_type) && (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_F32)) {
3505+
return true;
3506+
}
34993507
if (ggml_is_contiguous(op->src[0]) && ggml_are_same_shape(op->src[0], op->src[1])) {
35003508
if (src1_type == GGML_TYPE_F16 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F32) {
35013509
return true;

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,28 @@ static void ggml_cuda_op_bin_bcast(
282282
}
283283

284284
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
285-
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
285+
GGML_ASSERT(dst->type == dst->src[0]->type);
286+
if (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16) {
287+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
288+
return;
289+
}
290+
auto src = dst->src[0];
291+
auto bs = ggml_blck_size(src->type);
292+
auto ts = ggml_type_size(src->type);
293+
if (src->nb[0] != ts || ts*(src->ne[0]/bs) % 2 != 0) {
294+
fprintf(stderr, "%s: unsupported case type = %s, nb[0] = %zu, type_size = %zu\n", __func__, ggml_type_name(src->type), src->nb[0], ts);
295+
GGML_ABORT("fatal error");
296+
}
297+
auto aux_src = *src;
298+
aux_src.type = GGML_TYPE_F16;
299+
aux_src.ne[0] = ts*(src->ne[0]/bs)/2;
300+
aux_src.nb[0] = 2;
301+
auto aux_dst = *dst;
302+
aux_dst.type = GGML_TYPE_F16;
303+
aux_dst.ne[0] = ts*(dst->ne[0]/bs)/2;
304+
aux_dst.nb[0] = 2;
305+
aux_dst.src[0] = &aux_src;
306+
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(&aux_dst, &aux_src, &aux_dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
286307
}
287308

288309
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

ggml/src/ggml-cuda/concat.cu

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
209209

210210
if (dim == 0 && src0->nb[0] == ggml_type_size(src0->type) && src1->nb[0] == ggml_type_size(src1->type) &&
211211
src0->nb[1] % sizeof(float) == 0 && src1->nb[1] % sizeof(float) == 0) {
212+
auto bs = ggml_blck_size(dst->type);
213+
auto ts = ggml_type_size(dst->type);
214+
auto ne00_eff = (src0->ne[0]/bs)*ts/sizeof(float);
215+
auto ne0_eff = (dst->ne[0]/bs)*ts/sizeof(float);
212216
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) {
213217
//if (dst->ne[1] >= 65536 || dst->ne[2] >= 65536) {
214218
// fprintf(stderr, "%s: ne1 = %ld, ne2 = %ld exceed max. blocks when computing %s\n", __func__, dst->ne[1], dst->ne[2], dst->name);
@@ -217,25 +221,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
217221
const float * src0_d = (const float *)src0->data;
218222
const float * src1_d = (const float *)src1->data;
219223
float * dst_d = (float *)dst->data;
224+
//printf("%s(%s, %s): %ld %zu %zu %ld %zu %zu\n", __func__, src0->name, src1->name, src0->ne[0], src0->nb[0], src0->nb[1], dst->ne[0], dst->nb[0], dst->nb[1]);
220225
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
221226
concat_f32_cuda(
222227
src0_d + i3 * (src0->nb[3] / 4),
223228
src1_d + i3 * (src1->nb[3] / 4),
224229
dst_d + i3 * ( dst->nb[3] / 4),
225-
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
226-
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
230+
ne00_eff, src0->ne[1], src0->ne[2],
231+
ne0_eff, dst->ne[1], dst->ne[2], dim, stream);
232+
//src0->nb[1]/sizeof(float), src0->ne[1], src0->ne[2],
233+
//dst->nb[1]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
234+
//src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2],
235+
//dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dim, stream);
227236
}
228237
} else {
238+
//printf("%s(not contiguous): %s(%s) and %s(%s)\n", __func__, src0->name, ggml_type_name(src0->type), src1->name, ggml_type_name(src1->type));
239+
auto ne10_eff = (src1->ne[0]/bs)*ts/sizeof(float);
229240
dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]);
230241
concat_f32_non_cont<<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(
231242
(const char *)src0->data,
232243
(const char *)src1->data,
233244
( char *)dst->data,
234-
src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
245+
ne00_eff, src0->ne[1], src0->ne[2], src0->ne[3],
246+
//src0->ne[0]*src0->nb[0]/sizeof(float), src0->ne[1], src0->ne[2], src0->ne[3],
235247
sizeof(float), src0->nb[1], src0->nb[2], src0->nb[3],
236-
src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
248+
ne10_eff, src1->ne[1], src1->ne[2], src1->ne[3],
249+
//src1->ne[0]*src1->nb[0]/sizeof(float), src1->ne[1], src1->ne[2], src1->ne[3],
237250
sizeof(float), src1->nb[1], src1->nb[2], src1->nb[3],
238-
dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
251+
ne0_eff, dst->ne[1], dst->ne[2], dst->ne[3],
252+
//dst->ne[0]*dst->nb[0]/sizeof(float), dst->ne[1], dst->ne[2], dst->ne[3],
239253
sizeof(float), dst->nb[1], dst->nb[2], dst->nb[3], dim);
240254
}
241255
return;

ggml/src/ggml-cuda/cpy.cu

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,25 +66,30 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
6666
cpy_1(cx + x_offset, cdst + dst_offset);
6767
}
6868

69-
//static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
70-
// const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
71-
// const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
72-
//
73-
// if (i >= ne) {
74-
// return;
75-
// }
76-
//
77-
// const int64_t i03 = i/(ne00 * ne01 * ne02);
78-
// const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
79-
// const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
80-
// const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
81-
//
82-
// const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
83-
// const int ib = i00/QK8_0;
84-
// const int iq = i00%QK8_0;
85-
//
86-
// dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
87-
//}
69+
template <typename dst_t>
70+
static __global__ void k_cpy_q8_0_to_float(const char * cx, dst_t * dst, const int ne,
71+
const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
72+
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
73+
74+
if (i >= ne) {
75+
return;
76+
}
77+
78+
const int64_t i03 = i/(ne00 * ne01 * ne02);
79+
const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
80+
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
81+
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
82+
83+
const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
84+
const int ib = i00/QK8_0;
85+
const int iq = i00%QK8_0;
86+
87+
if constexpr (std::is_same_v<dst_t, nv_bfloat16>) {
88+
dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16(__half2float(q8[ib].d)*q8[ib].qs[iq]);
89+
} else {
90+
dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
91+
}
92+
}
8893

8994
static __global__ void k_transpose_q8_0(const char * cx, char * cdst,
9095
const int ne10, const int ne11, const int ne12,
@@ -532,23 +537,26 @@ static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor *
532537
(const char *)src->data, (char *)dst->data,
533538
dst->ne[0], dst->ne[1], dst->ne[2], src->nb[0], src->nb[2], src->nb[3],
534539
dst->nb[1], dst->nb[2], dst->nb[3]);
540+
}
535541

536-
//auto ne = ggml_nelements(dst);
537-
//ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
538-
//const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
539-
//auto aux_src = *dst;
540-
//aux_src.nb[0] = sizeof(float);
541-
//aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
542-
//aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
543-
//aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
544-
//cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
545-
// ((const char *)src->data, dst_f32.get(), ne,
546-
// src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
547-
//CUDA_CHECK(cudaGetLastError());
548-
//aux_src.type = GGML_TYPE_F32;
549-
//ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
550-
// aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
551-
// dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
542+
static void copy_q8_0_to_float(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
543+
auto stream = ctx.stream();
544+
auto num_blocks = ggml_nelements(dst)/QK8_0;
545+
if (dst->type == GGML_TYPE_F16) {
546+
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (half *)dst->data, ggml_nelements(dst),
547+
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
548+
}
549+
else if (dst->type == GGML_TYPE_F32) {
550+
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (float *)dst->data, ggml_nelements(dst),
551+
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
552+
}
553+
else if (dst->type == GGML_TYPE_BF16) {
554+
k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0, stream>>>((const char *)src->data, (nv_bfloat16 *)dst->data, ggml_nelements(dst),
555+
src->ne[0], src->ne[1], src->ne[2], src->nb[1], src->nb[2], src->nb[3]);
556+
}
557+
else {
558+
GGML_ABORT("fatal error");
559+
}
552560
}
553561

554562
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
@@ -607,8 +615,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
607615
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
608616
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
609617
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
618+
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
619+
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
610620
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
611621
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
622+
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
623+
(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
624+
copy_q8_0_to_float(ctx, src0, src1);
612625
} else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) {
613626
if (src1->type == GGML_TYPE_F16) {
614627
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);
@@ -670,6 +683,9 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
670683
return (void*) cpy_f32_f16<cpy_1_f16_f16>;
671684
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
672685
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
686+
} else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
687+
(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
688+
return (void*)copy_q8_0_to_float;
673689
} else if (ggml_is_contiguous(src0) && ggml_are_same_shape(src0, src1)) {
674690
if (src1->type == GGML_TYPE_F16) {
675691
auto to_fp16 = ggml_get_to_fp16_cuda(src0->type);

src/llama.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13771,9 +13771,21 @@ struct llm_build_context {
1377113771
auto kv_cache_rope = ggml_view_3d(ctx0, kv_self.kv_l[il], n_embd_head_qk_rope, n_kv, 1,
1377213772
kv_self.kv_l[il]->nb[1], kv_self.kv_l[il]->nb[2], ggml_row_size(kv_self.kv_l[il]->type, kv_lora_rank));
1377313773

13774+
// There is still an issue with one or more of the ops GGML_OP_REPEAT, GGML_OP_CONCAT, GGML_OP_CPY on CUDA when
13775+
// the KV cache is quantized. Hence, in that case we will simply use fp16 for now.
13776+
// The downside of the following line is that fp16 will be used even if attention is computed on the CPU
13777+
// if the build is with CUDA enabled.
13778+
auto kv_type = lctx.backends.size() == 1 && lctx.backends.front() == lctx.backend_cpu ? kv_self.kv_l[il]->type : GGML_TYPE_F16;
13779+
1377413780
ggml_tensor repeater;
1377513781
repeater.ne[0] = n_embd_head_qk_rope; repeater.ne[1] = n_kv; repeater.ne[2] = n_max_head; repeater.ne[3] = 1;
13776-
auto k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
13782+
ggml_tensor * k_rope;
13783+
if (kv_cache_rope->type == kv_type) {
13784+
k_rope = ggml_repeat(ctx0, kv_cache_rope, &repeater);
13785+
} else {
13786+
auto kv_cache_rope_f16 = ggml_cast(ctx0, kv_cache_rope, GGML_TYPE_F16);
13787+
k_rope = ggml_repeat(ctx0, kv_cache_rope_f16, &repeater);
13788+
}
1377713789
cb(k_rope, "k_rope", il);
1377813790

1377913791
auto q = ggml_concat(ctx0, q_nope, q_rope, 0);
@@ -13796,15 +13808,15 @@ struct llm_build_context {
1379613808
ggml_row_size(kv_f32->type, n_embd_head_qk_nope));
1379713809
cb(v_f32, "v_f32", il);
1379813810

13799-
auto v = ggml_cast(ctx0, v_f32, kv_self.kv_l[il]->type);
13800-
cb(v, "v", il);
13801-
1380213811
auto k_nope_f32 = ggml_view_3d(ctx0, kv_f32, n_embd_head_qk_nope, n_kv, n_max_head,
1380313812
ggml_row_size(kv_f32->type, n_max_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
1380413813
ggml_row_size(kv_f32->type, n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
1380513814
cb(k_nope_f32, "k_nope_f32", il);
1380613815

13807-
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_self.kv_l[il]->type);
13816+
auto v = ggml_cast(ctx0, v_f32, kv_type);
13817+
cb(v, "v", il);
13818+
13819+
auto k_nope = ggml_cast(ctx0, k_nope_f32, kv_type);
1380813820
cb(k_nope, "k_nope", il);
1380913821

1381013822
ggml_build_forward_expand(gf, k_nope);

0 commit comments

Comments
 (0)