From 220761b2efe153e05f263ed30784239fd277889a Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 20:54:20 +0000 Subject: [PATCH 1/5] first version Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 115 +++++++++++++++++++++---------- 1 file changed, 80 insertions(+), 35 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 69b4c1fb11d1..b4854790b551 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) { return cuda_cast(sigmoid_accurate(f)); } -template +template +__device__ inline T apply_scoring(T val) { + if constexpr (SF == SCORING_SIGMOID) { + return apply_sigmoid(val); + } else { + return val; + } +} + +template __device__ void topk_with_k2(T* output, T const* input, T const* bias, cg::thread_block_tile<32> const& tile, int32_t const lane_id, - int const num_experts_per_group, - int const scoring_func) { + int const num_experts_per_group) { // Get the top2 per thread T largest = neg_inf(); T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; if (value > largest) { @@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } else { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { - T value = input[i]; - // Apply scoring function if needed - if (scoring_func == SCORING_SIGMOID) { - value = apply_sigmoid(value); - } + T value = apply_scoring(input[i]); value = value + bias[i]; largest = value; } @@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias, } } -template +template __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, int64_t const num_tokens, int64_t const num_cases, int64_t const n_group, - int64_t const num_experts_per_group, - int const scoring_func) { + int64_t const num_experts_per_group) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -525,21 +524,20 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - topk_with_k2(output, input, group_bias, tile, lane_id, - num_experts_per_group, scoring_func); + topk_with_k2(output, input, group_bias, tile, lane_id, + num_experts_per_group); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, - int64_t const num_experts_per_group, bool renormalize, - double routed_scaling_factor, int scoring_func) { + int64_t const num_experts_per_group, double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = @@ -616,8 +614,7 @@ __global__ void group_idx_and_topk_idx_kernel( // Apply scoring function (if any) and add bias T input = scores[offset + i]; if (is_finite(input)) { - T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) - : input; + T score = apply_scoring(input); candidates = score + bias[offset + i]; } } @@ -646,8 +643,7 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { // Load the score value (without bias) for normalization T input = scores[s_topk_idx[i]]; - value = - (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input; + value = apply_scoring(input); s_topk_value[i] = value; } topk_sum += @@ -661,7 +657,7 @@ __global__ void group_idx_and_topk_idx_kernel( if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { float value; - if (renormalize) { + if constexpr (Renorm) { value = cuda_cast(s_topk_value[i]) / topk_sum * routed_scaling_factor; } else { @@ -694,7 +690,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, cudaStream_t const stream = 0) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchConfig_t config; config.gridDim = topk_with_k2_num_blocks; config.blockDim = BLOCK_SIZE; @@ -705,16 +700,30 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group, - scoring_func); + auto const sf = static_cast(scoring_func); + switch (sf) { + case SCORING_NONE: { + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts / n_group); + break; + } + case SCORING_SIGMOID: { + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts / n_group); + break; + } + default: + // should be guarded by higher level checks. + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; size_t dynamic_smem_in_bytes = warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; config.gridDim = topk_with_k_group_num_blocks; config.blockDim = BLOCK_SIZE; config.dynamicSmemBytes = dynamic_smem_in_bytes; @@ -723,10 +732,46 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - renormalize, routed_scaling_factor, scoring_func); + switch (sf) { + case SCORING_NONE: { + if (renormalize) { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + } else { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + } + break; + } + case SCORING_SIGMOID: { + if (renormalize) { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + } else { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts / n_group, + routed_scaling_factor); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc"); + } } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ From 9472d905240675e050f7f3ff62554d5ef40d216b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 21:54:46 +0000 Subject: [PATCH 2/5] further optimize Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 176 ++++++++++++++++++++++--------- 1 file changed, 129 insertions(+), 47 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index b4854790b551..9a8c6496c0e8 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -532,7 +532,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, @@ -547,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel( topk_values += case_id * topk; topk_indices += case_id * topk; + constexpr bool kUseStaticNGroup = (NGroup > 0); + // use int32 to avoid implicit conversion + int32_t const n_group_i32 = + kUseStaticNGroup ? NGroup : static_cast(n_group); + int32_t align_num_experts_per_group = warp_topk::round_up_to_multiple_of(num_experts_per_group); @@ -572,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { // calculate group_idx - int32_t target_num_min = WARP_SIZE - n_group + topk_group; + int32_t target_num_min = + WARP_SIZE - n_group_i32 + static_cast(topk_group); // The check is necessary to avoid abnormal input - if (lane_id < n_group && is_finite(group_scores[lane_id])) { + if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) { value = group_scores[lane_id]; } - int count_equal_to_top_value = WARP_SIZE - n_group; + int count_equal_to_top_value = WARP_SIZE - n_group_i32; int pre_count_equal_to_top_value = 0; // Use loop to find the largset top_group while (count_equal_to_top_value < target_num_min) { @@ -602,26 +609,52 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { - for (int i_group = 0; i_group < n_group; i_group++) { - if ((group_scores[i_group] > topk_group_value) || - ((group_scores[i_group] == topk_group_value) && - (count_equalto_topkth_group < num_equalto_topkth_group))) { - int32_t offset = i_group * num_experts_per_group; - for (int32_t i = lane_id; i < align_num_experts_per_group; - i += WARP_SIZE) { - T candidates = neg_inf(); - if (i < num_experts_per_group) { - // Apply scoring function (if any) and add bias - T input = scores[offset + i]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidates = score + bias[offset + i]; + if constexpr (kUseStaticNGroup) { +#pragma unroll + for (int i_group = 0; i_group < NGroup; ++i_group) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = neg_inf(); + if (i < num_experts_per_group) { + // apply scoring function (if any) and add bias + T input = scores[offset + i]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidates = score + bias[offset + i]; + } } + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; } - queue.add(candidates, offset + i); } - if (group_scores[i_group] == topk_group_value) { - count_equalto_topkth_group++; + } + } else { + for (int i_group = 0; i_group < n_group_i32; ++i_group) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = neg_inf(); + if (i < num_experts_per_group) { + T input = scores[offset + i]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidates = score + bias[offset + i]; + } + } + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; + } } } } @@ -680,6 +713,62 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } +template +inline void launch_group_idx_and_topk_kernel( + cudaLaunchConfig_t const& config, T* scores, T* group_scores, + float* topk_values, IdxT* topk_indices, T const* bias, + int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, + int64_t const topk, int64_t const num_experts, + int64_t const num_experts_per_group, double const routed_scaling_factor) { + switch (n_group) { + case 4: { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + break; + } + case 8: { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + break; + } + case 16: { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + break; + } + case 32: { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + break; + } + default: { + auto* kernel_instance2 = + &group_idx_and_topk_idx_kernel; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + break; + } + } +} + template void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, @@ -701,17 +790,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, config.numAttrs = 1; config.attrs = attrs; auto const sf = static_cast(scoring_func); + int64_t const num_experts_per_group = num_experts / n_group; switch (sf) { case SCORING_NONE: { auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group); + num_tokens, num_cases, n_group, num_experts_per_group); break; } case SCORING_SIGMOID: { auto* kernel_instance1 = &topk_with_k2_kernel; cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts / n_group); + num_tokens, num_cases, n_group, num_experts_per_group); break; } default: @@ -735,37 +825,29 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, switch (sf) { case SCORING_NONE: { if (renormalize) { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, routed_scaling_factor); } else { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, routed_scaling_factor); } break; } case SCORING_SIGMOID: { if (renormalize) { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, routed_scaling_factor); } else { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts / n_group, - routed_scaling_factor); + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, routed_scaling_factor); } break; } From bcf6077709087a87bddd8b876b712c632e26ab5f Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 22:31:02 +0000 Subject: [PATCH 3/5] reduce code Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 73 +++++++++++++------------------- 1 file changed, 30 insertions(+), 43 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 9a8c6496c0e8..e09893cea762 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -609,53 +609,38 @@ __global__ void group_idx_and_topk_idx_kernel( int count_equalto_topkth_group = 0; bool if_proceed_next_topk = topk_group_value != neg_inf(); if (case_id < num_tokens && if_proceed_next_topk) { - if constexpr (kUseStaticNGroup) { -#pragma unroll - for (int i_group = 0; i_group < NGroup; ++i_group) { - if ((group_scores[i_group] > topk_group_value) || - ((group_scores[i_group] == topk_group_value) && - (count_equalto_topkth_group < num_equalto_topkth_group))) { - int32_t offset = i_group * num_experts_per_group; - for (int32_t i = lane_id; i < align_num_experts_per_group; - i += WARP_SIZE) { - T candidates = neg_inf(); - if (i < num_experts_per_group) { - // apply scoring function (if any) and add bias - T input = scores[offset + i]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidates = score + bias[offset + i]; - } + auto process_group = [&](int i_group) { + if ((group_scores[i_group] > topk_group_value) || + ((group_scores[i_group] == topk_group_value) && + (count_equalto_topkth_group < num_equalto_topkth_group))) { + int32_t offset = i_group * num_experts_per_group; + for (int32_t i = lane_id; i < align_num_experts_per_group; + i += WARP_SIZE) { + T candidates = neg_inf(); + if (i < num_experts_per_group) { + // apply scoring function (if any) and add bias + T input = scores[offset + i]; + if (is_finite(input)) { + T score = apply_scoring(input); + candidates = score + bias[offset + i]; } - queue.add(candidates, offset + i); - } - if (group_scores[i_group] == topk_group_value) { - count_equalto_topkth_group++; } + queue.add(candidates, offset + i); + } + if (group_scores[i_group] == topk_group_value) { + count_equalto_topkth_group++; } } + }; + + if constexpr (kUseStaticNGroup) { +#pragma unroll + for (int i_group = 0; i_group < NGroup; ++i_group) { + process_group(i_group); + } } else { for (int i_group = 0; i_group < n_group_i32; ++i_group) { - if ((group_scores[i_group] > topk_group_value) || - ((group_scores[i_group] == topk_group_value) && - (count_equalto_topkth_group < num_equalto_topkth_group))) { - int32_t offset = i_group * num_experts_per_group; - for (int32_t i = lane_id; i < align_num_experts_per_group; - i += WARP_SIZE) { - T candidates = neg_inf(); - if (i < num_experts_per_group) { - T input = scores[offset + i]; - if (is_finite(input)) { - T score = apply_scoring(input); - candidates = score + bias[offset + i]; - } - } - queue.add(candidates, offset + i); - } - if (group_scores[i_group] == topk_group_value) { - count_equalto_topkth_group++; - } - } + process_group(i_group); } } queue.done(); @@ -679,8 +664,10 @@ __global__ void group_idx_and_topk_idx_kernel( value = apply_scoring(input); s_topk_value[i] = value; } - topk_sum += - cg::reduce(tile, cuda_cast(value), cg::plus()); + if constexpr (Renorm) { + topk_sum += + cg::reduce(tile, cuda_cast(value), cg::plus()); + } } } From bb310e36890a81bc988d12f352560c4f08a8308b Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 22:41:42 +0000 Subject: [PATCH 4/5] reduce code Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 52 +++++++++++--------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index e09893cea762..4db6e957afef 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -707,50 +707,32 @@ inline void launch_group_idx_and_topk_kernel( int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, double const routed_scaling_factor) { + auto launch = [&](auto* kernel_instance2) { + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, bias, num_tokens, n_group, + topk_group, topk, num_experts, num_experts_per_group, + routed_scaling_factor); + }; + switch (n_group) { case 4: { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + launch(&group_idx_and_topk_idx_kernel); break; } case 8: { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + launch(&group_idx_and_topk_idx_kernel); break; } case 16: { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + launch(&group_idx_and_topk_idx_kernel); break; } case 32: { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + launch(&group_idx_and_topk_idx_kernel); break; } default: { - auto* kernel_instance2 = - &group_idx_and_topk_idx_kernel; - cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, - topk_values, topk_indices, bias, num_tokens, n_group, - topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + launch(&group_idx_and_topk_idx_kernel); break; } } @@ -778,17 +760,19 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, config.attrs = attrs; auto const sf = static_cast(scoring_func); int64_t const num_experts_per_group = num_experts / n_group; + auto launch_topk_with_k2 = [&](auto* kernel_instance1) { + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, + num_tokens, num_cases, n_group, num_experts_per_group); + }; switch (sf) { case SCORING_NONE: { auto* kernel_instance1 = &topk_with_k2_kernel; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts_per_group); + launch_topk_with_k2(kernel_instance1); break; } case SCORING_SIGMOID: { auto* kernel_instance1 = &topk_with_k2_kernel; - cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias, - num_tokens, num_cases, n_group, num_experts_per_group); + launch_topk_with_k2(kernel_instance1); break; } default: From 241e0b386ee7ea1d9b7c63c0688f996ce130680d Mon Sep 17 00:00:00 2001 From: yewentao256 Date: Fri, 5 Dec 2025 22:56:12 +0000 Subject: [PATCH 5/5] remove renorm template Signed-off-by: yewentao256 --- csrc/moe/grouped_topk_kernels.cu | 65 ++++++++++++-------------------- 1 file changed, 24 insertions(+), 41 deletions(-) diff --git a/csrc/moe/grouped_topk_kernels.cu b/csrc/moe/grouped_topk_kernels.cu index 4db6e957afef..47ee5f021eb4 100644 --- a/csrc/moe/grouped_topk_kernels.cu +++ b/csrc/moe/grouped_topk_kernels.cu @@ -532,13 +532,13 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias, #endif } -template +template __global__ void group_idx_and_topk_idx_kernel( T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, - int64_t const num_experts_per_group, double routed_scaling_factor) { + int64_t const num_experts_per_group, bool renormalize, + double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; int32_t case_id = @@ -664,7 +664,7 @@ __global__ void group_idx_and_topk_idx_kernel( value = apply_scoring(input); s_topk_value[i] = value; } - if constexpr (Renorm) { + if (renormalize) { topk_sum += cg::reduce(tile, cuda_cast(value), cg::plus()); } @@ -676,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel( if (case_id < num_tokens) { if (if_proceed_next_topk) { for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value; - if constexpr (Renorm) { - value = cuda_cast(s_topk_value[i]) / topk_sum * - routed_scaling_factor; - } else { - value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; - } + float base = cuda_cast(s_topk_value[i]); + float value = renormalize ? (base / topk_sum * routed_scaling_factor) + : (base * routed_scaling_factor); topk_indices[i] = s_topk_idx[i]; topk_values[i] = value; } @@ -700,39 +696,40 @@ __global__ void group_idx_and_topk_idx_kernel( #endif } -template +template inline void launch_group_idx_and_topk_kernel( cudaLaunchConfig_t const& config, T* scores, T* group_scores, float* topk_values, IdxT* topk_indices, T const* bias, int64_t const num_tokens, int64_t const n_group, int64_t const topk_group, int64_t const topk, int64_t const num_experts, - int64_t const num_experts_per_group, double const routed_scaling_factor) { + int64_t const num_experts_per_group, bool const renormalize, + double const routed_scaling_factor) { auto launch = [&](auto* kernel_instance2) { cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, topk_values, topk_indices, bias, num_tokens, n_group, topk_group, topk, num_experts, num_experts_per_group, - routed_scaling_factor); + renormalize, routed_scaling_factor); }; switch (n_group) { case 4: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 8: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 16: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } case 32: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } default: { - launch(&group_idx_and_topk_idx_kernel); + launch(&group_idx_and_topk_idx_kernel); break; } } @@ -795,31 +792,17 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values, config.attrs = attrs; switch (sf) { case SCORING_NONE: { - if (renormalize) { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, routed_scaling_factor); - } else { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, routed_scaling_factor); - } + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); break; } case SCORING_SIGMOID: { - if (renormalize) { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, routed_scaling_factor); - } else { - launch_group_idx_and_topk_kernel( - config, scores, group_scores, topk_values, topk_indices, bias, - num_tokens, n_group, topk_group, topk, num_experts, - num_experts_per_group, routed_scaling_factor); - } + launch_group_idx_and_topk_kernel( + config, scores, group_scores, topk_values, topk_indices, bias, + num_tokens, n_group, topk_group, topk, num_experts, + num_experts_per_group, renormalize, routed_scaling_factor); break; } default: