@@ -428,7 +428,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
428428    char  * src0_ddc = (char  *) src0->data ;
429429    char  * src1_ddc = (char  *) src1->data ;
430430
431-     if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
431+     if  (src0->type  == src1->type  && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
432+         GGML_ASSERT (ggml_nbytes (src0) == ggml_nbytes (src1));
433+         CUDA_CHECK (cudaMemcpyAsync (src1_ddc, src0_ddc, ggml_nbytes (src0), cudaMemcpyDeviceToDevice, main_stream));
434+     } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
432435        ggml_cpy_f32_f32_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
433436    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
434437        ggml_cpy_f32_f16_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
@@ -449,9 +452,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
449452    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_F32) {
450453        ggml_cpy_f16_f32_cuda  (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
451454    } else  {
452-         fprintf (stderr,  " %s: unsupported type combination (%s to %s)\n " 
455+         GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " 
453456                ggml_type_name (src0->type ), ggml_type_name (src1->type ));
454-         GGML_ABORT (" fatal error" 
455457    }
456458}
457459
@@ -461,29 +463,30 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
461463}
462464
463465void * ggml_cuda_cpy_fn (const  ggml_tensor * src0, ggml_tensor * src1) {
464-     if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
465-             return  (void *) cpy_f32_f16<cpy_1_f32_f32>;
466+     if  (src0->type  == src1->type  && ggml_is_contiguous (src0) && ggml_is_contiguous (src1)) {
467+         return  nullptr ;
468+     } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F32) {
469+         return  (void *) cpy_f32_f16<cpy_1_f32_f32>;
466470    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_F16) {
467-              return  (void *) cpy_f32_f16<cpy_1_f32_f16>;
471+         return  (void *) cpy_f32_f16<cpy_1_f32_f16>;
468472    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q8_0) {
469-              return  (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
473+         return  (void *) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
470474    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q4_0) {
471-              return  (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
475+         return  (void *) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
472476    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q4_1) {
473-              return  (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
477+         return  (void *) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
474478    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q5_0) {
475-              return  (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
479+         return  (void *) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
476480    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_IQ4_NL) {
477-              return  (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
481+         return  (void *) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
478482    } else  if  (src0->type  == GGML_TYPE_F32 && src1->type  == GGML_TYPE_Q5_1) {
479-              return  (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
483+         return  (void *) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
480484    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_F16) {
481-              return  (void *) cpy_f32_f16<cpy_1_f32_f16>;
485+         return  (void *) cpy_f32_f16<cpy_1_f32_f16>;
482486    } else  if  (src0->type  == GGML_TYPE_F16 && src1->type  == GGML_TYPE_F32) {
483-              return  (void *) cpy_f32_f16<cpy_1_f16_f32>;
487+         return  (void *) cpy_f32_f16<cpy_1_f16_f32>;
484488    } else  {
485-         fprintf (stderr,  " %s: unsupported type combination (%s to %s)\n " 
489+         GGML_ABORT ( " %s: unsupported type combination (%s to %s)\n " 
486490                ggml_type_name (src0->type ), ggml_type_name (src1->type ));
487-         GGML_ABORT (" fatal error" 
488491    }
489492}
0 commit comments