44
55#include < initializer_list>
66
7+ // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8+ template <int experts_per_thread, bool use_limit>
9+ __device__ void softmax_warp_inplace (float (&vals)[experts_per_thread], const int limit, const int lane) {
10+ float max_val = -INFINITY;
11+
12+ #pragma unroll
13+ for (int i = 0 ; i < experts_per_thread; i++) {
14+ const int idx = lane + i * WARP_SIZE;
15+ const bool active = !use_limit || (idx < limit);
16+ if (active) {
17+ max_val = max (max_val, vals[i]);
18+ }
19+ }
20+
21+ max_val = warp_reduce_max (max_val);
22+
23+ float sum = 0 .f ;
24+
25+ #pragma unroll
26+ for (int i = 0 ; i < experts_per_thread; i++) {
27+ const int idx = lane + i * WARP_SIZE;
28+ const bool active = !use_limit || (idx < limit);
29+ if (active) {
30+ const float val = expf (vals[i] - max_val);
31+ vals[i] = val;
32+ sum += val;
33+ } else {
34+ vals[i] = 0 .f ;
35+ }
36+ }
37+
38+ sum = warp_reduce_sum (sum);
39+
40+ const float inv_sum = 1 .0f / sum;
41+
42+ #pragma unroll
43+ for (int i = 0 ; i < experts_per_thread; i++) {
44+ const int idx = lane + i * WARP_SIZE;
45+ const bool active = !use_limit || (idx < limit);
46+ if (active) {
47+ vals[i] *= inv_sum;
48+ }
49+ }
50+ }
51+
752/*
853 This kernel does the following:
9- 1. softmax over the logits per token [n_experts, n_tokens]
54+ 1. optionally softmax over the logits per token [n_experts, n_tokens]
1055 2. argmax reduce over the top-k (n_experts_used) logits
1156 3. write weights + ids to global memory
12- 4. optionally normalize the weights
57+ 4. optionally normalize the weights or apply softmax over the selected logits
1358
1459 It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1560*/
16- template <int n_experts, bool with_norm>
61+ template <int n_experts, bool with_norm, bool delayed_softmax = false >
1762__launch_bounds__ (4 * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const float * logits,
1863 float * weights,
1964 int32_t * ids,
@@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
3075
3176 constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1 ;
3277
33- float logits_r [experts_per_thread];
78+ float wt [experts_per_thread];
3479
3580#pragma unroll
3681 for (int i = 0 ; i < n_experts; i += WARP_SIZE) {
37- const int expert = i + threadIdx .x ;
38- logits_r [i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
82+ const int expert = i + threadIdx .x ;
83+ wt [i / WARP_SIZE] = ( n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
3984 }
4085
41- float max_val = logits_r[0 ];
42-
43- #pragma unroll
44- for (int i = 1 ; i < experts_per_thread; i++) {
45- const float val = logits_r[i];
46- max_val = max (val, max_val);
86+ if constexpr (!delayed_softmax) {
87+ softmax_warp_inplace<experts_per_thread, false >(wt, n_experts, threadIdx .x );
4788 }
4889
49- max_val = warp_reduce_max (max_val);
50-
51- float wt[experts_per_thread];
52- float tmp = 0 .f ;
53-
54- #pragma unroll
55- for (int i = 0 ; i < experts_per_thread; i++) {
56- const float val = logits_r[i];
57- wt[i] = expf (val - max_val);
58- tmp += wt[i];
59- }
90+ // at this point, each thread holds either a portion of the softmax distribution
91+ // or the raw logits. We do the argmax reduce over n_expert_used, each time marking
92+ // the expert weight as -inf to exclude from the next iteration
6093
61- tmp = warp_reduce_sum (tmp) ;
94+ float wt_sum = 0 . f ;
6295
63- const float inv_sum = 1 . 0f / tmp ;
96+ float output_weights[experts_per_thread] ;
6497
6598#pragma unroll
6699 for (int i = 0 ; i < experts_per_thread; i++) {
67- wt [i] = wt[i] * inv_sum ;
100+ output_weights [i] = 0 . f ;
68101 }
69102
70- // at this point, each thread holds a portion of softmax,
71- // we do the argmax reduce over n_expert_used, each time marking
72- // the expert weight as -inf to exclude from the next iteration
73-
74- float wt_sum = 0 .f ;
75-
76- float output_weights[experts_per_thread];
77-
78103 for (int k = 0 ; k < n_expert_used; k++) {
79104 float max_val = wt[0 ];
80105 int max_expert = threadIdx .x ;
@@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
121146 }
122147 }
123148
149+ if constexpr (delayed_softmax) {
150+ softmax_warp_inplace<experts_per_thread, true >(output_weights, n_expert_used, threadIdx .x );
151+ }
152+
124153#pragma unroll
125154 for (int i = 0 ; i < experts_per_thread; i++) {
126155 const int idx = i * WARP_SIZE + threadIdx .x ;
@@ -130,58 +159,60 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
130159 }
131160}
132161
133- template <bool with_norm>
162+ template <bool with_norm, bool delayed_softmax = false >
134163static void launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
135164 const float * logits,
136165 float * weights,
137166 int32_t * ids,
138167 const int n_rows,
139168 const int n_expert,
140169 const int n_expert_used) {
170+ static_assert (!(with_norm && delayed_softmax), " delayed softmax is not supported with weight normalization" );
171+
141172 const int rows_per_block = 4 ;
142173 dim3 grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
143174 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
144175 cudaStream_t stream = ctx.stream ();
145176
146177 switch (n_expert) {
147178 case 1 :
148- topk_moe_cuda<1 , with_norm>
179+ topk_moe_cuda<1 , with_norm, delayed_softmax >
149180 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
150181 break ;
151182 case 2 :
152- topk_moe_cuda<2 , with_norm>
183+ topk_moe_cuda<2 , with_norm, delayed_softmax >
153184 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
154185 break ;
155186 case 4 :
156- topk_moe_cuda<4 , with_norm>
187+ topk_moe_cuda<4 , with_norm, delayed_softmax >
157188 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
158189 break ;
159190 case 8 :
160- topk_moe_cuda<8 , with_norm>
191+ topk_moe_cuda<8 , with_norm, delayed_softmax >
161192 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
162193 break ;
163194 case 16 :
164- topk_moe_cuda<16 , with_norm>
195+ topk_moe_cuda<16 , with_norm, delayed_softmax >
165196 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
166197 break ;
167198 case 32 :
168- topk_moe_cuda<32 , with_norm>
199+ topk_moe_cuda<32 , with_norm, delayed_softmax >
169200 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
170201 break ;
171202 case 64 :
172- topk_moe_cuda<64 , with_norm>
203+ topk_moe_cuda<64 , with_norm, delayed_softmax >
173204 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
174205 break ;
175206 case 128 :
176- topk_moe_cuda<128 , with_norm>
207+ topk_moe_cuda<128 , with_norm, delayed_softmax >
177208 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
178209 break ;
179210 case 256 :
180- topk_moe_cuda<256 , with_norm>
211+ topk_moe_cuda<256 , with_norm, delayed_softmax >
181212 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
182213 break ;
183214 case 512 :
184- topk_moe_cuda<512 , with_norm>
215+ topk_moe_cuda<512 , with_norm, delayed_softmax >
185216 <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
186217 break ;
187218 default :
@@ -194,15 +225,16 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
194225 const ggml_tensor * logits,
195226 ggml_tensor * weights,
196227 ggml_tensor * ids,
197- const bool with_norm) {
228+ const bool with_norm,
229+ const bool delayed_softmax) {
198230 GGML_ASSERT (logits->type == GGML_TYPE_F32);
199231 GGML_ASSERT (weights->type == GGML_TYPE_F32);
200232 GGML_ASSERT (ids->type == GGML_TYPE_I32);
201233
202234 const int n_experts = logits->ne [0 ];
203235 const int n_rows = logits->ne [1 ];
204236
205- const float * logits_d = (const float *) logits->src [ 0 ]-> data ;
237+ const float * logits_d = (const float *) logits->data ;
206238 float * weights_d = (float *) weights->data ;
207239 int32_t * ids_d = (int32_t *) ids->data ;
208240
@@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
213245 if (with_norm) {
214246 launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215247 } else {
216- launch_topk_moe_cuda<false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
248+ if (delayed_softmax) {
249+ launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
250+ } else {
251+ launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
252+ }
217253 }
218254}
219255
@@ -246,16 +282,27 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
246282 return true ;
247283}
248284
249- std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm) {
285+ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm, bool delayed_softmax ) {
250286 static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
251287 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
252288 GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
253289
254290 static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
255291 GGML_OP_VIEW, GGML_OP_GET_ROWS };
256292
293+ static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
294+ GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
295+ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
296+
297+ GGML_ASSERT (!norm || !delayed_softmax);
298+
299+ if (delayed_softmax) {
300+ return delayed_softmax_ops;
301+ }
302+
257303 if (norm) {
258304 return norm_ops;
259305 }
306+
260307 return no_norm_ops;
261308}
0 commit comments