Skip to content

Commit 9d554a8

Browse files
am17anpwilkin
authored andcommitted
CUDA: support for weight clamp in top-k norm (ggml-org#16702)
1 parent 3838596 commit 9d554a8

File tree

4 files changed

+60
-29
lines changed

4 files changed

+60
-29
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2976,7 +2976,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29762976
if (ops.size() == topk_moe_ops_with_norm.size() &&
29772977
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 8 })) {
29782978
ggml_tensor * softmax = cgraph->nodes[node_idx];
2979-
ggml_tensor * weights = cgraph->nodes[node_idx+8];
2979+
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
29802980

29812981
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
29822982
return true;
@@ -2986,7 +2986,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29862986
if (ops.size() == topk_moe_ops.size() &&
29872987
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
29882988
ggml_tensor * softmax = cgraph->nodes[node_idx];
2989-
ggml_tensor * weights = cgraph->nodes[node_idx+4];
2989+
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
29902990
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
29912991
return true;
29922992
}
@@ -3125,17 +3125,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31253125
if (!disable_fusion) {
31263126

31273127
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
3128-
ggml_tensor * weights = cgraph->nodes[i+8];
3129-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
3128+
ggml_tensor * weights = cgraph->nodes[i + 9];
3129+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
3130+
ggml_tensor * clamp = cgraph->nodes[i + 7];
31303131
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
3131-
/*delayed softmax*/ false);
3132-
i += 8;
3132+
/*delayed softmax*/ false, clamp);
3133+
i += 9;
31333134
continue;
31343135
}
31353136

31363137
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
3137-
ggml_tensor * weights = cgraph->nodes[i+4];
3138-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
3138+
ggml_tensor * weights = cgraph->nodes[i + 4];
3139+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
31393140
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
31403141
/*delayed softmax*/ false);
31413142
i += 4;

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "ggml.h"
33
#include "topk-moe.cuh"
44

5+
#include <cmath>
56
#include <initializer_list>
67

78
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6364
float * weights,
6465
int32_t * ids,
6566
const int n_rows,
66-
const int n_expert_used) {
67+
const int n_expert_used,
68+
const float clamp_val) {
6769
const int row = blockIdx.x * blockDim.y + threadIdx.y;
6870
if (row >= n_rows) {
6971
return;
@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
139141

140142
if constexpr (with_norm) {
141143
wt_sum = warp_reduce_sum(wt_sum);
144+
wt_sum = max(wt_sum, clamp_val);
142145
const float inv_sum = 1.0f / wt_sum;
143146

144147
for (int i = 0; i < experts_per_thread; i++) {
@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
157160
weights[idx] = output_weights[i];
158161
}
159162
}
163+
164+
if (!with_norm) {
165+
GGML_UNUSED(clamp_val);
166+
}
160167
}
161168

162169
template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
166173
int32_t * ids,
167174
const int n_rows,
168175
const int n_expert,
169-
const int n_expert_used) {
176+
const int n_expert_used,
177+
const float clamp_val) {
170178
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
171-
172179
const int rows_per_block = 4;
173180
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
174181
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
177184
switch (n_expert) {
178185
case 1:
179186
topk_moe_cuda<1, with_norm, delayed_softmax>
180-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
187+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
181188
break;
182189
case 2:
183190
topk_moe_cuda<2, with_norm, delayed_softmax>
184-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
191+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
185192
break;
186193
case 4:
187194
topk_moe_cuda<4, with_norm, delayed_softmax>
188-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
195+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
189196
break;
190197
case 8:
191198
topk_moe_cuda<8, with_norm, delayed_softmax>
192-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
199+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
193200
break;
194201
case 16:
195202
topk_moe_cuda<16, with_norm, delayed_softmax>
196-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
203+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
197204
break;
198205
case 32:
199206
topk_moe_cuda<32, with_norm, delayed_softmax>
200-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
207+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
201208
break;
202209
case 64:
203210
topk_moe_cuda<64, with_norm, delayed_softmax>
204-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
211+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
205212
break;
206213
case 128:
207214
topk_moe_cuda<128, with_norm, delayed_softmax>
208-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
215+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
209216
break;
210217
case 256:
211218
topk_moe_cuda<256, with_norm, delayed_softmax>
212-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
219+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
213220
break;
214221
case 512:
215222
topk_moe_cuda<512, with_norm, delayed_softmax>
216-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
223+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
217224
break;
218225
default:
219226
GGML_ASSERT(false && "fatal error");
@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
226233
ggml_tensor * weights,
227234
ggml_tensor * ids,
228235
const bool with_norm,
229-
const bool delayed_softmax) {
236+
const bool delayed_softmax,
237+
ggml_tensor * clamp) {
230238
GGML_ASSERT(logits->type == GGML_TYPE_F32);
231239
GGML_ASSERT(weights->type == GGML_TYPE_F32);
232240
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
242250

243251
const int n_expert_used = weights->ne[1];
244252

253+
float clamp_val = -INFINITY;
245254
if (with_norm) {
246-
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
255+
if (clamp) {
256+
clamp_val = ggml_get_op_params_f32(clamp, 0);
257+
}
258+
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
247259
} else {
260+
GGML_ASSERT(clamp == nullptr);
248261
if (delayed_softmax) {
249-
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
262+
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
263+
clamp_val);
250264
} else {
251-
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
265+
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
266+
clamp_val);
252267
}
253268
}
254269
}
255270

256-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
271+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
257272
float scale = 1.0f;
258273
float max_bias = 0.0f;
259274

@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
279294
return false;
280295
}
281296

297+
if (clamp) {
298+
if (clamp->op != GGML_OP_CLAMP) {
299+
return false;
300+
}
301+
float max_val = ggml_get_op_params_f32(clamp, 1);
302+
303+
if (max_val != INFINITY) {
304+
return false;
305+
}
306+
}
307+
308+
282309
return true;
283310
}
284311

285312
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
286313
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287314
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288-
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
315+
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
316+
GGML_OP_RESHAPE };
289317

290318
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291319
GGML_OP_VIEW, GGML_OP_GET_ROWS };

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
88
ggml_tensor * weights,
99
ggml_tensor * ids,
1010
const bool with_norm,
11-
const bool delayed_softmax = false);
11+
const bool delayed_softmax = false,
12+
ggml_tensor * weight_clamp = nullptr);
1213

13-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
14+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
1415

1516
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4712,6 +4712,7 @@ struct test_topk_moe: public test_case {
47124712
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
47134713
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
47144714

4715+
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
47154716
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
47164717
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
47174718
}

0 commit comments

Comments
 (0)