@@ -66,25 +66,30 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
6666 cpy_1 (cx + x_offset, cdst + dst_offset);
6767}
6868
69- // static __global__ void cpy_q8_0_f32(const char * cx, float * dst, const int ne,
70- // const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
71- // const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
72- //
73- // if (i >= ne) {
74- // return;
75- // }
76- //
77- // const int64_t i03 = i/(ne00 * ne01 * ne02);
78- // const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
79- // const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
80- // const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
81- //
82- // const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
83- // const int ib = i00/QK8_0;
84- // const int iq = i00%QK8_0;
85- //
86- // dst[i00*ne01 + i01 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float(q8[ib].d)*q8[ib].qs[iq];
87- // }
69+ template <typename dst_t >
70+ static __global__ void k_cpy_q8_0_to_float (const char * cx, dst_t * dst, const int ne,
71+ const int ne00, const int ne01, const int ne02, const int nb01, const int nb02, const int nb03) {
72+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
73+
74+ if (i >= ne) {
75+ return ;
76+ }
77+
78+ const int64_t i03 = i/(ne00 * ne01 * ne02);
79+ const int64_t i02 = (i - i03*ne00*ne01*ne02) / (ne00*ne01);
80+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne00*ne01) / ne00;
81+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne00*ne01 - i01*ne00;
82+
83+ const block_q8_0 * q8 = (const block_q8_0 *)(cx + i01*nb01 + i02*nb02 + i03*nb03);
84+ const int ib = i00/QK8_0;
85+ const int iq = i00%QK8_0;
86+
87+ if constexpr (std::is_same_v<dst_t , nv_bfloat16>) {
88+ dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __float2bfloat16 (__half2float (q8[ib].d )*q8[ib].qs [iq]);
89+ } else {
90+ dst[i00 + i01*ne00 + i02*ne00*ne01 + i03*ne00*ne01*ne02] = __half2float (q8[ib].d )*q8[ib].qs [iq];
91+ }
92+ }
8893
8994static __global__ void k_transpose_q8_0 (const char * cx, char * cdst,
9095 const int ne10, const int ne11, const int ne12,
@@ -532,23 +537,26 @@ static void transpose_q8_0(ggml_backend_cuda_context & ctx, const ggml_tensor *
532537 (const char *)src->data , (char *)dst->data ,
533538 dst->ne [0 ], dst->ne [1 ], dst->ne [2 ], src->nb [0 ], src->nb [2 ], src->nb [3 ],
534539 dst->nb [1 ], dst->nb [2 ], dst->nb [3 ]);
540+ }
535541
536- // auto ne = ggml_nelements(dst);
537- // ggml_cuda_pool_alloc<float> dst_f32(ctx.pool(), ne);
538- // const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
539- // auto aux_src = *dst;
540- // aux_src.nb[0] = sizeof(float);
541- // aux_src.nb[1] = aux_src.nb[0]*aux_src.ne[0];
542- // aux_src.nb[2] = aux_src.nb[1]*aux_src.ne[1];
543- // aux_src.nb[3] = aux_src.nb[2]*aux_src.ne[2];
544- // cpy_q8_0_f32<<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
545- // ((const char *)src->data, dst_f32.get(), ne,
546- // src->ne[1], src->ne[0], src->ne[2], src->nb[0], src->nb[2], src->nb[3]);
547- // CUDA_CHECK(cudaGetLastError());
548- // aux_src.type = GGML_TYPE_F32;
549- // ggml_cpy_f32_q8_0_cuda((const char *)dst_f32.get(), (char *)dst->data, ne, dst->ne[0], dst->ne[1], dst->ne[2],
550- // aux_src.nb[0], aux_src.nb[1], aux_src.nb[2], aux_src.nb[3],
551- // dst->ne[0], dst->ne[1], dst->ne[2], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], stream);
542+ static void copy_q8_0_to_float (ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) {
543+ auto stream = ctx.stream ();
544+ auto num_blocks = ggml_nelements (dst)/QK8_0;
545+ if (dst->type == GGML_TYPE_F16) {
546+ k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0 , stream>>> ((const char *)src->data , (half *)dst->data , ggml_nelements (dst),
547+ src->ne [0 ], src->ne [1 ], src->ne [2 ], src->nb [1 ], src->nb [2 ], src->nb [3 ]);
548+ }
549+ else if (dst->type == GGML_TYPE_F32) {
550+ k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0 , stream>>> ((const char *)src->data , (float *)dst->data , ggml_nelements (dst),
551+ src->ne [0 ], src->ne [1 ], src->ne [2 ], src->nb [1 ], src->nb [2 ], src->nb [3 ]);
552+ }
553+ else if (dst->type == GGML_TYPE_BF16) {
554+ k_cpy_q8_0_to_float<<<num_blocks, QK8_0, 0 , stream>>> ((const char *)src->data , (nv_bfloat16 *)dst->data , ggml_nelements (dst),
555+ src->ne [0 ], src->ne [1 ], src->ne [2 ], src->nb [1 ], src->nb [2 ], src->nb [3 ]);
556+ }
557+ else {
558+ GGML_ABORT (" fatal error" );
559+ }
552560}
553561
554562void ggml_cuda_cpy (ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
@@ -607,8 +615,13 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
607615 ggml_cpy_f32_q5_1_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
608616 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
609617 ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
618+ } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
619+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
610620 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
611621 ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
622+ } else if (ggml_are_same_shape (src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
623+ (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
624+ copy_q8_0_to_float (ctx, src0, src1);
612625 } else if (ggml_is_contiguous (src0) && ggml_are_same_shape (src0, src1)) {
613626 if (src1->type == GGML_TYPE_F16) {
614627 auto to_fp16 = ggml_get_to_fp16_cuda (src0->type );
@@ -670,6 +683,9 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
670683 return (void *) cpy_f32_f16<cpy_1_f16_f16>;
671684 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
672685 return (void *) cpy_f32_f16<cpy_1_f16_f32>;
686+ } else if (ggml_are_same_shape (src0, src1) && src0->type == GGML_TYPE_Q8_0 &&
687+ (src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_BF16 || src1->type == GGML_TYPE_F32)) {
688+ return (void *)copy_q8_0_to_float;
673689 } else if (ggml_is_contiguous (src0) && ggml_are_same_shape (src0, src1)) {
674690 if (src1->type == GGML_TYPE_F16) {
675691 auto to_fp16 = ggml_get_to_fp16_cuda (src0->type );
0 commit comments