@@ -439,7 +439,6 @@ static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullpt
439439struct ggml_tensor_extra_gpu {
440440 void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
441441 cudaEvent_t events[GGML_CUDA_MAX_DEVICES][MAX_STREAMS]; // events for synchronizing multiple GPUs
442- bool copied;
443442};
444443
445444// this is faster on Windows
@@ -4357,8 +4356,9 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
43574356
43584357// rope == RoPE == rotary positional embedding
43594358
4360- static __global__ void rope_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361- const int p_delta_rows, const float theta_scale) {
4359+ template <typename T, bool has_pos>
4360+ static __global__ void rope (const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4361+ const int p_delta_rows, const float theta_scale) {
43624362 const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
43634363
43644364 if (col >= ncols) {
@@ -4369,8 +4369,8 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43694369 const int i = row*ncols + col;
43704370 const int i2 = row/p_delta_rows;
43714371
4372- const int p = pos != nullptr ? pos[i2] : 0 ;
4373- const float p0 = p * freq_scale;
4372+ const int p = has_pos ? pos[i2] : 0 ;
4373+ const float p0 = p* freq_scale;
43744374 const float theta = p0*powf (theta_scale, col/2 );
43754375 const float sin_theta = sinf (theta);
43764376 const float cos_theta = cosf (theta);
@@ -4382,8 +4382,9 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
43824382 dst[i + 1 ] = x0*sin_theta + x1*cos_theta;
43834383}
43844384
4385- static __global__ void rope_neox_f32 (const float * x, float * dst, const int ncols, const int32_t * pos, const float freq_scale,
4386- const int p_delta_rows, const float theta_scale) {
4385+ template <typename T, bool has_pos>
4386+ static __global__ void rope_neox (const T * x, T * dst, const int ncols, const int32_t * pos, const float freq_scale,
4387+ const int p_delta_rows, const float theta_scale) {
43874388 const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
43884389
43894390 if (col >= ncols) {
@@ -4394,8 +4395,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
43944395 const int i = row*ncols + col/2 ;
43954396 const int i2 = row/p_delta_rows;
43964397
4397- const int p = pos != nullptr ? pos[i2] : 0 ;
4398- const float p0 = p * freq_scale;
4398+ const int p = has_pos ? pos[i2] : 0 ;
4399+ const float p0 = p* freq_scale;
43994400 const float theta = p0*powf (theta_scale, col/2 );
44004401 const float sin_theta = sinf (theta);
44014402 const float cos_theta = cosf (theta);
@@ -5371,22 +5372,32 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
53715372 scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
53725373}
53735374
5374- static void rope_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5375+ template <typename T>
5376+ static void rope_cuda (const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53755377 const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53765378 GGML_ASSERT (ncols % 2 == 0 );
53775379 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
53785380 const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
53795381 const dim3 block_nums (nrows, num_blocks_x, 1 );
5380- rope_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5382+ if (pos == nullptr ) {
5383+ rope<T, false ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5384+ } else {
5385+ rope<T, true ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5386+ }
53815387}
53825388
5383- static void rope_neox_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
5389+ template <typename T>
5390+ static void rope_neox_cuda (const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
53845391 const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
53855392 GGML_ASSERT (ncols % 2 == 0 );
53865393 const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
53875394 const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
53885395 const dim3 block_nums (nrows, num_blocks_x, 1 );
5389- rope_neox_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5396+ if (pos == nullptr ) {
5397+ rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5398+ } else {
5399+ rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols, pos, freq_scale, p_delta_rows, theta_scale);
5400+ }
53905401}
53915402
53925403static void rope_glm_f32_cuda (const float * x, float * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
@@ -6036,7 +6047,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
60366047 const int64_t ne0 = dst->ne [0 ];
60376048 const int64_t row_diff = row_high - row_low;
60386049
6039- float * src0_ddq_as_f32;
6050+ float * src0_ddq_as_f32 = nullptr ;
60406051 size_t src0_as = 0 ;
60416052
60426053 if (src0->type != GGML_TYPE_F32) {
@@ -6074,8 +6085,9 @@ inline void ggml_cuda_op_rope(
60746085 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
60756086 const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
60766087
6077- GGML_ASSERT (src0->type == GGML_TYPE_F32);
6078- GGML_ASSERT ( dst->type == GGML_TYPE_F32);
6088+ GGML_ASSERT (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
6089+ GGML_ASSERT ( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
6090+ GGML_ASSERT (src0->type == dst->type );
60796091
60806092 const int64_t ne00 = src0->ne [0 ];
60816093 const int64_t ne01 = src0->ne [1 ];
@@ -6093,23 +6105,12 @@ inline void ggml_cuda_op_rope(
60936105 memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
60946106
60956107 const float theta_scale = powf (freq_base, -2 .0f /n_dims);
6096- // const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
6097-
6098- GGML_ASSERT (src1->type == GGML_TYPE_I32);
6099- GGML_ASSERT (src1->ne [0 ] == ne2);
6100- GGML_ASSERT (src1->backend == GGML_BACKEND_GPU);
61016108
6102- int id;
6103- CUDA_CHECK (cudaGetDevice (&id));
6104-
6105- int * pos = nullptr ;
6109+ const int32_t * pos = nullptr ;
61066110 if ((mode & 1 ) == 0 ) {
6107- struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra ;
6108- pos = (int *) src1_extra->data_device [id];
6109- if (!src1_extra->copied ) {
6110- CUDA_CHECK (cudaMemcpyAsync (pos, src1->data , ggml_nbytes (src1), cudaMemcpyHostToDevice, main_stream));
6111- src1_extra->copied = true ;
6112- }
6111+ GGML_ASSERT (src1->type == GGML_TYPE_I32);
6112+ GGML_ASSERT (src1->ne [0 ] == ne2);
6113+ pos = (const int32_t *) src1_dd;
61136114 }
61146115
61156116 const bool is_neox = mode & 2 ;
@@ -6121,9 +6122,21 @@ inline void ggml_cuda_op_rope(
61216122 rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, n_ctx, main_stream);
61226123 } else if (is_neox) {
61236124 GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
6124- rope_neox_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6125+ if (src0->type == GGML_TYPE_F32) {
6126+ rope_neox_cuda ((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6127+ } else if (src0->type == GGML_TYPE_F16) {
6128+ rope_neox_cuda ((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6129+ } else {
6130+ GGML_ASSERT (false );
6131+ }
61256132 } else {
6126- rope_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6133+ if (src0->type == GGML_TYPE_F32) {
6134+ rope_cuda ((const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6135+ } else if (src0->type == GGML_TYPE_F16) {
6136+ rope_cuda ((const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, theta_scale, main_stream);
6137+ } else {
6138+ GGML_ASSERT (false );
6139+ }
61276140 }
61286141
61296142 (void ) src1;
@@ -6294,7 +6307,7 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
62946307 }
62956308}
62966309
6297- void ggml_cuda_set_peer_access (const int n_tokens) {
6310+ static void ggml_cuda_set_peer_access (const int n_tokens) {
62986311 static bool peer_access_enabled = false ;
62996312
63006313 const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
@@ -6622,27 +6635,27 @@ static void ggml_cuda_op_mul_mat(
66226635 }
66236636}
66246637
6625- void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6638+ static void ggml_cuda_add (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66266639 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_add);
66276640}
66286641
6629- void ggml_cuda_mul (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6642+ static void ggml_cuda_mul (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66306643 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_mul);
66316644}
66326645
6633- void ggml_cuda_gelu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6646+ static void ggml_cuda_gelu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66346647 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_gelu);
66356648}
66366649
6637- void ggml_cuda_silu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6650+ static void ggml_cuda_silu (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66386651 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_silu);
66396652}
66406653
6641- void ggml_cuda_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6654+ static void ggml_cuda_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66426655 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_norm);
66436656}
66446657
6645- void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6658+ static void ggml_cuda_rms_norm (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
66466659 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_rms_norm);
66476660}
66486661
@@ -6663,7 +6676,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
66636676 return false ;
66646677}
66656678
6666- void ggml_cuda_mul_mat_vec_p021 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6679+ static void ggml_cuda_mul_mat_vec_p021 (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66676680 GGML_ASSERT (ggml_is_permuted (src0) && ggml_is_permuted (src1));
66686681 GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
66696682 GGML_ASSERT (src0->nb [0 ] <= src0->nb [1 ] && src0->nb [2 ] <= src0->nb [3 ]); // 0213 permutation
@@ -6692,7 +6705,7 @@ void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tensor * sr
66926705 ggml_mul_mat_p021_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
66936706}
66946707
6695- void ggml_cuda_mul_mat_vec_nc (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
6708+ static void ggml_cuda_mul_mat_vec_nc (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
66966709 GGML_ASSERT (!ggml_is_contiguous (src0) && ggml_is_contiguous (src1));
66976710 GGML_ASSERT (!ggml_is_permuted (src0));
66986711 GGML_ASSERT (src0->backend != GGML_BACKEND_GPU_SPLIT);
@@ -6726,7 +6739,7 @@ void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1
67266739 ggml_mul_mat_vec_nc_f16_f32_cuda (src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
67276740}
67286741
6729- void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6742+ static void ggml_cuda_mul_mat (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67306743 bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
67316744 src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
67326745
@@ -6770,11 +6783,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
67706783 }
67716784}
67726785
6773- void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6786+ static void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67746787 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_scale);
67756788}
67766789
6777- void ggml_cuda_cpy (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6790+ static void ggml_cuda_cpy (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
67786791 const int64_t ne = ggml_nelements (src0);
67796792 GGML_ASSERT (ne == ggml_nelements (src1));
67806793
@@ -6822,29 +6835,29 @@ void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
68226835 (void ) dst;
68236836}
68246837
6825- void ggml_cuda_dup (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6838+ static void ggml_cuda_dup (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68266839 ggml_cuda_cpy (src0, dst, nullptr );
68276840 (void ) src1;
68286841}
68296842
6830- void ggml_cuda_diag_mask_inf (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6843+ static void ggml_cuda_diag_mask_inf (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68316844 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_diag_mask_inf);
68326845}
68336846
6834- void ggml_cuda_soft_max (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6847+ static void ggml_cuda_soft_max (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68356848 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_soft_max);
68366849}
68376850
6838- void ggml_cuda_rope (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6851+ static void ggml_cuda_rope (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68396852 GGML_ASSERT (ggml_is_contiguous (src0)); // TODO: this restriction is temporary until non-cont support is implemented
68406853 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_rope);
68416854}
68426855
6843- void ggml_cuda_alibi (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6856+ static void ggml_cuda_alibi (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68446857 ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_alibi);
68456858}
68466859
6847- void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
6860+ static void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
68486861 (void ) src0;
68496862 (void ) src1;
68506863 (void ) dst;
@@ -6967,11 +6980,13 @@ static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
69676980 return extra;
69686981}
69696982
6970- void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
6983+ static void ggml_cuda_assign_buffers_impl (struct ggml_tensor * tensor, bool scratch, bool force_inplace, bool no_alloc) {
69716984 if (scratch && g_scratch_size == 0 ) {
69726985 return ;
69736986 }
69746987
6988+ tensor->backend = GGML_BACKEND_GPU;
6989+
69756990 // recursively assign CUDA buffers until a compute tensor is found
69766991 if (tensor->src [0 ] != nullptr && tensor->src [0 ]->backend == GGML_BACKEND_CPU) {
69776992 const ggml_op src0_op = tensor->src [0 ]->op ;
@@ -6983,8 +6998,6 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
69836998 ggml_cuda_assign_buffers_impl (tensor->src [1 ], scratch, force_inplace, no_alloc);
69846999 }
69857000
6986- tensor->backend = GGML_BACKEND_GPU;
6987-
69887001 if (scratch && no_alloc) {
69897002 return ;
69907003 }
@@ -7069,6 +7082,15 @@ void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset)
70697082 tensor->extra = extra;
70707083}
70717084
7085+ void ggml_cuda_copy_to_device (struct ggml_tensor * tensor) {
7086+ GGML_ASSERT (tensor->backend == GGML_BACKEND_GPU);
7087+ GGML_ASSERT (ggml_is_contiguous (tensor));
7088+
7089+ struct ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra ;
7090+ CUDA_CHECK (ggml_cuda_set_device (g_main_device));
7091+ CUDA_CHECK (cudaMemcpy (extra->data_device [g_main_device], tensor->data , ggml_nbytes (tensor), cudaMemcpyHostToDevice));
7092+ }
7093+
70727094void ggml_cuda_assign_buffers (struct ggml_tensor * tensor) {
70737095 ggml_cuda_assign_buffers_impl (tensor, true , false , false );
70747096}
0 commit comments