Skip to content

Commit 46a1ef1

Browse files
committed
Revert "CUDA: topk-moe: add optional parameter for gpt-oss (ggml-org#16649)"
1 parent 7f34add commit 46a1ef1

File tree

3 files changed

+56
-131
lines changed

3 files changed

+56
-131
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,12 +2831,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28312831
#endif
28322832

28332833
//TODO: remove special case once ggml_can_fuse can handle empty nodes
2834-
std::initializer_list<enum ggml_op> topk_moe_ops =
2835-
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
2836-
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
2837-
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
2838-
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
2839-
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
2834+
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2835+
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
28402836

28412837
if (ops.size() == topk_moe_ops_with_norm.size() &&
28422838
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
@@ -2857,16 +2853,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28572853
}
28582854
}
28592855

2860-
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
2861-
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_delayed_softmax, { node_idx + 2, node_idx + 5 })) {
2862-
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
2863-
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
2864-
2865-
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2866-
return true;
2867-
}
2868-
}
2869-
28702856
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28712857
return false;
28722858
}
@@ -2960,32 +2946,19 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29602946
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
29612947
ggml_tensor * weights = cgraph->nodes[i+8];
29622948
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2963-
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
2964-
/*delayed softmax*/ false);
2949+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
29652950
i += 8;
29662951
continue;
29672952
}
29682953

29692954
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
29702955
ggml_tensor * weights = cgraph->nodes[i+4];
29712956
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2972-
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
2973-
/*delayed softmax*/ false);
2957+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
29742958
i += 4;
29752959
continue;
29762960
}
29772961

2978-
if (ggml_cuda_can_fuse(cgraph, i,
2979-
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) {
2980-
ggml_tensor * weights = cgraph->nodes[i + 5];
2981-
ggml_tensor * ids = cgraph->nodes[i + 1];
2982-
2983-
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false,
2984-
/*delayed_softmax*/ true);
2985-
i += 5;
2986-
continue;
2987-
}
2988-
29892962
if (node->op == GGML_OP_ADD) {
29902963
int n_fuse = 0;
29912964
ggml_op ops[8];

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 49 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,16 @@
44

55
#include <initializer_list>
66

7-
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8-
template <int experts_per_thread, bool use_limit>
9-
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
10-
float max_val = -INFINITY;
11-
12-
#pragma unroll
13-
for (int i = 0; i < experts_per_thread; i++) {
14-
const int idx = lane + i * WARP_SIZE;
15-
const bool active = !use_limit || (idx < limit);
16-
if (active) {
17-
max_val = max(max_val, vals[i]);
18-
}
19-
}
20-
21-
max_val = warp_reduce_max(max_val);
22-
23-
float sum = 0.f;
24-
25-
#pragma unroll
26-
for (int i = 0; i < experts_per_thread; i++) {
27-
const int idx = lane + i * WARP_SIZE;
28-
const bool active = !use_limit || (idx < limit);
29-
if (active) {
30-
const float val = expf(vals[i] - max_val);
31-
vals[i] = val;
32-
sum += val;
33-
} else {
34-
vals[i] = 0.f;
35-
}
36-
}
37-
38-
sum = warp_reduce_sum(sum);
39-
40-
const float inv_sum = 1.0f / sum;
41-
42-
#pragma unroll
43-
for (int i = 0; i < experts_per_thread; i++) {
44-
const int idx = lane + i * WARP_SIZE;
45-
const bool active = !use_limit || (idx < limit);
46-
if (active) {
47-
vals[i] *= inv_sum;
48-
}
49-
}
50-
}
51-
527
/*
538
This kernel does the following:
54-
1. optionally softmax over the logits per token [n_experts, n_tokens]
9+
1. softmax over the logits per token [n_experts, n_tokens]
5510
2. argmax reduce over the top-k (n_experts_used) logits
5611
3. write weights + ids to global memory
57-
4. optionally normalize the weights or apply softmax over the selected logits
12+
4. optionally normalize the weights
5813
5914
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
6015
*/
61-
template <int n_experts, bool with_norm, bool delayed_softmax = false>
16+
template <int n_experts, bool with_norm>
6217
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
6318
float * weights,
6419
int32_t * ids,
@@ -75,31 +30,51 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7530

7631
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
7732

78-
float wt[experts_per_thread];
33+
float logits_r[experts_per_thread];
7934

8035
#pragma unroll
8136
for (int i = 0; i < n_experts; i += WARP_SIZE) {
82-
const int expert = i + threadIdx.x;
83-
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
37+
const int expert = i + threadIdx.x;
38+
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY;
8439
}
8540

