55#include < initializer_list>
66
77// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8- template <int experts_per_thread>
8+ template <int experts_per_thread, bool use_limit >
99__device__ void softmax_warp_inplace (float (&vals)[experts_per_thread], const int limit, const int lane) {
1010 float max_val = -INFINITY;
1111
1212#pragma unroll
1313 for (int i = 0 ; i < experts_per_thread; i++) {
14- const int idx = lane + i * WARP_SIZE;
15- if (idx < limit) {
14+ const int idx = lane + i * WARP_SIZE;
15+ const bool active = !use_limit || (idx < limit);
16+ if (active) {
1617 max_val = max (max_val, vals[i]);
1718 }
1819 }
@@ -23,8 +24,9 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
2324
2425#pragma unroll
2526 for (int i = 0 ; i < experts_per_thread; i++) {
26- const int idx = lane + i * WARP_SIZE;
27- if (idx < limit) {
27+ const int idx = lane + i * WARP_SIZE;
28+ const bool active = !use_limit || (idx < limit);
29+ if (active) {
2830 const float val = expf (vals[i] - max_val);
2931 vals[i] = val;
3032 sum += val;
@@ -39,8 +41,9 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
3941
4042#pragma unroll
4143 for (int i = 0 ; i < experts_per_thread; i++) {
42- const int idx = lane + i * WARP_SIZE;
43- if (idx < limit) {
44+ const int idx = lane + i * WARP_SIZE;
45+ const bool active = !use_limit || (idx < limit);
46+ if (active) {
4447 vals[i] *= inv_sum;
4548 }
4649 }
@@ -76,12 +79,12 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7679
7780#pragma unroll
7881 for (int i = 0 ; i < n_experts; i += WARP_SIZE) {
79- const int expert = i + threadIdx .x ;
80- wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
82+ const int expert = i + threadIdx .x ;
83+ wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
8184 }
8285
8386 if constexpr (!delayed_softmax) {
84- softmax_warp_inplace<experts_per_thread>(wt, n_experts, threadIdx .x );
87+ softmax_warp_inplace<experts_per_thread, false >(wt, n_experts, threadIdx .x );
8588 }
8689
8790 // at this point, each thread holds either a portion of the softmax distribution
@@ -144,7 +147,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
144147 }
145148
146149 if constexpr (delayed_softmax) {
147- softmax_warp_inplace<experts_per_thread>(output_weights, n_expert_used, threadIdx .x );
150+ softmax_warp_inplace<experts_per_thread, true >(output_weights, n_expert_used, threadIdx .x );
148151 }
149152
150153#pragma unroll
0 commit comments