Skip to content

Commit 7345668

Browse files
committed
CUDA: add a fused top-K MoE kernel
This kernel does the following: 1. softmax over the logits per token [n_experts, n_tokens] 2. argmax reduce over the top-k (n_experts_used) logits 3. write weights + ids to global memory It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1 parent 51abc96 commit 7345668

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "ggml-cuda/sumrows.cuh"
4646
#include "ggml-cuda/mean.cuh"
4747
#include "ggml-cuda/tsembd.cuh"
48+
#include "ggml-cuda/topk-moe.cuh"
4849
#include "ggml-cuda/unary.cuh"
4950
#include "ggml-cuda/upscale.cuh"
5051
#include "ggml-cuda/wkv.cuh"
@@ -2825,6 +2826,40 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28252826
GGML_ASSERT(unary_ops.size() == num_unary);
28262827
#endif
28272828

2829+
//special case for topk-moe
2830+
if (ops.size() == 5 && ops.begin()[0] == GGML_OP_SOFT_MAX && ops.begin()[1] == GGML_OP_RESHAPE && ops.begin()[2] == GGML_OP_ARGSORT
2831+
&& ops.begin()[3] == GGML_OP_VIEW && ops.begin()[4] == GGML_OP_GET_ROWS) {
2832+
2833+
for (int i = 0; i < 5; i++) {
2834+
if (cgraph->nodes[node_idx + i]->op != ops.begin()[i]) return false;
2835+
}
2836+
2837+
ggml_tensor * softmax = cgraph->nodes[node_idx];
2838+
2839+
float scale = 1.0f;
2840+
float max_bias = 0.0f;
2841+
2842+
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
2843+
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
2844+
2845+
if (scale != 1.0f || max_bias != 0.0f) {
2846+
return false;
2847+
}
2848+
2849+
// don't fuse when masks or sinks are present
2850+
if (softmax->src[1] || softmax->src[2]) {
2851+
return false;
2852+
}
2853+
2854+
const int n_expert = softmax->ne[0];
2855+
// n_expert must be a power of 2
2856+
if (n_expert & (n_expert - 1) != 0 || n_expert > 512) {
2857+
return false;
2858+
}
2859+
2860+
return true;
2861+
}
2862+
28282863
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
28292864
return false;
28302865
}
@@ -2892,6 +2927,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28922927
return true;
28932928
}
28942929

2930+
2931+
28952932
return false;
28962933
}
28972934

