Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 128 additions & 47 deletions csrc/moe/grouped_topk_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -444,23 +444,27 @@ __device__ inline T apply_sigmoid(T val) {
return cuda_cast<T, float>(sigmoid_accurate(f));
}

template <typename T>
template <ScoringFunc SF, typename T>
__device__ inline T apply_scoring(T val) {
if constexpr (SF == SCORING_SIGMOID) {
return apply_sigmoid(val);
} else {
return val;
}
}

template <typename T, ScoringFunc SF>
__device__ void topk_with_k2(T* output, T const* input, T const* bias,
cg::thread_block_tile<32> const& tile,
int32_t const lane_id,
int const num_experts_per_group,
int const scoring_func) {
int const num_experts_per_group) {
// Get the top2 per thread
T largest = neg_inf<T>();
T second_largest = neg_inf<T>();

if (num_experts_per_group > WARP_SIZE) {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];

if (value > largest) {
Expand All @@ -472,11 +476,7 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
} else {
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
T value = input[i];
// Apply scoring function if needed
if (scoring_func == SCORING_SIGMOID) {
value = apply_sigmoid(value);
}
T value = apply_scoring<SF>(input[i]);
value = value + bias[i];
largest = value;
}
Expand All @@ -501,13 +501,12 @@ __device__ void topk_with_k2(T* output, T const* input, T const* bias,
}
}

template <typename T>
template <typename T, ScoringFunc SF>
__global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
int64_t const num_tokens,
int64_t const num_cases,
int64_t const n_group,
int64_t const num_experts_per_group,
int const scoring_func) {
int64_t const num_experts_per_group) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;

