@@ -215,40 +215,6 @@ static __global__ void rms_norm_back_f32(
215215    }
216216}
217217
218- template  <int  block_size>
219- static  __global__  void  fused_rms_norm_f32 (const  float  * x, const  float  * y, float  * dst, const  int  ncols, const  float  eps) {
220-     const  int  row = blockIdx .x *blockDim .y  + threadIdx .y ;
221-     const  int  tid = threadIdx .x ;
222- 
223-     float  tmp = 0 .0f ; //  partial sum for thread in warp
224- 
225-     for  (int  col = tid; col < ncols; col += block_size) {
226-         const  float  xi = x[row*ncols + col];
227-         tmp += xi * xi;
228-     }
229- 
230-     //  sum up partial sums
231-     tmp = warp_reduce_sum (tmp);
232-     if  (block_size > WARP_SIZE) {
233-         __shared__  float  s_sum[32 ];
234-         int  warp_id = threadIdx .x  / WARP_SIZE;
235-         int  lane_id = threadIdx .x  % WARP_SIZE;
236-         if  (lane_id == 0 ) {
237-             s_sum[warp_id] = tmp;
238-         }
239-         __syncthreads ();
240-         tmp = s_sum[lane_id];
241-         tmp = warp_reduce_sum (tmp);
242-     }
243- 
244-     const  float  mean = tmp / ncols;
245-     const  float  scale = rsqrtf (mean + eps);
246- 
247-     for  (int  col = tid; col < ncols; col += block_size) {
248-         dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
249-     }
250- }
251- 
252218//  template <int block_size>
253219//  static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
254220//      const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -395,19 +361,6 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
395361    }
396362}
397363
398- 
399- static  void  fused_rms_norm_f32_cuda (const  float  * x, const  float  * y, float  * dst,
400-         const  int  ncols, const  int  nrows, const  float  eps, cudaStream_t stream) {
401-     GGML_ASSERT (ncols % WARP_SIZE == 0 );
402-     if  (ncols < 1024 ) {
403-         const  dim3  block_dims (WARP_SIZE, 1 , 1 );
404-         fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
405-     } else  {
406-         const  dim3  block_dims (1024 , 1 , 1 );
407-         fused_rms_norm_f32<1024 ><<<nrows, block_dims, 0 , stream>>> (x, y, dst, ncols, eps);
408-     }
409- }
410- 
411364static  void  l2_norm_f32_cuda (
412365        const  float  * x, float  * dst, const  int  ncols, const  int  nrows, const  int  nchannels, const  int  nsamples,
413366        const  int64_t  stride_row, const  int64_t  stride_channel, const  int64_t  stride_sample, const  float  eps, cudaStream_t stream) {
@@ -567,36 +520,6 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
567520    rms_norm_back_f32_cuda (grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
568521}
569522
570- 
571- void  ggml_cuda_op_fused_rms_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
572-     if  (!dst->src [1 ]) {
573-         ggml_cuda_op_rms_norm (ctx, dst);
574-         return ;
575-     }
576-     const  ggml_tensor * src0 = dst->src [0 ];
577-     const  ggml_tensor * src1 = dst->src [1 ];
578-     const  float  * src0_d = (const  float  *)src0->data ;
579-     const  float  * src1_d = (const  float  *)src1->data ;
580-     float  * dst_d = (float  *)dst->data ;
581-     cudaStream_t stream = ctx.stream ();
582- 
583-     GGML_ASSERT (ggml_is_contiguous (src0));
584- 
585-     GGML_ASSERT (src0->type  == GGML_TYPE_F32);
586-     GGML_ASSERT (src1->type  == GGML_TYPE_F32);
587-     GGML_ASSERT ( dst->type  == GGML_TYPE_F32);
588-     GGML_ASSERT (src0->ne [0 ] == src1->ne [0 ]);
589-     GGML_ASSERT (ggml_nrows (src1) == 1 );
590- 
591-     const  int64_t  ne00 = src0->ne [0 ];
592-     const  int64_t  nrows = ggml_nrows (src0);
593- 
594-     float  eps;
595-     memcpy (&eps, dst->op_params , sizeof (float ));
596- 
597-     fused_rms_norm_f32_cuda (src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
598- }
599- 
600523void  ggml_cuda_op_l2_norm (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
601524    const  ggml_tensor * src0 = dst->src [0 ];
602525    const  float  * src0_d = (const  float  *) src0->data ;
0 commit comments