-
Notifications
You must be signed in to change notification settings - Fork 13.5k
CUDA: add a fused top-K MoE kernel #16130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
a208f5c
CUDA: add a fused top-K MoE kernel
am17an 9fc0396
Refactor into ggml_cuda_should_use_topk_moe
am17an 8b780cc
Review: Use better coalescing pattern, use WARP_SIZE, store logits in…
am17an ce867aa
Review: format + micro-optimizations
am17an 2930668
Fix bug: fix tie breakers
am17an 240b2c1
Add optional norm + clean-up code
am17an e772b28
Use smem for final write
am17an 53acfe6
Add bounds check
am17an 33856e1
Use better memory pattern for writeback
am17an File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,259 @@ | ||
| #include "ggml-cuda/common.cuh" | ||
| #include "ggml.h" | ||
| #include "topk-moe.cuh" | ||
|
|
||
| #include <initializer_list> | ||
|
|
||
| /* | ||
| This kernel does the following: | ||
| 1. softmax over the logits per token [n_experts, n_tokens] | ||
| 2. argmax reduce over the top-k (n_experts_used) logits | ||
| 3. write weights + ids to global memory | ||
| 4. optionally normalize the weights | ||
|
|
||
| It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models | ||
| */ | ||
| template <size_t n_experts, bool with_norm> | ||
| __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, | ||
| float * weights, | ||
| int32_t * ids, | ||
| const int n_rows, | ||
| const int n_expert_used) { | ||
| const int row = blockIdx.x * blockDim.y + threadIdx.y; | ||
| if (row >= n_rows) { | ||
| return; | ||
| } | ||
|
|
||
| logits += n_experts * row; | ||
| weights += n_expert_used * row; | ||
| ids += n_experts * row; | ||
|
|
||
| constexpr int experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1; | ||
|
|
||
| float logits_r[experts_per_thread]; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < n_experts; i += WARP_SIZE) { | ||
| const int expert = i + threadIdx.x; | ||
| logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[expert] : -INFINITY; | ||
| } | ||
|
|
||
| float max_val = logits_r[0]; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 1; i < experts_per_thread; i++) { | ||
| const float val = logits_r[i]; | ||
| max_val = max(val, max_val); | ||
| } | ||
|
|
||
| max_val = warp_reduce_max(max_val); | ||
|
|
||
| float wt[experts_per_thread]; | ||
| float tmp = 0.f; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < experts_per_thread; i++) { | ||
| const float val = logits_r[i]; | ||
| wt[i] = expf(val - max_val); | ||
| tmp += wt[i]; | ||
| } | ||
|
|
||
| tmp = warp_reduce_sum(tmp); | ||
|
|
||
| const float inv_sum = 1.0f / tmp; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 0; i < experts_per_thread; i++) { | ||
| wt[i] = wt[i] * inv_sum; | ||
| } | ||
|
|
||
| //at this point, each thread holds a portion of softmax, | ||
| //we do the argmax reduce over n_expert_used, each time marking | ||
| //the expert weight as -inf to exclude from the next iteration | ||
|
|
||
| float wt_sum = 0.f; | ||
|
|
||
| extern __shared__ float data_topk_shared[]; | ||
| float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used; | ||
|
|
||
| for (int k = 0; k < n_expert_used; k++) { | ||
| float max_val = wt[0]; | ||
| int max_expert = threadIdx.x; | ||
|
|
||
| #pragma unroll | ||
| for (int i = 1; i < experts_per_thread; i++) { | ||
| const int expert = threadIdx.x + i * WARP_SIZE; | ||
| if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { | ||
| max_val = wt[i]; | ||
| max_expert = expert; | ||
| } | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { | ||
| const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); | ||
| const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); | ||
| if (val > max_val || (val == max_val && expert < max_expert)) { | ||
| max_val = val; | ||
| max_expert = expert; | ||
| } | ||
| } | ||
|
|
||
| if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { | ||
| wt[max_expert / WARP_SIZE] = -INFINITY; | ||
|
|
||
| wt_shared_ptr[k] = max_val; | ||
| ids[k] = max_expert; | ||
| if constexpr (with_norm) { | ||
| wt_sum += max_val; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if constexpr (with_norm) { | ||
| wt_sum = warp_reduce_sum(wt_sum); | ||
| const float inv_sum = 1.0f / wt_sum; | ||
|
|
||
| for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { | ||
| wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum; | ||
| } | ||
| } | ||
|
|
||
| for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) { | ||
| weights[i] = wt_shared_ptr[i]; | ||
| } | ||
| } | ||
|
|
||
| template <bool with_norm> | ||
| static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, | ||
| const float * logits, | ||
| float * weights, | ||
| int32_t * ids, | ||
| const int n_rows, | ||
| const int n_expert, | ||
| const int n_expert_used) { | ||
| const int rows_per_block = 4; | ||
| dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); | ||
| dim3 block_dims(WARP_SIZE, rows_per_block, 1); | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float); | ||
|
|
||
| switch (n_expert) { | ||
| case 1: | ||
| topk_moe_cuda<1, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 2: | ||
| topk_moe_cuda<2, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 4: | ||
| topk_moe_cuda<4, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 8: | ||
| topk_moe_cuda<8, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 16: | ||
| topk_moe_cuda<16, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 32: | ||
| topk_moe_cuda<32, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 64: | ||
| topk_moe_cuda<64, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 128: | ||
| topk_moe_cuda<128, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 256: | ||
| topk_moe_cuda<256, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| case 512: | ||
| topk_moe_cuda<512, with_norm> | ||
| <<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used); | ||
| break; | ||
| default: | ||
| GGML_ASSERT(false && "fatal error"); | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, | ||
| const ggml_tensor * logits, | ||
| ggml_tensor * weights, | ||
| ggml_tensor * ids, | ||
| const bool with_norm) { | ||
| GGML_ASSERT(logits->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(weights->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(ids->type == GGML_TYPE_I32); | ||
|
|
||
| const int n_experts = logits->ne[0]; | ||
| const int n_rows = logits->ne[1]; | ||
|
|
||
| const float * logits_d = (const float *) logits->src[0]->data; | ||
| float * weights_d = (float *) weights->data; | ||
| int32_t * ids_d = (int32_t *) ids->data; | ||
|
|
||
| GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); | ||
|
|
||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| const int n_expert_used = weights->ne[1]; | ||
|
|
||
| if (with_norm) { | ||
| launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); | ||
| } else { | ||
| launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); | ||
| } | ||
| } | ||
|
|
||
| bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { | ||
am17an marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| float scale = 1.0f; | ||
| float max_bias = 0.0f; | ||
|
|
||
| memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); | ||
| memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); | ||
|
|
||
| if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { | ||
| return false; | ||
| } | ||
|
|
||
| if (scale != 1.0f || max_bias != 0.0f) { | ||
| return false; | ||
| } | ||
|
|
||
| // don't fuse when masks or sinks are present | ||
| if (softmax->src[1] || softmax->src[2]) { | ||
| return false; | ||
| } | ||
|
|
||
| const int n_expert = softmax->ne[0]; | ||
| // n_expert must be a power of 2 | ||
| if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { | ||
| return false; | ||
| } | ||
|
|
||
| return true; | ||
| } | ||
|
|
||
| std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) { | ||
| static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, | ||
| GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, | ||
| GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; | ||
|
|
||
| static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, | ||
| GGML_OP_VIEW, GGML_OP_GET_ROWS }; | ||
|
|
||
| if (norm) { | ||
| return norm_ops; | ||
| } | ||
| return no_norm_ops; | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| #include "common.cuh" | ||
| #include "ggml.h" | ||
|
|
||
| #include <initializer_list> | ||
|
|
||
| void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, | ||
| const ggml_tensor * logits, | ||
| ggml_tensor * weights, | ||
| ggml_tensor * top_k, | ||
| const bool with_norm); | ||
|
|
||
| bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); | ||
|
|
||
| std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.