Skip to content

Commit 2de54df

Browse files
committed
add parameter to avoid runtime branch
1 parent 5632159 commit 2de54df

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55
#include <initializer_list>
66

77
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8-
template <int experts_per_thread>
8+
template <int experts_per_thread, bool use_limit>
99
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
1010
float max_val = -INFINITY;
1111

1212
#pragma unroll
1313
for (int i = 0; i < experts_per_thread; i++) {
14-
const int idx = lane + i * WARP_SIZE;
15-
if (idx < limit) {
14+
const int idx = lane + i * WARP_SIZE;
15+
const bool active = !use_limit || (idx < limit);
16+
if (active) {
1617
max_val = max(max_val, vals[i]);
1718
}
1819
}
@@ -23,8 +24,9 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
2324

2425
#pragma unroll
2526
for (int i = 0; i < experts_per_thread; i++) {
26-
const int idx = lane + i * WARP_SIZE;
27-
if (idx < limit) {
27+
const int idx = lane + i * WARP_SIZE;
28+
const bool active = !use_limit || (idx < limit);
29+
if (active) {
2830
const float val = expf(vals[i] - max_val);
2931
vals[i] = val;
3032
sum += val;
@@ -39,8 +41,9 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
3941

4042
#pragma unroll
4143
for (int i = 0; i < experts_per_thread; i++) {
42-
const int idx = lane + i * WARP_SIZE;
43-
if (idx < limit) {
44+
const int idx = lane + i * WARP_SIZE;
45+
const bool active = !use_limit || (idx < limit);
46+
if (active) {
4447
vals[i] *= inv_sum;
4548
}
4649
}
@@ -76,12 +79,12 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7679

7780
#pragma unroll
7881
for (int i = 0; i < n_experts; i += WARP_SIZE) {
79-
const int expert = i + threadIdx.x;
80-
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
82+
const int expert = i + threadIdx.x;
83+
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
8184
}
8285

8386
if constexpr (!delayed_softmax) {
84-
softmax_warp_inplace<experts_per_thread>(wt, n_experts, threadIdx.x);
87+
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
8588
}
8689

8790
//at this point, each thread holds either a portion of the softmax distribution
@@ -144,7 +147,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
144147
}
145148

146149
if constexpr (delayed_softmax) {
147-
softmax_warp_inplace<experts_per_thread>(output_weights, n_expert_used, threadIdx.x);
150+
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
148151
}
149152

150153
#pragma unroll

0 commit comments

Comments
 (0)