Skip to content

Commit 8c01a63

Browse files
am17anpwilkin
authored andcommitted
CUDA: topk-moe: add optional parameter for gpt-oss (ggml-org#16649)
1 parent 5f157a9 commit 8c01a63

File tree

4 files changed

+153
-62
lines changed

4 files changed

+153
-62
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2824,8 +2824,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28242824
#endif
28252825

28262826
//TODO: remove special case once ggml_can_fuse can handle empty nodes
2827-
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2828-
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
2827+
std::initializer_list<enum ggml_op> topk_moe_ops =
2828+
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
2829+
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
2830+
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
2831+
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
2832+
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
28292833

28302834
if (ops.size() == topk_moe_ops_with_norm.size() &&
28312835
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
@@ -2846,6 +2850,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28462850
}
28472851
}
28482852

2853+
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
2854+
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
2855+
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
2856+
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
2857+
2858+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2859+
return true;
2860+
}
2861+
}
2862+
28492863
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28502864
return false;
28512865
}
@@ -2939,19 +2953,32 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29392953
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
29402954
ggml_tensor * weights = cgraph->nodes[i+8];
29412955
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2942-
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2956+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
2957+
/*delayed softmax*/ false);
29432958
i += 8;
29442959
continue;
29452960
}
29462961

29472962
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
29482963
ggml_tensor * weights = cgraph->nodes[i+4];
29492964
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2950-
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
2965+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
2966+
/*delayed softmax*/ false);
29512967
i += 4;
29522968
continue;
29532969
}
29542970

2971+
if (ggml_cuda_can_fuse(cgraph, i,
2972+
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
2973+
ggml_tensor * weights = cgraph->nodes[i + 5];
2974+
ggml_tensor * ids = cgraph->nodes[i + 1];
2975+
2976+
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
2977+
/*delayed_softmax*/ true);
2978+
i += 5;
2979+
continue;
2980+
}
2981+
29552982
if (node->op == GGML_OP_ADD) {
29562983
int n_fuse = 0;
29572984
ggml_op ops[8];

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 96 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,61 @@
44

55
#include <initializer_list>
66

7+
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8+
template <int experts_per_thread, bool use_limit>
9+
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
10+
float max_val = -INFINITY;
11+
12+
#pragma unroll
13+
for (int i = 0; i < experts_per_thread; i++) {
14+
const int idx = lane + i * WARP_SIZE;
15+
const bool active = !use_limit || (idx < limit);
16+
if (active) {
17+
max_val = max(max_val, vals[i]);
18+
}
19+
}
20+
21+
max_val = warp_reduce_max(max_val);
22+
23+
float sum = 0.f;
24+
25+
#pragma unroll
26+
for (int i = 0; i < experts_per_thread; i++) {
27+
const int idx = lane + i * WARP_SIZE;
28+
const bool active = !use_limit || (idx < limit);
29+
if (active) {
30+
const float val = expf(vals[i] - max_val);
31+
vals[i] = val;
32+
sum += val;
33+
} else {
34+
vals[i] = 0.f;
35+
}
36+
}
37+
38+
sum = warp_reduce_sum(sum);
39+
40+
const float inv_sum = 1.0f / sum;
41+
42+
#pragma unroll
43+
for (int i = 0; i < experts_per_thread; i++) {
44+
const int idx = lane + i * WARP_SIZE;
45+
const bool active = !use_limit || (idx < limit);
46+
if (active) {
47+
vals[i] *= inv_sum;
48+
}
49+
}
50+
}
51+
752
/*
853
This kernel does the following:
9-
1. softmax over the logits per token [n_experts, n_tokens]
54+
1. optionally softmax over the logits per token [n_experts, n_tokens]
1055
2. argmax reduce over the top-k (n_experts_used) logits
1156
3. write weights + ids to global memory
12-
4. optionally normalize the weights
57+
4. optionally normalize the weights or apply softmax over the selected logits
1358
1459
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1560
*/
16-
template <int n_experts, bool with_norm>
61+
template <int n_experts, bool with_norm, bool delayed_softmax = false>
1762
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
1863
float * weights,
1964
int32_t * ids,
@@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
3075

3176
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
3277

33-
float logits_r[experts_per_thread];
78+
float wt[experts_per_thread];
3479

3580
#pragma unroll
3681
for (int i = 0; i < n_experts; i += WARP_SIZE) {
37-
const int expert = i + threadIdx.x;
38-
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
82+
const int expert = i + threadIdx.x;
83+
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
3984
}
4085

