1+ #include  " ggml-cuda/common.cuh" 
12#include  " ggml.h" 
23#include  " topk-moe.cuh" 
34
1011    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models 
1112*/ 
1213template  <size_t  n_experts>
13- __global__  void  topk_moe_cuda (const  float  * logits,
14-                               float  *       weights,
15-                               int32_t  *     ids,
16-                               const  int      n_rows,
17-                               const  int      n_expert_used) {
14+ __launch_bounds__ ( 4  * WARP_SIZE,  1 )  __global__ void topk_moe_cuda(const  float  * logits,
15+                                                                    float  *       weights,
16+                                                                    int32_t  *     ids,
17+                                                                    const  int      n_rows,
18+                                                                    const  int      n_expert_used) {
1819    const  int  row = blockIdx .x  * blockDim .y  + threadIdx .y ;
1920    if  (row >= n_rows) {
2021        return ;
2122    }
23+ 
2224    logits += n_experts * row;
23-     ids += n_experts * row;
2425    weights += n_expert_used * row;
26+     ids += n_experts * row;
27+ 
28+     constexpr  int  experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1 ;
2529
26-     constexpr   int   experts_per_thread = (n_experts >  32 ) ? n_experts /  32  :  1 ;
30+     float  logits_r[ experts_per_thread] ;
2731
28-     const  int  start_expert = threadIdx .x  * experts_per_thread;
29-     const  int  end_expert   = (threadIdx .x  + 1 ) * experts_per_thread;
30-     float      max_val      = -INFINITY;
32+ #pragma  unroll
33+     for  (int  i = 0 ; i < n_experts; i += WARP_SIZE) {
34+         const  int  expert        = i + threadIdx .x ;
35+         logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
36+     }
37+ 
38+     float  max_val = -INFINITY;
3139
3240#pragma  unroll
3341    for  (int  i = 0 ; i < experts_per_thread; i++) {
34-         const  int    expert = start_expert + i;
35-         const  float  val    = (expert < n_experts) ? logits[expert] : -INFINITY;
36-         max_val            = max (val, max_val);
42+         const  float  val = logits_r[i];
43+         max_val         = max (val, max_val);
3744    }
3845
3946    max_val = warp_reduce_max (max_val);
@@ -43,9 +50,8 @@ __global__ void topk_moe_cuda(const float * logits,
4350
4451#pragma  unroll
4552    for  (int  i = 0 ; i < experts_per_thread; i++) {
46-         const  int    expert = start_expert + i;
47-         const  float  val    = (expert < n_experts) ? logits[expert] : -INFINITY;
48-         wt[i]              = expf (val - max_val);
53+         const  float  val = logits_r[i];
54+         wt[i]           = expf (val - max_val);
4955        tmp += wt[i];
5056    }
5157
@@ -64,29 +70,29 @@ __global__ void topk_moe_cuda(const float * logits,
6470
6571    for  (int  k = 0 ; k < n_expert_used; k++) {
6672        float  max_val    = wt[0 ];
67-         int    max_expert = start_expert ;
73+         int    max_expert = threadIdx . x ;
6874
6975#pragma  unroll
7076        for  (int  i = 1 ; i < experts_per_thread; i++) {
71-             const  int  expert = start_expert  + i;
72-             if  (wt[i] > max_val) {
77+             const  int  expert = threadIdx . x  + i * WARP_SIZE ;
78+             if  (expert < n_experts &&  wt[i] > max_val) {
7379                max_val    = wt[i];
7480                max_expert = expert;
7581            }
7682        }
7783
7884#pragma  unroll
79-         for  (int  mask = warpSize  / 2 ; mask > 0 ; mask /= 2 ) {
80-             const  float  val    = __shfl_xor_sync (0xFFFFFFFF , max_val, mask, warpSize );
81-             const  int    expert = __shfl_xor_sync (0xFFFFFFFF , max_expert, mask, warpSize );
85+         for  (int  mask = WARP_SIZE  / 2 ; mask > 0 ; mask /= 2 ) {
86+             const  float  val    = __shfl_xor_sync (0xFFFFFFFF , max_val, mask, WARP_SIZE );
87+             const  int    expert = __shfl_xor_sync (0xFFFFFFFF , max_expert, mask, WARP_SIZE );
8288            if  (val > max_val) {
8389                max_val    = val;
8490                max_expert = expert;
8591            }
8692        }
8793
88-         if  (max_expert >= start_expert && max_expert < end_expert ) {
89-             wt[max_expert - start_expert ] = -INFINITY;
94+         if  (( max_expert & (WARP_SIZE -  1 )) ==  threadIdx . x ) {
95+             wt[max_expert / WARP_SIZE ] = -INFINITY;
9096
9197            weights[k] = max_val;
9298            ids[k]     = max_expert;
@@ -103,7 +109,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
103109                                 const  int                    n_expert_used) {
104110    const  int     rows_per_block = 4 ;
105111    dim3          grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
106-     dim3          block_dims (32 , rows_per_block, 1 );
112+     dim3          block_dims (WARP_SIZE , rows_per_block, 1 );
107113    cudaStream_t stream = ctx.stream ();
108114
109115    switch  (n_expert) {
@@ -151,12 +157,14 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
151157    GGML_ASSERT (weights->type  == GGML_TYPE_F32);
152158    GGML_ASSERT (ids->type  == GGML_TYPE_I32);
153159
160+     const  int  n_experts = logits->ne [0 ];
161+     const  int  n_rows    = logits->ne [1 ];
162+ 
154163    const  float  * logits_d  = (const  float  *) logits->src [0 ]->data ;
155164    float  *       weights_d = (float  *) weights->data ;
156165    int32_t  *     ids_d     = (int32_t  *) ids->data ;
157166
158-     const  int  n_experts = logits->ne [0 ];
159-     const  int  n_rows    = logits->ne [1 ];
167+     GGML_ASSERT (ids->nb [1 ] / ggml_type_size (ids->type ) == (size_t ) n_experts);
160168
161169    cudaStream_t stream = ctx.stream ();
162170
@@ -165,13 +173,17 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
165173    launch_topk_moe_cuda (ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
166174}
167175
168- bool  ggml_cuda_should_use_topk_moe (const  ggml_tensor * softmax) {
176+ bool  ggml_cuda_should_use_topk_moe (const  ggml_tensor * softmax,  const  ggml_tensor * weights ) {
169177    float  scale    = 1 .0f ;
170178    float  max_bias = 0 .0f ;
171179
172180    memcpy (&scale, (const  float  *) softmax->op_params  + 0 , sizeof (float ));
173181    memcpy (&max_bias, (const  float  *) softmax->op_params  + 1 , sizeof (float ));
174182
183+     if  (!ggml_is_contiguous (softmax->src [0 ]) || !ggml_is_contiguous (weights)) {
184+         return  false ;
185+     }
186+ 
175187    if  (scale != 1 .0f  || max_bias != 0 .0f ) {
176188        return  false ;
177189    }
0 commit comments