44
55#include < initializer_list>
66
7- // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8- template <int experts_per_thread, bool use_limit>
9- __device__ void softmax_warp_inplace (float (&vals)[experts_per_thread], const int limit, const int lane) {
10- float max_val = -INFINITY;
11-
12- #pragma unroll
13- for (int i = 0 ; i < experts_per_thread; i++) {
14- const int idx = lane + i * WARP_SIZE;
15- const bool active = !use_limit || (idx < limit);
16- if (active) {
17- max_val = max (max_val, vals[i]);
18- }
19- }
20-
21- max_val = warp_reduce_max (max_val);
22-
23- float sum = 0 .f ;
24-
25- #pragma unroll
26- for (int i = 0 ; i < experts_per_thread; i++) {
27- const int idx = lane + i * WARP_SIZE;
28- const bool active = !use_limit || (idx < limit);
29- if (active) {
30- const float val = expf (vals[i] - max_val);
31- vals[i] = val;
32- sum += val;
33- } else {
34- vals[i] = 0 .f ;
35- }
36- }
37-
38- sum = warp_reduce_sum (sum);
39-
40- const float inv_sum = 1 .0f / sum;
41-
42- #pragma unroll
43- for (int i = 0 ; i < experts_per_thread; i++) {
44- const int idx = lane + i * WARP_SIZE;
45- const bool active = !use_limit || (idx < limit);
46- if (active) {
47- vals[i] *= inv_sum;
48- }
49- }
50- }
51-
527/*
538 This kernel does the following:
54- 1. optionally softmax over the logits per token [n_experts, n_tokens]
9+ 1. softmax over the logits per token [n_experts, n_tokens]
5510 2. argmax reduce over the top-k (n_experts_used) logits
5611 3. write weights + ids to global memory
57- 4. optionally normalize the weights or apply softmax over the selected logits
12+ 4. optionally normalize the weights
5813
5914 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
6015*/
61- template <int n_experts, bool with_norm, bool delayed_softmax = false >
16+ template <int n_experts, bool with_norm>
6217__launch_bounds__ (4 * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const float * logits,
6318 float * weights,
6419 int32_t * ids,
@@ -75,31 +30,51 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7530
7631 constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1 ;
7732
78- float wt [experts_per_thread];
33+ float logits_r [experts_per_thread];
7934
8035#pragma unroll
8136 for (int i = 0 ; i < n_experts; i += WARP_SIZE) {
82- const int expert = i + threadIdx .x ;
83- wt [i / WARP_SIZE] = ( n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
37+ const int expert = i + threadIdx .x ;
38+ logits_r [i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
8439 }
8540
86- if constexpr (!delayed_softmax) {
87- softmax_warp_inplace<experts_per_thread, false >(wt, n_experts, threadIdx .x );
41+ float max_val = logits_r[0 ];
42+
43+ #pragma unroll
44+ for (int i = 1 ; i < experts_per_thread; i++) {
45+ const float val = logits_r[i];
46+ max_val = max (val, max_val);
8847 }
8948
90- // at this point, each thread holds either a portion of the softmax distribution
91- // or the raw logits. We do the argmax reduce over n_expert_used, each time marking
92- // the expert weight as -inf to exclude from the next iteration
49+ max_val = warp_reduce_max (max_val);
9350
94- float wt_sum = 0 .f ;
51+ float wt[experts_per_thread];
52+ float tmp = 0 .f ;
9553
96- float output_weights[experts_per_thread];
54+ #pragma unroll
55+ for (int i = 0 ; i < experts_per_thread; i++) {
56+ const float val = logits_r[i];
57+ wt[i] = expf (val - max_val);
58+ tmp += wt[i];
59+ }
60+
61+ tmp = warp_reduce_sum (tmp);
62+
63+ const float inv_sum = 1 .0f / tmp;
9764
9865#pragma unroll
9966 for (int i = 0 ; i < experts_per_thread; i++) {
100- output_weights [i] = 0 . f ;
67+ wt [i] = wt[i] * inv_sum ;
10168 }
10269
70+ // at this point, each thread holds a portion of softmax,
71+ // we do the argmax reduce over n_expert_used, each time marking
72+ // the expert weight as -inf to exclude from the next iteration
73+
74+ float wt_sum = 0 .f ;
75+
76+ float output_weights[experts_per_thread];
77+
10378 for (int k = 0 ; k < n_expert_used; k++) {
10479 float max_val = wt[0 ];
10580 int max_expert = threadIdx .x ;
@@ -146,10 +121,6 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
146121 }
147122 }
148123
149- if constexpr (delayed_softmax) {
150- softmax_warp_inplace<experts_per_thread, true >(output_weights, n_expert_used, threadIdx .x );
151- }
152-
153124#pragma unroll
154125 for (int i = 0 ; i < experts_per_thread; i++) {
155126 const int idx = i * WARP_SIZE + threadIdx .x ;
@@ -159,60 +130,58 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
159130 }
160131}
161132
162- template <bool with_norm, bool delayed_softmax = false >
133+ template <bool with_norm>
163134static void launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
164135 const float * logits,
165136 float * weights,
166137 int32_t * ids,
167138 const int n_rows,
168139 const int n_expert,
169140 const int n_expert_used) {
170- static_assert (!(with_norm && delayed_softmax), " delayed softmax is not supported with weight normalization" );
171-
172141 const int rows_per_block = 4 ;
173142 dim3 grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
174143 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
175144 cudaStream_t stream = ctx.stream ();
176145
177146 switch (n_expert) {
178147 case 1 :
179- topk_moe_cuda<1 , with_norm, delayed_softmax >
148+ topk_moe_cuda<1 , with_norm>
180149 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
181150 break ;
182151 case 2 :
183- topk_moe_cuda<2 , with_norm, delayed_softmax >
152+ topk_moe_cuda<2 , with_norm>
184153 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
185154 break ;
186155 case 4 :
187- topk_moe_cuda<4 , with_norm, delayed_softmax >
156+ topk_moe_cuda<4 , with_norm>
188157 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
189158 break ;
190159 case 8 :
191- topk_moe_cuda<8 , with_norm, delayed_softmax >
160+ topk_moe_cuda<8 , with_norm>
192161 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
193162 break ;
194163 case 16 :
195- topk_moe_cuda<16 , with_norm, delayed_softmax >
164+ topk_moe_cuda<16 , with_norm>
196165 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
197166 break ;
198167 case 32 :
199- topk_moe_cuda<32 , with_norm, delayed_softmax >
168+ topk_moe_cuda<32 , with_norm>
200169 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
201170 break ;
202171 case 64 :
203- topk_moe_cuda<64 , with_norm, delayed_softmax >
172+ topk_moe_cuda<64 , with_norm>
204173 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
205174 break ;
206175 case 128 :
207- topk_moe_cuda<128 , with_norm, delayed_softmax >
176+ topk_moe_cuda<128 , with_norm>
208177 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
209178 break ;
210179 case 256 :
211- topk_moe_cuda<256 , with_norm, delayed_softmax >
180+ topk_moe_cuda<256 , with_norm>
212181 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
213182 break ;
214183 case 512 :
215- topk_moe_cuda<512 , with_norm, delayed_softmax >
184+ topk_moe_cuda<512 , with_norm>
216185 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
217186 break ;
218187 default :
@@ -225,16 +194,15 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
225194 const ggml_tensor * logits,
226195 ggml_tensor * weights,
227196 ggml_tensor * ids,
228- const bool with_norm,
229- const bool delayed_softmax) {
197+ const bool with_norm) {
230198 GGML_ASSERT (logits->type == GGML_TYPE_F32);
231199 GGML_ASSERT (weights->type == GGML_TYPE_F32);
232200 GGML_ASSERT (ids->type == GGML_TYPE_I32);
233201
234202 const int n_experts = logits->ne [0 ];
235203 const int n_rows = logits->ne [1 ];
236204
237- const float * logits_d = (const float *) logits->data ;
205+ const float * logits_d = (const float *) logits->src [ 0 ]-> data ;
238206 float * weights_d = (float *) weights->data ;
239207 int32_t * ids_d = (int32_t *) ids->data ;
240208
@@ -245,11 +213,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
245213 if (with_norm) {
246214 launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
247215 } else {
248- if (delayed_softmax) {
249- launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
250- } else {
251- launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
252- }
216+ launch_topk_moe_cuda<false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
253217 }
254218}
255219
@@ -282,27 +246,16 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
282246 return true ;
283247}
284248
285- std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm, bool delayed_softmax ) {
249+ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm) {
286250 static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287251 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288252 GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
289253
290254 static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291255 GGML_OP_VIEW, GGML_OP_GET_ROWS };
292256
293- static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
294- GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
295- GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
296-
297- GGML_ASSERT (!norm || !delayed_softmax);
298-
299- if (delayed_softmax) {
300- return delayed_softmax_ops;
301- }
302-
303257 if (norm) {
304258 return norm_ops;
305259 }
306-
307260 return no_norm_ops;
308261}
0 commit comments