41-
float max_val = logits_r[0];
42-
43-
#pragma unroll
44-
for (int i = 1; i < experts_per_thread; i++) {
45-
const float val = logits_r[i];
46-
max_val = max(val, max_val);
86+
if constexpr (!delayed_softmax) {
87+
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
4788
}
4889

49-
max_val = warp_reduce_max(max_val);
50-
51-
float wt[experts_per_thread];
52-
float tmp = 0.f;
53-
54-
#pragma unroll
55-
for (int i = 0; i < experts_per_thread; i++) {
56-
const float val = logits_r[i];
57-
wt[i] = expf(val - max_val);
58-
tmp += wt[i];
59-
}
90+
//at this point, each thread holds either a portion of the softmax distribution
91+
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
92+
//the expert weight as -inf to exclude from the next iteration
6093

61-
tmp = warp_reduce_sum(tmp);
94+
float wt_sum = 0.f;
6295

63-
const float inv_sum = 1.0f / tmp;
96+
float output_weights[experts_per_thread];
6497

6598
#pragma unroll
6699
for (int i = 0; i < experts_per_thread; i++) {
67-
wt[i] = wt[i] * inv_sum;
100+
output_weights[i] = 0.f;
68101
}
69102

70-
//at this point, each thread holds a portion of softmax,
71-
//we do the argmax reduce over n_expert_used, each time marking
72-
//the expert weight as -inf to exclude from the next iteration
73-
74-
float wt_sum = 0.f;
75-
76-
float output_weights[experts_per_thread];
77-
78103
for (int k = 0; k < n_expert_used; k++) {
79104
float max_val = wt[0];
80105
int max_expert = threadIdx.x;
@@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
121146
}
122147
}
123148

149+
if constexpr (delayed_softmax) {
150+
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
151+
}
152+
124153
#pragma unroll
125154
for (int i = 0; i < experts_per_thread; i++) {
126155
const int idx = i * WARP_SIZE + threadIdx.x;
@@ -130,58 +159,60 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
130159
}
131160
}
132161

133-
template <bool with_norm>
162+
template <bool with_norm, bool delayed_softmax = false>
134163
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
135164
const float * logits,
136165
float * weights,
137166
int32_t * ids,
138167
const int n_rows,
139168
const int n_expert,
140169
const int n_expert_used) {
170+
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
171+
141172
const int rows_per_block = 4;
142173
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
143174
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
144175
cudaStream_t stream = ctx.stream();
145176

146177
switch (n_expert) {
147178
case 1:
148-
topk_moe_cuda<1, with_norm>
179+
topk_moe_cuda<1, with_norm, delayed_softmax>
149180
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
150181
break;
151182
case 2:
152-
topk_moe_cuda<2, with_norm>
183+
topk_moe_cuda<2, with_norm, delayed_softmax>
153184
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
154185
break;
155186
case 4:
156-
topk_moe_cuda<4, with_norm>
187+
topk_moe_cuda<4, with_norm, delayed_softmax>
157188
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
158189
break;
159190
case 8:
160-
topk_moe_cuda<8, with_norm>
191+
topk_moe_cuda<8, with_norm, delayed_softmax>
161192
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
162193
break;
163194
case 16:
164-
topk_moe_cuda<16, with_norm>
195+
topk_moe_cuda<16, with_norm, delayed_softmax>
165196
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
166197
break;
167198
case 32:
168-
topk_moe_cuda<32, with_norm>
199+
topk_moe_cuda<32, with_norm, delayed_softmax>
169200
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
170201
break;
171202
case 64:
172-
topk_moe_cuda<64, with_norm>
203+
topk_moe_cuda<64, with_norm, delayed_softmax>
173204
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
174205
break;
175206
case 128:
176-
topk_moe_cuda<128, with_norm>
207+
topk_moe_cuda<128, with_norm, delayed_softmax>
177208
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
178209
break;
179210
case 256:
180-
topk_moe_cuda<256, with_norm>
211+
topk_moe_cuda<256, with_norm, delayed_softmax>
181212
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
182213
break;
183214
case 512:
184-
topk_moe_cuda<512, with_norm>
215+
topk_moe_cuda<512, with_norm, delayed_softmax>
185216
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
186217
break;
187218
default:
@@ -194,15 +225,16 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
194225
const ggml_tensor * logits,
195226
ggml_tensor * weights,
196227
ggml_tensor * ids,
197-
const bool with_norm) {
228+
const bool with_norm,
229+
const bool delayed_softmax) {
198230
GGML_ASSERT(logits->type == GGML_TYPE_F32);
199231
GGML_ASSERT(weights->type == GGML_TYPE_F32);
200232
GGML_ASSERT(ids->type == GGML_TYPE_I32);
201233

202234
const int n_experts = logits->ne[0];
203235
const int n_rows = logits->ne[1];
204236

205-
const float * logits_d = (const float *) logits->src[0]->data;
237+
const float * logits_d = (const float *) logits->data;
206238
float * weights_d = (float *) weights->data;
207239
int32_t * ids_d = (int32_t *) ids->data;
208240

@@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
213245
if (with_norm) {
214246
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215247
} else {
216-
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
248+
if (delayed_softmax) {
249+
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
250+
} else {
251+
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
252+
}
217253
}
218254
}
219255

@@ -246,16 +282,27 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
246282
return true;
247283
}
248284

249-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
285+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
250286
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
251287
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
252288
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
253289

254290
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
255291
GGML_OP_VIEW, GGML_OP_GET_ROWS };
256292

293+
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
294+
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
295+
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
296+
297+
GGML_ASSERT(!norm || !delayed_softmax);
298+
299+
if (delayed_softmax) {
300+
return delayed_softmax_ops;
301+
}
302+
257303
if (norm) {
258304
return norm_ops;
259305
}
306+
260307
return no_norm_ops;
261308
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
77
const ggml_tensor * logits,
88
ggml_tensor * weights,
9-
ggml_tensor * top_k,
10-
const bool with_norm);
9+
ggml_tensor * ids,
10+
const bool with_norm,
11+
const bool delayed_softmax = false);
1112

1213
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
1314

14-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);
15+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

0 commit comments

Comments
 (0)