1+ #include " ggml-cuda/common.cuh"
12#include " ggml.h"
23#include " topk-moe.cuh"
34
1011 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1112*/
1213template <size_t n_experts>
13- __global__ void topk_moe_cuda (const float * logits,
14- float * weights,
15- int32_t * ids,
16- const int n_rows,
17- const int n_expert_used) {
14+ __launch_bounds__ ( 4 * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const float * logits,
15+ float * weights,
16+ int32_t * ids,
17+ const int n_rows,
18+ const int n_expert_used) {
1819 const int row = blockIdx .x * blockDim .y + threadIdx .y ;
1920 if (row >= n_rows) {
2021 return ;
2122 }
23+
2224 logits += n_experts * row;
23- ids += n_experts * row;
2425 weights += n_expert_used * row;
26+ ids += n_experts * row;
27+
28+ constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1 ;
2529
26- constexpr int experts_per_thread = (n_experts > 32 ) ? n_experts / 32 : 1 ;
30+ float logits_r[ experts_per_thread] ;
2731
28- const int start_expert = threadIdx .x * experts_per_thread;
29- const int end_expert = (threadIdx .x + 1 ) * experts_per_thread;
30- float max_val = -INFINITY;
32+ #pragma unroll
33+ for (int i = 0 ; i < n_experts; i += WARP_SIZE) {
34+ const int expert = i + threadIdx .x ;
35+ logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
36+ }
37+
38+ float max_val = -INFINITY;
3139
3240#pragma unroll
3341 for (int i = 0 ; i < experts_per_thread; i++) {
34- const int expert = start_expert + i;
35- const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
36- max_val = max (val, max_val);
42+ const float val = logits_r[i];
43+ max_val = max (val, max_val);
3744 }
3845
3946 max_val = warp_reduce_max (max_val);
@@ -43,9 +50,8 @@ __global__ void topk_moe_cuda(const float * logits,
4350
4451#pragma unroll
4552 for (int i = 0 ; i < experts_per_thread; i++) {
46- const int expert = start_expert + i;
47- const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
48- wt[i] = expf (val - max_val);
53+ const float val = logits_r[i];
54+ wt[i] = expf (val - max_val);
4955 tmp += wt[i];
5056 }
5157
@@ -64,29 +70,29 @@ __global__ void topk_moe_cuda(const float * logits,
6470
6571 for (int k = 0 ; k < n_expert_used; k++) {
6672 float max_val = wt[0 ];
67- int max_expert = start_expert ;
73+ int max_expert = threadIdx . x ;
6874
6975#pragma unroll
7076 for (int i = 1 ; i < experts_per_thread; i++) {
71- const int expert = start_expert + i;
72- if (wt[i] > max_val) {
77+ const int expert = threadIdx . x + i * WARP_SIZE ;
78+ if (expert < n_experts && wt[i] > max_val) {
7379 max_val = wt[i];
7480 max_expert = expert;
7581 }
7682 }
7783
7884#pragma unroll
79- for (int mask = warpSize / 2 ; mask > 0 ; mask /= 2 ) {
80- const float val = __shfl_xor_sync (0xFFFFFFFF , max_val, mask, warpSize );
81- const int expert = __shfl_xor_sync (0xFFFFFFFF , max_expert, mask, warpSize );
85+ for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask /= 2 ) {
86+ const float val = __shfl_xor_sync (0xFFFFFFFF , max_val, mask, WARP_SIZE );
87+ const int expert = __shfl_xor_sync (0xFFFFFFFF , max_expert, mask, WARP_SIZE );
8288 if (val > max_val) {
8389 max_val = val;
8490 max_expert = expert;
8591 }
8692 }
8793
88- if (max_expert >= start_expert && max_expert < end_expert ) {
89- wt[max_expert - start_expert ] = -INFINITY;
94+ if (( max_expert & (WARP_SIZE - 1 )) == threadIdx . x ) {
95+ wt[max_expert / WARP_SIZE ] = -INFINITY;
9096
9197 weights[k] = max_val;
9298 ids[k] = max_expert;
@@ -103,7 +109,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
103109 const int n_expert_used) {
104110 const int rows_per_block = 4 ;
105111 dim3 grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
106- dim3 block_dims (32 , rows_per_block, 1 );
112+ dim3 block_dims (WARP_SIZE , rows_per_block, 1 );
107113 cudaStream_t stream = ctx.stream ();
108114
109115 switch (n_expert) {
@@ -151,12 +157,14 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
151157 GGML_ASSERT (weights->type == GGML_TYPE_F32);
152158 GGML_ASSERT (ids->type == GGML_TYPE_I32);
153159
160+ const int n_experts = logits->ne [0 ];
161+ const int n_rows = logits->ne [1 ];
162+
154163 const float * logits_d = (const float *) logits->src [0 ]->data ;
155164 float * weights_d = (float *) weights->data ;
156165 int32_t * ids_d = (int32_t *) ids->data ;
157166
158- const int n_experts = logits->ne [0 ];
159- const int n_rows = logits->ne [1 ];
167+ GGML_ASSERT (ids->nb [1 ] / ggml_type_size (ids->type ) == (size_t ) n_experts);
160168
161169 cudaStream_t stream = ctx.stream ();
162170
@@ -165,13 +173,17 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
165173 launch_topk_moe_cuda (ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
166174}
167175
168- bool ggml_cuda_should_use_topk_moe (const ggml_tensor * softmax) {
176+ bool ggml_cuda_should_use_topk_moe (const ggml_tensor * softmax, const ggml_tensor * weights ) {
169177 float scale = 1 .0f ;
170178 float max_bias = 0 .0f ;
171179
172180 memcpy (&scale, (const float *) softmax->op_params + 0 , sizeof (float ));
173181 memcpy (&max_bias, (const float *) softmax->op_params + 1 , sizeof (float ));
174182
183+ if (!ggml_is_contiguous (softmax->src [0 ]) || !ggml_is_contiguous (weights)) {
184+ return false ;
185+ }
186+
175187 if (scale != 1 .0f || max_bias != 0 .0f ) {
176188 return false ;
177189 }
0 commit comments