@@ -112,6 +112,30 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
112112 cpy_blck (cx + x_offset, cdst + dst_offset);
113113}
114114
115+ template <typename src_t , typename dst_t >
116+ static __global__ void cpy_flt_contiguous (const char * cx, char * cdst, const int64_t ne) {
117+ const int64_t i = blockDim .x *blockIdx .x + threadIdx .x ;
118+
119+ if (i >= ne) {
120+ return ;
121+ }
122+
123+ const src_t * x = (const src_t *) cx;
124+ dst_t * dst = (dst_t *) cdst;
125+
126+ dst[i] = ggml_cuda_cast<dst_t >(x[i]);
127+ }
128+
129+ template <typename src_t , typename dst_t >
130+ static void ggml_cpy_flt_contiguous_cuda (
131+ const char * cx, char * cdst, const int64_t ne,
132+ cudaStream_t stream) {
133+
134+ const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
135+ cpy_flt_contiguous<src_t , dst_t ><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
136+ (cx, cdst, ne);
137+ }
138+
115139template <typename src_t , typename dst_t >
116140static void ggml_cpy_flt_cuda (
117141 const char * cx, char * cdst, const int ne,
@@ -285,7 +309,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
285309 char * src0_ddc = (char *) src0->data ;
286310 char * src1_ddc = (char *) src1->data ;
287311
288- if (src0->type == src1->type && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
312+ const bool contiguous_srcs = ggml_is_contiguous (src0) && ggml_is_contiguous (src1);
313+
314+ if (src0->type == src1->type && contiguous_srcs) {
289315 GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
290316#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
291317 if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) {
@@ -296,11 +322,19 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
296322 CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
297323 }
298324 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
299- ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
325+ ggml_cpy_flt_cuda<float , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
300326 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
301- ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
327+ if (contiguous_srcs) {
328+ ggml_cpy_flt_contiguous_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
329+ } else {
330+ ggml_cpy_flt_cuda<float , nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331+ }
302332 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
303- ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
333+ if (contiguous_srcs) {
334+ ggml_cpy_flt_contiguous_cuda<float , half> (src0_ddc, src1_ddc, ne, main_stream);
335+ } else {
336+ ggml_cpy_flt_cuda<float , half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337+ }
304338 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
305339 ggml_cpy_f32_q8_0_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
306340 } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
@@ -327,21 +361,45 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
327361 } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
328362 ggml_cpy_q5_1_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
329363 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
330- ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
364+ ggml_cpy_flt_cuda<half, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
331365 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
332- ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
366+ if (contiguous_srcs) {
367+ ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
368+ } else {
369+ ggml_cpy_flt_cuda<half, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
370+ }
333371 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
334- ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
372+ if (contiguous_srcs) {
373+ ggml_cpy_flt_contiguous_cuda<half, float > (src0_ddc, src1_ddc, ne, main_stream);
374+ } else {
375+ ggml_cpy_flt_cuda<half, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
376+ }
335377 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
336378 ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
337379 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
338- ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
380+ if (contiguous_srcs) {
381+ ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, main_stream);
382+ } else {
383+ ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
384+ }
339385 } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
340- ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
386+ if (contiguous_srcs) {
387+ ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, main_stream);
388+ } else {
389+ ggml_cpy_flt_cuda<nv_bfloat16, float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
390+ }
341391 } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
342- ggml_cpy_flt_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
392+ if (contiguous_srcs) {
393+ ggml_cpy_flt_contiguous_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, main_stream);
394+ } else {
395+ ggml_cpy_flt_cuda<float , int32_t > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
396+ }
343397 } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
344- ggml_cpy_flt_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
398+ if (contiguous_srcs) {
399+ ggml_cpy_flt_contiguous_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, main_stream);
400+ } else {
401+ ggml_cpy_flt_cuda<int32_t , float > (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
402+ }
345403 } else {
346404 GGML_ABORT (" %s: unsupported type combination (%s to %s)\n " , __func__,
347405 ggml_type_name (src0->type ), ggml_type_name (src1->type ));
0 commit comments