@@ -2915,6 +2952,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29152952
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
29162953
if (!disable_fusion) {
29172954

2955+
if (ggml_cuda_can_fuse(cgraph, i, {GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS}, {})) {
2956+
2957+
ggml_tensor * weights = cgraph->nodes[i+4];
2958+
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2959+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts);
2960+
i += 4;
2961+
continue;
2962+
}
2963+
29182964
if (node->op == GGML_OP_ADD) {
29192965
int n_fuse = 0;
29202966
ggml_op ops[8];
@@ -2964,6 +3010,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29643010
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
29653011
continue;
29663012
}
3013+
29673014
}
29683015
#ifndef NDEBUG
29693016
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
#include "topk-moe.cuh"
2+
3+
/*
4+
This kernel does the following:
5+
1. softmax over the logits per token [n_experts, n_tokens]
6+
2. argmax reduce over the top-k (n_experts_used) logits
7+
3. write weights + ids to global memory
8+
9+
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
10+
*/
11+
template <size_t n_experts>
12+
__global__ void topk_moe_cuda(const float * logits,
13+
float * weights,
14+
int32_t * ids,
15+
const int n_rows,
16+
const int n_expert_used) {
17+
const int row = blockIdx.x * blockDim.y + threadIdx.y;
18+
if (row >= n_rows) {
19+
return;
20+
}
21+
logits += n_experts * row;
22+
ids += n_experts * row;
23+
weights += n_expert_used * row;
24+
25+
constexpr int experts_per_thread = (n_experts > 32) ? n_experts / 32 : 1;
26+
27+
const int start_expert = threadIdx.x * experts_per_thread;
28+
const int end_expert = (threadIdx.x + 1) * experts_per_thread;
29+
float max_val = -INFINITY;
30+
31+
#pragma unroll
32+
for (int i = 0; i < experts_per_thread; i++) {
33+
const int expert = start_expert + i;
34+
const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
35+
max_val = max(val, max_val);
36+
}
37+
38+
max_val = warp_reduce_max(max_val);
39+
40+
float wt[experts_per_thread];
41+
float tmp = 0.f;
42+
43+
#pragma unroll
44+
for (int i = 0; i < experts_per_thread; i++) {
45+
const int expert = start_expert + i;
46+
const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
47+
wt[i] = expf(val - max_val);
48+
tmp += wt[i];
49+
}
50+
51+
tmp = warp_reduce_sum(tmp);
52+
53+
const float inv_sum = 1.0f / tmp;
54+
55+
#pragma unroll
56+
for (int i = 0; i < experts_per_thread; i++) {
57+
wt[i] = wt[i] * inv_sum;
58+
}
59+
60+
//at this point, each thread holds a portion of softmax,
61+
//we do the argmax reduce over n_expert_used, each time marking
62+
//the expert weight as -inf to exclude from the next iteration
63+
64+
for (int k = 0; k < n_expert_used; k++) {
65+
float max_val = wt[0];
66+
int max_expert = start_expert;
67+
68+
#pragma unroll
69+
for (int i = 1; i < experts_per_thread; i++) {
70+
const int expert = start_expert + i;
71+
if (wt[i] > max_val) {
72+
max_val = wt[i];
73+
max_expert = expert;
74+
}
75+
}
76+
77+
#pragma unroll
78+
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
79+
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, warpSize);
80+
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, warpSize);
81+
if (val > max_val) {
82+
max_val = val;
83+
max_expert = expert;
84+
}
85+
}
86+
87+
if (max_expert >= start_expert && max_expert < end_expert) {
88+
wt[max_expert - start_expert] = -INFINITY;
89+
90+
weights[k] = max_val;
91+
ids[k] = max_expert;
92+
}
93+
}
94+
}
95+
96+
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
97+
const float * logits,
98+
float * weights,
99+
int32_t * ids,
100+
const int n_rows,
101+
const int n_expert,
102+
const int n_expert_used) {
103+
const int rows_per_block = 4;
104+
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
105+
dim3 block_dims(32, rows_per_block, 1);
106+
cudaStream_t stream = ctx.stream();
107+
108+
switch (n_expert) {
109+
case 1:
110+
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
111+
break;
112+
case 2:
113+
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
114+
break;
115+
case 4:
116+
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
117+
break;
118+
case 8:
119+
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
120+
break;
121+
case 16:
122+
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
123+
break;
124+
case 32:
125+
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
126+
break;
127+
case 64:
128+
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
129+
break;
130+
case 128:
131+
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
132+
break;
133+
case 256:
134+
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
135+
break;
136+
case 512:
137+
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
138+
break;
139+
default:
140+
GGML_ASSERT(false && "fatal error");
141+
break;
142+
}
143+
}
144+
145+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
146+
ggml_tensor * logits,
147+
ggml_tensor * weights,
148+
ggml_tensor * ids) {
149+
GGML_ASSERT(logits->type == GGML_TYPE_F32);
150+
GGML_ASSERT(weights->type == GGML_TYPE_F32);
151+
GGML_ASSERT(ids->type == GGML_TYPE_I32);
152+
153+
const float * logits_d = (const float *) logits->src[0]->data;
154+
float * weights_d = (float *) weights->data;
155+
int32_t * ids_d = (int32_t *) ids->data;
156+
157+
const int n_experts = logits->ne[0];
158+
const int n_rows = logits->ne[1];
159+
160+
cudaStream_t stream = ctx.stream();
161+
162+
const int n_expert_used = weights->ne[1];
163+
164+
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
165+
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k);

src/llama-graph.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
929929
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
930930
cb(weights, "ffn_moe_weights", il);
931931

932+
//call early so that softmax->topk->get_rows can be fused
933+
ggml_build_forward_expand(gf, weights);
934+
932935
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
933936
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
934937
weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]

tests/test-backend-ops.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4403,6 +4403,42 @@ struct test_argsort : public test_case {
44034403
}
44044404
};
44054405

4406+
struct test_topk_moe: public test_case {
4407+
4408+
const std::array<int64_t, 4> ne;
4409+
const int n_expert_used;
4410+
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1)
4411+
: ne(ne),
4412+
n_expert_used(n_expert_used) {
4413+
GGML_ASSERT(n_expert_used <= ne[0]);
4414+
}
4415+
4416+
std::string vars() override {
4417+
return VARS_TO_STR2(ne, n_expert_used);
4418+
}
4419+
4420+
std::string op_desc(ggml_tensor * t) override {
4421+
GGML_UNUSED(t);
4422+
return "TOPK_GATED_MOE";
4423+
}
4424+
4425+
bool run_whole_graph() override { return true; }
4426+
4427+
ggml_tensor * build_graph(ggml_context * ctx) override {
4428+
const int n_expert = ne[0];
4429+
const int n_tokens = ne[1];
4430+
4431+
ggml_tensor * logits = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne.data());
4432+
ggml_tensor * probs = ggml_soft_max(ctx, logits);
4433+
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
4434+
4435+
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
4436+
4437+
ggml_set_name(out, "out");
4438+
return out;
4439+
}
4440+
};
4441+
44064442
// GGML_OP_SUM
44074443
struct test_sum : public test_case {
44084444
const ggml_type type;
@@ -6557,6 +6593,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
65576593
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
65586594
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
65596595

6596+
6597+
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4));
6598+
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8));
6599+
test_cases.emplace_back(new test_topk_moe({128, 19, 1, 1}, 16));
6600+
65606601
#if 0
65616602
// these tests are disabled to save execution time, sbut they can be handy for debugging
65626603
test_cases.emplace_back(new test_llama(2, true));

0 commit comments

Comments
 (0)