86-
if constexpr (!delayed_softmax) {
87-
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
41+
float max_val = logits_r[0];
42+
43+
#pragma unroll
44+
for (int i = 1; i < experts_per_thread; i++) {
45+
const float val = logits_r[i];
46+
max_val = max(val, max_val);
8847
}
8948

90-
//at this point, each thread holds either a portion of the softmax distribution
91-
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
92-
//the expert weight as -inf to exclude from the next iteration
49+
max_val = warp_reduce_max(max_val);
9350

94-
float wt_sum = 0.f;
51+
float wt[experts_per_thread];
52+
float tmp = 0.f;
9553

96-
float output_weights[experts_per_thread];
54+
#pragma unroll
55+
for (int i = 0; i < experts_per_thread; i++) {
56+
const float val = logits_r[i];
57+
wt[i] = expf(val - max_val);
58+
tmp += wt[i];
59+
}
60+
61+
tmp = warp_reduce_sum(tmp);
62+
63+
const float inv_sum = 1.0f / tmp;
9764

9865
#pragma unroll
9966
for (int i = 0; i < experts_per_thread; i++) {
100-
output_weights[i] = 0.f;
67+
wt[i] = wt[i] * inv_sum;
10168
}
10269

70+
//at this point, each thread holds a portion of softmax,
71+
//we do the argmax reduce over n_expert_used, each time marking
72+
//the expert weight as -inf to exclude from the next iteration
73+
74+
float wt_sum = 0.f;
75+
76+
float output_weights[experts_per_thread];
77+
10378
for (int k = 0; k < n_expert_used; k++) {
10479
float max_val = wt[0];
10580
int max_expert = threadIdx.x;
@@ -146,10 +121,6 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
146121
}
147122
}
148123

149-
if constexpr (delayed_softmax) {
150-
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
151-
}
152-
153124
#pragma unroll
154125
for (int i = 0; i < experts_per_thread; i++) {
155126
const int idx = i * WARP_SIZE + threadIdx.x;
@@ -159,60 +130,58 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
159130
}
160131
}
161132

162-
template <bool with_norm, bool delayed_softmax = false>
133+
template <bool with_norm>
163134
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
164135
const float * logits,
165136
float * weights,
166137
int32_t * ids,
167138
const int n_rows,
168139
const int n_expert,
169140
const int n_expert_used) {
170-
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
171-
172141
const int rows_per_block = 4;
173142
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
174143
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
175144
cudaStream_t stream = ctx.stream();
176145

177146
switch (n_expert) {
178147
case 1:
179-
topk_moe_cuda<1, with_norm, delayed_softmax>
148+
topk_moe_cuda<1, with_norm>
180149
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
181150
break;
182151
case 2:
183-
topk_moe_cuda<2, with_norm, delayed_softmax>
152+
topk_moe_cuda<2, with_norm>
184153
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
185154
break;
186155
case 4:
187-
topk_moe_cuda<4, with_norm, delayed_softmax>
156+
topk_moe_cuda<4, with_norm>
188157
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
189158
break;
190159
case 8:
191-
topk_moe_cuda<8, with_norm, delayed_softmax>
160+
topk_moe_cuda<8, with_norm>
192161
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
193162
break;
194163
case 16:
195-
topk_moe_cuda<16, with_norm, delayed_softmax>
164+
topk_moe_cuda<16, with_norm>
196165
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
197166
break;
198167
case 32:
199-
topk_moe_cuda<32, with_norm, delayed_softmax>
168+
topk_moe_cuda<32, with_norm>
200169
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
201170
break;
202171
case 64:
203-
topk_moe_cuda<64, with_norm, delayed_softmax>
172+
topk_moe_cuda<64, with_norm>
204173
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
205174
break;
206175
case 128:
207-
topk_moe_cuda<128, with_norm, delayed_softmax>
176+
topk_moe_cuda<128, with_norm>
208177
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
209178
break;
210179
case 256:
211-
topk_moe_cuda<256, with_norm, delayed_softmax>
180+
topk_moe_cuda<256, with_norm>
212181
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
213182
break;
214183
case 512:
215-
topk_moe_cuda<512, with_norm, delayed_softmax>
184+
topk_moe_cuda<512, with_norm>
216185
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
217186
break;
218187
default:
@@ -225,16 +194,15 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
225194
const ggml_tensor * logits,
226195
ggml_tensor * weights,
227196
ggml_tensor * ids,
228-
const bool with_norm,
229-
const bool delayed_softmax) {
197+
const bool with_norm) {
230198
GGML_ASSERT(logits->type == GGML_TYPE_F32);
231199
GGML_ASSERT(weights->type == GGML_TYPE_F32);
232200
GGML_ASSERT(ids->type == GGML_TYPE_I32);
233201

234202
const int n_experts = logits->ne[0];
235203
const int n_rows = logits->ne[1];
236204

237-
const float * logits_d = (const float *) logits->data;
205+
const float * logits_d = (const float *) logits->src[0]->data;
238206
float * weights_d = (float *) weights->data;
239207
int32_t * ids_d = (int32_t *) ids->data;
240208

@@ -245,11 +213,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
245213
if (with_norm) {
246214
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
247215
} else {
248-
if (delayed_softmax) {
249-
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
250-
} else {
251-
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
252-
}
216+
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
253217
}
254218
}
255219

@@ -282,27 +246,16 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
282246
return true;
283247
}
284248

285-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
249+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
286250
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287251
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288252
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
289253

290254
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291255
GGML_OP_VIEW, GGML_OP_GET_ROWS };
292256

293-
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
294-
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
295-
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
296-
297-
GGML_ASSERT(!norm || !delayed_softmax);
298-
299-
if (delayed_softmax) {
300-
return delayed_softmax_ops;
301-
}
302-
303257
if (norm) {
304258
return norm_ops;
305259
}
306-
307260
return no_norm_ops;
308261
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
77
const ggml_tensor * logits,
88
ggml_tensor * weights,
9-
ggml_tensor * ids,
10-
const bool with_norm,
11-
const bool delayed_softmax = false);
9+
ggml_tensor * top_k,
10+
const bool with_norm);
1211

1312
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
1413

15-
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
14+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);

0 commit comments

Comments
 (0)