22#include " ggml.h"
33#include " topk-moe.cuh"
44
5+ #include < cmath>
56#include < initializer_list>
67
78// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6364 float * weights,
6465 int32_t * ids,
6566 const int n_rows,
66- const int n_expert_used) {
67+ const int n_expert_used,
68+ const float clamp_val) {
6769 const int row = blockIdx .x * blockDim .y + threadIdx .y ;
6870 if (row >= n_rows) {
6971 return ;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
139141
140142 if constexpr (with_norm) {
141143 wt_sum = warp_reduce_sum (wt_sum);
144+ wt_sum = max (wt_sum, clamp_val);
142145 const float inv_sum = 1 .0f / wt_sum;
143146
144147 for (int i = 0 ; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
157160 weights[idx] = output_weights[i];
158161 }
159162 }
163+
164+ if (!with_norm) {
165+ GGML_UNUSED (clamp_val);
166+ }
160167}
161168
162169template <bool with_norm, bool delayed_softmax = false >
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
166173 int32_t * ids,
167174 const int n_rows,
168175 const int n_expert,
169- const int n_expert_used) {
176+ const int n_expert_used,
177+ const float clamp_val) {
170178 static_assert (!(with_norm && delayed_softmax), " delayed softmax is not supported with weight normalization" );
171-
172179 const int rows_per_block = 4 ;
173180 dim3 grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
174181 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
177184 switch (n_expert) {
178185 case 1 :
179186 topk_moe_cuda<1 , with_norm, delayed_softmax>
180- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
187+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
181188 break ;
182189 case 2 :
183190 topk_moe_cuda<2 , with_norm, delayed_softmax>
184- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
191+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
185192 break ;
186193 case 4 :
187194 topk_moe_cuda<4 , with_norm, delayed_softmax>
188- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
195+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
189196 break ;
190197 case 8 :
191198 topk_moe_cuda<8 , with_norm, delayed_softmax>
192- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
199+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
193200 break ;
194201 case 16 :
195202 topk_moe_cuda<16 , with_norm, delayed_softmax>
196- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
203+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
197204 break ;
198205 case 32 :
199206 topk_moe_cuda<32 , with_norm, delayed_softmax>
200- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
207+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
201208 break ;
202209 case 64 :
203210 topk_moe_cuda<64 , with_norm, delayed_softmax>
204- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
211+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
205212 break ;
206213 case 128 :
207214 topk_moe_cuda<128 , with_norm, delayed_softmax>
208- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
215+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
209216 break ;
210217 case 256 :
211218 topk_moe_cuda<256 , with_norm, delayed_softmax>
212- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
219+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
213220 break ;
214221 case 512 :
215222 topk_moe_cuda<512 , with_norm, delayed_softmax>
216- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
223+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
217224 break ;
218225 default :
219226 GGML_ASSERT (false && " fatal error" );
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
226233 ggml_tensor * weights,
227234 ggml_tensor * ids,
228235 const bool with_norm,
229- const bool delayed_softmax) {
236+ const bool delayed_softmax,
237+ ggml_tensor * clamp) {
230238 GGML_ASSERT (logits->type == GGML_TYPE_F32);
231239 GGML_ASSERT (weights->type == GGML_TYPE_F32);
232240 GGML_ASSERT (ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
242250
243251 const int n_expert_used = weights->ne [1 ];
244252
253+ float clamp_val = -INFINITY;
245254 if (with_norm) {
246- launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
255+ if (clamp) {
256+ clamp_val = ggml_get_op_params_f32 (clamp, 0 );
257+ }
258+ launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
247259 } else {
260+ GGML_ASSERT (clamp == nullptr );
248261 if (delayed_softmax) {
249- launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
262+ launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
263+ clamp_val);
250264 } else {
251- launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
265+ launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
266+ clamp_val);
252267 }
253268 }
254269}
255270
256- bool ggml_cuda_should_use_topk_moe (const ggml_tensor * softmax, const ggml_tensor * weights) {
271+ bool ggml_cuda_should_use_topk_moe (const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp ) {
257272 float scale = 1 .0f ;
258273 float max_bias = 0 .0f ;
259274
@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
279294 return false ;
280295 }
281296
297+ if (clamp) {
298+ if (clamp->op != GGML_OP_CLAMP) {
299+ return false ;
300+ }
301+ float max_val = ggml_get_op_params_f32 (clamp, 1 );
302+
303+ if (max_val != INFINITY) {
304+ return false ;
305+ }
306+ }
307+
308+
282309 return true ;
283310}
284311
285312std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm, bool delayed_softmax) {
286313 static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287314 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
315+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
316+ GGML_OP_RESHAPE };
289317
290318 static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291319 GGML_OP_VIEW, GGML_OP_GET_ROWS };
0 commit comments