Skip to content

Commit 17c9e7c

Browse files
committed
Review: Use better coalescing pattern, use WARP_SIZE, store logits into registers before
1 parent 1b7f1e3 commit 17c9e7c

File tree

3 files changed

+47
-31
lines changed

3 files changed

+47
-31
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2835,7 +2835,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28352835
}
28362836

28372837
ggml_tensor * softmax = cgraph->nodes[node_idx];
2838-
if (ggml_cuda_should_use_topk_moe(softmax)) {
2838+
ggml_tensor * weights = cgraph->nodes[node_idx+4];
2839+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
28392840
return true;
28402841
}
28412842
}

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "ggml-cuda/common.cuh"
12
#include "ggml.h"
23
#include "topk-moe.cuh"
34

@@ -10,30 +11,36 @@
1011
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1112
*/
1213
template <size_t n_experts>
13-
__global__ void topk_moe_cuda(const float * logits,
14-
float * weights,
15-
int32_t * ids,
16-
const int n_rows,
17-
const int n_expert_used) {
14+
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
15+
float * weights,
16+
int32_t * ids,
17+
const int n_rows,
18+
const int n_expert_used) {
1819
const int row = blockIdx.x * blockDim.y + threadIdx.y;
1920
if (row >= n_rows) {
2021
return;
2122
}
23+
2224
logits += n_experts * row;
23-
ids += n_experts * row;
2425
weights += n_expert_used * row;
26+
ids += n_experts * row;
27+
28+
constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
2529

26-
constexpr int experts_per_thread = (n_experts > 32) ? n_experts / 32 : 1;
30+
float logits_r[experts_per_thread];
2731

28-
const int start_expert = threadIdx.x * experts_per_thread;
29-
const int end_expert = (threadIdx.x + 1) * experts_per_thread;
30-
float max_val = -INFINITY;
32+
#pragma unroll
33+
for (int i = 0; i < n_experts; i += WARP_SIZE) {
34+
const int expert = i + threadIdx.x;
35+
logits_r[i / WARP_SIZE] = expert < n_experts ? logits[expert] : -INFINITY;
36+
}
37+
38+
float max_val = -INFINITY;
3139

3240
#pragma unroll
3341
for (int i = 0; i < experts_per_thread; i++) {
34-
const int expert = start_expert + i;
35-
const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
36-
max_val = max(val, max_val);
42+
const float val = logits_r[i];
43+
max_val = max(val, max_val);
3744
}
3845

3946
max_val = warp_reduce_max(max_val);
@@ -43,9 +50,8 @@ __global__ void topk_moe_cuda(const float * logits,
4350

4451
#pragma unroll
4552
for (int i = 0; i < experts_per_thread; i++) {
46-
const int expert = start_expert + i;
47-
const float val = (expert < n_experts) ? logits[expert] : -INFINITY;
48-
wt[i] = expf(val - max_val);
53+
const float val = logits_r[i];
54+
wt[i] = expf(val - max_val);
4955
tmp += wt[i];
5056
}
5157

@@ -64,29 +70,29 @@ __global__ void topk_moe_cuda(const float * logits,
6470

6571
for (int k = 0; k < n_expert_used; k++) {
6672
float max_val = wt[0];
67-
int max_expert = start_expert;
73+
int max_expert = threadIdx.x;
6874

6975
#pragma unroll
7076
for (int i = 1; i < experts_per_thread; i++) {
71-
const int expert = start_expert + i;
72-
if (wt[i] > max_val) {
77+
const int expert = threadIdx.x + i * WARP_SIZE;
78+
if (expert < n_experts && wt[i] > max_val) {
7379
max_val = wt[i];
7480
max_expert = expert;
7581
}
7682
}
7783

7884
#pragma unroll
79-
for (int mask = warpSize / 2; mask > 0; mask /= 2) {
80-
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, warpSize);
81-
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, warpSize);
85+
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
86+
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
87+
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
8288
if (val > max_val) {
8389
max_val = val;
8490
max_expert = expert;
8591
}
8692
}
8793

88-
if (max_expert >= start_expert && max_expert < end_expert) {
89-
wt[max_expert - start_expert] = -INFINITY;
94+
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
95+
wt[max_expert / WARP_SIZE] = -INFINITY;
9096

9197
weights[k] = max_val;
9298
ids[k] = max_expert;
@@ -103,7 +109,7 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
103109
const int n_expert_used) {
104110
const int rows_per_block = 4;
105111
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
106-
dim3 block_dims(32, rows_per_block, 1);
112+
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
107113
cudaStream_t stream = ctx.stream();
108114

109115
switch (n_expert) {
@@ -151,12 +157,14 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
151157
GGML_ASSERT(weights->type == GGML_TYPE_F32);
152158
GGML_ASSERT(ids->type == GGML_TYPE_I32);
153159

160+
const int n_experts = logits->ne[0];
161+
const int n_rows = logits->ne[1];
162+
154163
const float * logits_d = (const float *) logits->src[0]->data;
155164
float * weights_d = (float *) weights->data;
156165
int32_t * ids_d = (int32_t *) ids->data;
157166

158-
const int n_experts = logits->ne[0];
159-
const int n_rows = logits->ne[1];
167+
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
160168

161169
cudaStream_t stream = ctx.stream();
162170

@@ -165,13 +173,17 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
165173
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
166174
}
167175

168-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax) {
176+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
169177
float scale = 1.0f;
170178
float max_bias = 0.0f;
171179

172180
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
173181
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
174182

183+
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
184+
return false;
185+
}
186+
175187
if (scale != 1.0f || max_bias != 0.0f) {
176188
return false;
177189
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
#include "common.cuh"
22

3-
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const ggml_tensor * logits, ggml_tensor * weights, ggml_tensor * top_k);
3+
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
4+
const ggml_tensor * logits,
5+
ggml_tensor * weights,
6+
ggml_tensor * top_k);
47

5-
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax);
8+
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);

0 commit comments

Comments
 (0)