Expand All @@ -525,21 +524,21 @@ __global__ void topk_with_k2_kernel(T* output, T* input, T const* bias,
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
topk_with_k2(output, input, group_bias, tile, lane_id,
num_experts_per_group, scoring_func);
topk_with_k2<T, SF>(output, input, group_bias, tile, lane_id,
num_experts_per_group);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

template <typename T, typename IdxT>
template <typename T, typename IdxT, ScoringFunc SF, int NGroup = -1>
__global__ void group_idx_and_topk_idx_kernel(
T* scores, T const* group_scores, float* topk_values, IdxT* topk_indices,
T const* bias, int64_t const num_tokens, int64_t const n_group,
int64_t const topk_group, int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool renormalize,
double routed_scaling_factor, int scoring_func) {
double routed_scaling_factor) {
int32_t warp_id = threadIdx.x / WARP_SIZE;
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t case_id =
Expand All @@ -549,6 +548,11 @@ __global__ void group_idx_and_topk_idx_kernel(
topk_values += case_id * topk;
topk_indices += case_id * topk;

constexpr bool kUseStaticNGroup = (NGroup > 0);
// use int32 to avoid implicit conversion
int32_t const n_group_i32 =
kUseStaticNGroup ? NGroup : static_cast<int32_t>(n_group);

int32_t align_num_experts_per_group =
warp_topk::round_up_to_multiple_of<WARP_SIZE>(num_experts_per_group);

Expand All @@ -574,13 +578,14 @@ __global__ void group_idx_and_topk_idx_kernel(

if (case_id < num_tokens) {
// calculate group_idx
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
int32_t target_num_min =
WARP_SIZE - n_group_i32 + static_cast<int32_t>(topk_group);
// The check is necessary to avoid abnormal input
if (lane_id < n_group && is_finite(group_scores[lane_id])) {
if (lane_id < n_group_i32 && is_finite(group_scores[lane_id])) {
value = group_scores[lane_id];
}

int count_equal_to_top_value = WARP_SIZE - n_group;
int count_equal_to_top_value = WARP_SIZE - n_group_i32;
int pre_count_equal_to_top_value = 0;
// Use loop to find the largset top_group
while (count_equal_to_top_value < target_num_min) {
Expand All @@ -604,7 +609,7 @@ __global__ void group_idx_and_topk_idx_kernel(
int count_equalto_topkth_group = 0;
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
if (case_id < num_tokens && if_proceed_next_topk) {
for (int i_group = 0; i_group < n_group; i_group++) {
auto process_group = [&](int i_group) {
if ((group_scores[i_group] > topk_group_value) ||
((group_scores[i_group] == topk_group_value) &&
(count_equalto_topkth_group < num_equalto_topkth_group))) {
Expand All @@ -613,11 +618,10 @@ __global__ void group_idx_and_topk_idx_kernel(
i += WARP_SIZE) {
T candidates = neg_inf<T>();
if (i < num_experts_per_group) {
// Apply scoring function (if any) and add bias
// apply scoring function (if any) and add bias
T input = scores[offset + i];
if (is_finite(input)) {
T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input)
: input;
T score = apply_scoring<SF>(input);
candidates = score + bias[offset + i];
}
}
Expand All @@ -627,6 +631,17 @@ __global__ void group_idx_and_topk_idx_kernel(
count_equalto_topkth_group++;
}
}
};

if constexpr (kUseStaticNGroup) {
#pragma unroll
for (int i_group = 0; i_group < NGroup; ++i_group) {
process_group(i_group);
}
} else {
for (int i_group = 0; i_group < n_group_i32; ++i_group) {
process_group(i_group);
}
}
queue.done();
__syncwarp();
Expand All @@ -646,12 +661,13 @@ __global__ void group_idx_and_topk_idx_kernel(
if (i < topk) {
// Load the score value (without bias) for normalization
T input = scores[s_topk_idx[i]];
value =
(scoring_func == SCORING_SIGMOID) ? apply_sigmoid(input) : input;
value = apply_scoring<SF>(input);
s_topk_value[i] = value;
}
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
if (renormalize) {
topk_sum +=
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
}
}
}

Expand All @@ -660,13 +676,9 @@ __global__ void group_idx_and_topk_idx_kernel(
if (case_id < num_tokens) {
if (if_proceed_next_topk) {
for (int i = lane_id; i < topk; i += WARP_SIZE) {
float value;
if (renormalize) {
value = cuda_cast<float, T>(s_topk_value[i]) / topk_sum *
routed_scaling_factor;
} else {
value = cuda_cast<float, T>(s_topk_value[i]) * routed_scaling_factor;
}
float base = cuda_cast<float, T>(s_topk_value[i]);
float value = renormalize ? (base / topk_sum * routed_scaling_factor)
: (base * routed_scaling_factor);
topk_indices[i] = s_topk_idx[i];
topk_values[i] = value;
}
Expand All @@ -684,6 +696,45 @@ __global__ void group_idx_and_topk_idx_kernel(
#endif
}

template <typename T, typename IdxT, ScoringFunc SF>
inline void launch_group_idx_and_topk_kernel(
cudaLaunchConfig_t const& config, T* scores, T* group_scores,
float* topk_values, IdxT* topk_indices, T const* bias,
int64_t const num_tokens, int64_t const n_group, int64_t const topk_group,
int64_t const topk, int64_t const num_experts,
int64_t const num_experts_per_group, bool const renormalize,
double const routed_scaling_factor) {
auto launch = [&](auto* kernel_instance2) {
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts_per_group,
renormalize, routed_scaling_factor);
};

switch (n_group) {
case 4: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 4>);
break;
}
case 8: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 8>);
break;
}
case 16: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 16>);
break;
}
case 32: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF, 32>);
break;
}
default: {
launch(&group_idx_and_topk_idx_kernel<T, IdxT, SF>);
break;
}
}
}

template <typename T, typename IdxT>
void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
IdxT* topk_indices, T const* bias, int64_t const num_tokens,
Expand All @@ -694,7 +745,6 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
cudaStream_t const stream = 0) {
int64_t num_cases = num_tokens * n_group;
int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1;
auto* kernel_instance1 = &topk_with_k2_kernel<T>;
cudaLaunchConfig_t config;
config.gridDim = topk_with_k2_num_blocks;
config.blockDim = BLOCK_SIZE;
Expand All @@ -705,16 +755,33 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts / n_group,
scoring_func);
auto const sf = static_cast<ScoringFunc>(scoring_func);
int64_t const num_experts_per_group = num_experts / n_group;
auto launch_topk_with_k2 = [&](auto* kernel_instance1) {
cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores, bias,
num_tokens, num_cases, n_group, num_experts_per_group);
};
switch (sf) {
case SCORING_NONE: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_NONE>;
launch_topk_with_k2(kernel_instance1);
break;
}
case SCORING_SIGMOID: {
auto* kernel_instance1 = &topk_with_k2_kernel<T, SCORING_SIGMOID>;
launch_topk_with_k2(kernel_instance1);
break;
}
default:
// should be guarded by higher level checks.
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}

int64_t topk_with_k_group_num_blocks =
(num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1;
size_t dynamic_smem_in_bytes =
warp_topk::calc_smem_size_for_block_wide<T, int32_t>(NUM_WARPS_PER_BLOCK,
topk);
auto* kernel_instance2 = &group_idx_and_topk_idx_kernel<T, IdxT>;
config.gridDim = topk_with_k_group_num_blocks;
config.blockDim = BLOCK_SIZE;
config.dynamicSmemBytes = dynamic_smem_in_bytes;
Expand All @@ -723,10 +790,24 @@ void invokeNoAuxTc(T* scores, T* group_scores, float* topk_values,
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores,
topk_values, topk_indices, bias, num_tokens, n_group,
topk_group, topk, num_experts, num_experts / n_group,
renormalize, routed_scaling_factor, scoring_func);
switch (sf) {
case SCORING_NONE: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_NONE>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
case SCORING_SIGMOID: {
launch_group_idx_and_topk_kernel<T, IdxT, SCORING_SIGMOID>(
config, scores, group_scores, topk_values, topk_indices, bias,
num_tokens, n_group, topk_group, topk, num_experts,
num_experts_per_group, renormalize, routed_scaling_factor);
break;
}
default:
TORCH_CHECK(false, "Unsupported scoring_func in invokeNoAuxTc");
}
}

#define INSTANTIATE_NOAUX_TC(T, IdxT) \
Expand Down
Loading