diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 72cac8ec9da..18ca42ba54c 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -151,6 +151,7 @@ set(SOURCES "csrc/gemm/per_token_group_quant_8bit.cu" "csrc/gemm/per_token_quant_fp8.cu" "csrc/moe/moe_align_kernel.cu" + "csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_topk_softmax_kernels.cu" "csrc/speculative/eagle_utils.cu" "csrc/speculative/speculative_sampling.cu" diff --git a/sgl-kernel/benchmark/bench_moe_fused_gate.py b/sgl-kernel/benchmark/bench_moe_fused_gate.py new file mode 100644 index 00000000000..2405c49b6c9 --- /dev/null +++ b/sgl-kernel/benchmark/bench_moe_fused_gate.py @@ -0,0 +1,74 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import moe_fused_gate + +from sglang.srt.layers.moe.topk import biased_grouped_topk + + +def biased_grouped_topk_org(scores, bias, num_expert_group, topk_group, topk): + return biased_grouped_topk( + scores, + scores, + bias, + topk=topk, + renormalize=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + +def biased_grouped_topk_org_kernel(scores, bias, num_expert_group, topk_group, topk): + return moe_fused_gate(scores, bias, num_expert_group, topk_group, topk) + + +seq_length_range = [5000, 10000, 15000, 20000, 25000, 30000, 35000, 40000] +configs = [(sq,) for sq in seq_length_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_length"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["original", "kernel"], + line_names=["Original", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="moe-fused-gate-performance", + args={}, + ) +) +def benchmark(seq_length, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8 + + scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype) + bias = torch.rand(num_experts, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "original": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: biased_grouped_topk_org( + scores.clone(), bias.clone(), num_expert_group, topk_group, topk + ), + quantiles=quantiles, + ) + elif provider == "kernel": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: biased_grouped_topk_org_kernel( + scores.clone(), bias.clone(), num_expert_group, topk_group, topk + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/sgl-kernel/csrc/moe/moe_fused_gate.cu b/sgl-kernel/csrc/moe/moe_fused_gate.cu new file mode 100644 index 00000000000..c8aa4811f30 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_fused_gate.cu @@ -0,0 +1,447 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +template +using AlignedArray = cutlass::AlignedArray; +using bfloat16_t = cutlass::bfloat16_t; +using float16_t = cutlass::half_t; +using float32_t = float; + +// QQ NOTE: to handle the case for at::Half, error: more than one operator ">" matches these operands: built-in operator +// "arithmetic > arithmetic" function "operator>(const __half &, const __half &)" +template +__device__ inline bool cmp_gt(const T& a, const T& b) { + if constexpr (std::is_same::value) { + // at::Half (or float16_t in our native case) causes ambiguity, so we cast to float. + return static_cast(a) > static_cast(b); + } else { + // For types like float, at::BFloat16, or cutlass::half_t / cutlass::bfloat16_t, assume operator> works as expected. + return a > b; + } +} + +template +__device__ inline bool cmp_eq(const T& a, const T& b) { + if constexpr (std::is_same::value) { + return static_cast(a) == static_cast(b); + } else { + return a == b; + } +} + +// Fixed constants common to both dynamic and static template versions: +static constexpr int WARP_SIZE = 32; +static constexpr int WARPS_PER_CTA = 6; +static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group + +// Create an alias for Array using AlignedArray +template +using Array = AlignedArray; +// QQ: NOTE expression must have a constant value, this has to be > params.VPT +template +using AccessType = AlignedArray; + +template +__device__ void moe_fused_gate_impl( + void* input, + void* bias, + float* output_ptr, + int32_t* indices_ptr, + int64_t num_rows, + int64_t topk_group, + int64_t topk, + Params params) { + int tidx = threadIdx.x; + int64_t thread_row = + blockIdx.x * params.ROWS_PER_CTA + threadIdx.y * params.ROWS_PER_WARP + tidx / params.THREADS_PER_ROW; + if (thread_row >= num_rows) { + return; + } + + // Cast pointers to type T: + auto* input_ptr = reinterpret_cast(input); + auto* bias_ptr = reinterpret_cast(bias); + auto* thread_row_ptr = input_ptr + thread_row * params.NUM_EXPERTS; + + int thread_group_idx = tidx % params.THREADS_PER_ROW; + int first_elt_read_by_thread = thread_group_idx * params.VPT; + + // Create local arrays for the row chunk and bias chunk and then reinterpret the address of row_chunk as a pointer to + // AccessType. + T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + Array row_chunk; + AccessType const* vec_thread_read_ptr = reinterpret_cast const*>(thread_read_ptr); + + T* bias_thread_read_ptr = bias_ptr + first_elt_read_by_thread; + Array bias_chunk; + AccessType const* vec_bias_thread_read_ptr = reinterpret_cast const*>(bias_thread_read_ptr); + +// QQ NOTE: doing the follow will be slower than loop assign and more importantly +// have misaligned address issue when params.VPT < 8 and mismatch with MAX_VPT +// AccessType* row_chunk_vec_ptr = reinterpret_cast*>(&row_chunk); +// row_chunk_vec_ptr[0] = vec_thread_read_ptr[0]; +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + row_chunk[ii] = vec_thread_read_ptr[0][ii]; + bias_chunk[ii] = vec_bias_thread_read_ptr[0][ii]; + } + + __syncthreads(); + +////////////////////// Sigmoid ////////////////////// +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + row_chunk[ii] = static_cast(1.0f / (1.0f + expf(-float(row_chunk[ii])))); + } + __syncthreads(); + +////////////////////// Add Bias ////////////////////// +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + bias_chunk[ii] = row_chunk[ii] + bias_chunk[ii]; + } + +////////////////////// Exclude Groups ////////////////////// +#pragma unroll + for (int k_idx = 0; k_idx < params.THREADS_PER_ROW - topk_group; + ++k_idx) { // QQ NOTE Here params.THREADS_PER_ROW = num_expert_group + int expert = first_elt_read_by_thread; + // local argmax + T max_val = static_cast(-FLT_MAX); + T max_val_second = static_cast(-FLT_MAX); +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + T val = bias_chunk[ii]; + + if (cmp_gt(val, max_val)) { + max_val_second = max_val; + max_val = val; + } else if (cmp_gt(val, max_val_second)) { + max_val_second = val; + } + } + + // QQ NOTE: currently fixed to pick top2 sigmoid weight value in each expert group and sum them as the group weight + // to select expert groups + T max_sum = max_val + max_val_second; + +// argmin reduce +#pragma unroll + for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + T other_max_sum = + static_cast(__shfl_xor_sync(0xFFFFFFFF, static_cast(max_sum), mask, params.THREADS_PER_ROW)); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW); + + // higher indices win + if (cmp_gt(max_sum, other_max_sum) || (cmp_eq(other_max_sum, max_sum) && other_expert > expert)) { + max_sum = other_max_sum; + expert = other_expert; + } + } + + // clear the max value in the thread + if (k_idx < params.THREADS_PER_ROW - topk_group) { + int const thread_to_clear_in_group = expert / params.VPT; + + if (thread_group_idx == thread_to_clear_in_group) { +#pragma unroll + for (int ii = 0; ii < params.VPT; ++ii) { + bias_chunk[ii] = static_cast(FLT_MAX); + } + } + } + } + + __syncthreads(); + + ////////////////////// Topk ////////////////////// + float output_sum = 0.0f; + for (int k_idx = 0; k_idx < topk; ++k_idx) { + // local argmax + T max_val = bias_chunk[0]; + int expert = first_elt_read_by_thread; + + if (!cmp_eq(max_val, static_cast(FLT_MAX))) { +#pragma unroll + for (int ii = 1; ii < params.VPT; ++ii) { + T val = bias_chunk[ii]; + if (cmp_gt(val, max_val)) { + max_val = val; + expert = first_elt_read_by_thread + ii; + } + } + } else { + max_val = static_cast(-FLT_MAX); + } + +// argmax reduce +#pragma unroll + for (int mask = params.THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + T other_max = + static_cast(__shfl_xor_sync(0xFFFFFFFF, static_cast(max_val), mask, params.THREADS_PER_ROW)); + int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, params.THREADS_PER_ROW); + + // lower indices to win + if (cmp_gt(other_max, max_val) || (cmp_eq(other_max, max_val) && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + if (k_idx < topk) { + int thread_to_clear_in_group = expert / params.VPT; + int64_t idx = topk * thread_row + k_idx; + + if (thread_group_idx == thread_to_clear_in_group) { + int expert_to_clear_in_thread = expert % params.VPT; + + // clear the max value in the thread + bias_chunk[expert_to_clear_in_thread] = static_cast(-FLT_MAX); + + // store output + output_ptr[idx] = static_cast(row_chunk[expert_to_clear_in_thread]); + indices_ptr[idx] = static_cast(expert); + } + + // accumulate sum + if (thread_group_idx == 0) { + output_sum += output_ptr[idx]; + } + } + + __syncthreads(); + } + + ////////////////////// Rescale Output ////////////////////// + if (thread_group_idx == 0) { +#pragma unroll + for (int ii = 0; ii < topk; ++ii) { + int64_t const idx = topk * thread_row + ii; + output_ptr[idx] = static_cast(static_cast(output_ptr[idx]) / static_cast(output_sum)); + } + } +} + +//------------------------------------------------------------------------------ +// Templated Kernel Version (using compile-time constants) +//------------------------------------------------------------------------------ +template +struct KernelParams { + static constexpr int VPT = VPT_; + static constexpr int NUM_EXPERTS = NUM_EXPERTS_; + static constexpr int THREADS_PER_ROW = THREADS_PER_ROW_; + static constexpr int ROWS_PER_WARP = ROWS_PER_WARP_; + static constexpr int ROWS_PER_CTA = ROWS_PER_CTA_; + static constexpr int WARPS_PER_CTA = WARPS_PER_CTA_; +}; + +template < + typename T, + int VPT, + int NUM_EXPERTS, + int THREADS_PER_ROW, + int ROWS_PER_WARP, + int ROWS_PER_CTA, + int WARPS_PER_CTA> +__global__ void moe_fused_gate_kernel( + void* input, + void* bias, + float* output_ptr, + int32_t* indices_ptr, + int64_t num_rows, + int64_t topk_group, + int64_t topk) { + KernelParams params; + moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); +} + +// Macro to compute compile-time constants and launch the kernel. +#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \ + do { \ + constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \ + /* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \ + constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \ + constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \ + moe_fused_gate_kernel \ + <<>>( \ + input.data_ptr(), \ + bias.data_ptr(), \ + output.data_ptr(), \ + indices.data_ptr(), \ + num_rows, \ + topk_group, \ + topk); \ + dispatched = true; \ + } while (0) + +//------------------------------------------------------------------------------ +// Dynamic Kernel Version (parameters computed at runtime) +//------------------------------------------------------------------------------ +struct KernelParamsDynamic { + int VPT; + int NUM_EXPERTS; + int THREADS_PER_ROW; + int ROWS_PER_WARP; + int ROWS_PER_CTA; + int WARPS_PER_CTA; +}; + +template +__global__ void moe_fused_gate_kernel_dynamic( + void* input, + void* bias, + float* output_ptr, + int32_t* indices_ptr, + int64_t num_rows, + int64_t num_experts, + int64_t num_expert_group, + int64_t topk_group, + int64_t topk) { + KernelParamsDynamic params; + params.NUM_EXPERTS = num_experts; // e.g, for deepseek v3, this is 256 + params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32 + params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8 + params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6 + params.ROWS_PER_WARP = std::max(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32 + params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP; + + moe_fused_gate_impl(input, bias, output_ptr, indices_ptr, num_rows, topk_group, topk, params); +} + +//------------------------------------------------------------------------------ +// Host Launcher Function +//------------------------------------------------------------------------------ +std::vector +moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk) { + int64_t num_rows = input.size(0); + int32_t num_experts = input.size(1); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto output = torch::empty({num_rows, topk}, options); + auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32)); + + // Compute grid dimensions based on runtime value for num_expert_group. + int64_t rows_per_warp = std::max(1, WARP_SIZE / num_expert_group); + int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp; + int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + dim3 block_dim(WARP_SIZE, WARPS_PER_CTA); + + // Check 1: Ensure that num_experts is a power of 2. + TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts); + + // Check 2: Ensure that num_experts is divisible by num_expert_group. (this also means num_expert_group is power of 2) + TORCH_CHECK( + num_experts % num_expert_group == 0, + "num_experts must be divisible by num_expert_group, but got ", + num_experts, + " / ", + num_expert_group); + + int computed_vpt = num_experts / num_expert_group; + // Check 3: Ensure that num_experts/num_expert_group does not exceed MAX_VPT=32. Maximum VPT indicate max value per + // threads we can process. + TORCH_CHECK( + computed_vpt <= MAX_VPT, + "Per group experts: num_experts / num_expert_group = (", + computed_vpt, + ") exceeds the maximum supported (", + MAX_VPT, + ")"); + + // Dispatch to templated kernel for known compile-time configurations. + // We currently only support for: + // Case 1: 256 experts, with 8 or 16 groups. + // Case 2: 128 experts, with 4 or 8 groups. + // Case 3: other cases, require 8 <= num_experts / num_expert_group <= 32 + bool dispatched = false; + switch (num_experts) { + case 256: + if (num_expert_group == 8) + // This is deepseek v3 case. Here VPT = 256/8 = 32, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 8); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 8); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 8); + } else if (num_expert_group == 16) + // Here VPT = 256/16 = 16, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 256, 16); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 256, 16); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 256, 16); + } + break; + case 128: + if (num_expert_group == 4) + // VPT = 128/4 = 32, ROWS_PER_WARP = 32/16 = 2, ROWS_PER_CTA = 6 * 2 = 12. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 4); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 4); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 4); + } else if (num_expert_group == 8) + // VPT = 128/8 = 16, ROWS_PER_WARP = 32/8 = 4, ROWS_PER_CTA = 6 * 4 = 24. + if (input.scalar_type() == at::kBFloat16) { + LAUNCH_MOE_GATE_CONFIG(bfloat16_t, 128, 8); + } else if (input.scalar_type() == at::kHalf) { + LAUNCH_MOE_GATE_CONFIG(float16_t, 128, 8); + } else if (input.scalar_type() == at::kFloat) { + LAUNCH_MOE_GATE_CONFIG(float32_t, 128, 8); + } + break; + default: + break; + } + if (!dispatched) { + // Fallback to the dynamic kernel if none of the supported combinations match. + // currently only support num_experts / num_expert_group <= 32 for dynamic kernels + if (input.scalar_type() == at::kBFloat16) { + moe_fused_gate_kernel_dynamic<<>>( + input.data_ptr(), + bias.data_ptr(), + output.data_ptr(), + indices.data_ptr(), + num_rows, + num_experts, + num_expert_group, + topk_group, + topk); + } else if (input.scalar_type() == at::kHalf) { + moe_fused_gate_kernel_dynamic<<>>( + input.data_ptr(), + bias.data_ptr(), + output.data_ptr(), + indices.data_ptr(), + num_rows, + num_experts, + num_expert_group, + topk_group, + topk); + } else if (input.scalar_type() == at::kFloat) { + moe_fused_gate_kernel_dynamic<<>>( + input.data_ptr(), + bias.data_ptr(), + output.data_ptr(), + indices.data_ptr(), + num_rows, + num_experts, + num_expert_group, + topk_group, + topk); + } else { + TORCH_CHECK(false, "Unsupported data type for moe_fused_gate"); + } + } + return {output, indices}; +} diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 16d9adb12f2..f2d5eba73bd 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -138,6 +138,11 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + m.def( + "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> " + "(Tensor[])"); + m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); + /* * From csrc/speculative */ diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index f4961ab4fee..847b24ebe58 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -199,6 +199,9 @@ void topk_softmax( torch::Tensor& token_expert_indices, torch::Tensor& gating_output); +std::vector +moe_fused_gate(at::Tensor& input, at::Tensor& bias, int64_t num_expert_group, int64_t topk_group, int64_t topk); + /* * From csrc/speculative */ diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 1a668fee4f9..789cbcedf52 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -36,7 +36,7 @@ sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, ) -from sgl_kernel.moe import moe_align_block_size, topk_softmax +from sgl_kernel.moe import moe_align_block_size, moe_fused_gate, topk_softmax from sgl_kernel.sampling import ( min_p_sampling_from_probs, top_k_renorm_prob, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index 84c79fdb782..24066424f61 100644 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -32,3 +32,15 @@ def topk_softmax( torch.ops.sgl_kernel.topk_softmax.default( topk_weights, topk_ids, token_expert_indices, gating_output ) + + +def moe_fused_gate(input_tensor, bias, num_expert_group, topk_group, topk): + # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion + # it split group of expert into num_expert_group, and use top2 expert weight sum in each group + # as the group weight to select exerpt groups and then select topk experts within the selected groups + # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts + # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now. + # for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk + return torch.ops.sgl_kernel.moe_fused_gate( + input_tensor, bias, num_expert_group, topk_group, topk + ) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 7b21ba59390..f516758c14d 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -161,6 +161,7 @@ def copy_deepgemm_to_build_lib(self): "csrc/gemm/per_token_quant_fp8.cu", "csrc/gemm/per_tensor_quant_fp8.cu", "csrc/moe/moe_align_kernel.cu", + "csrc/moe/moe_fused_gate.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/speculative/eagle_utils.cu", "csrc/speculative/speculative_sampling.cu", diff --git a/sgl-kernel/tests/test_moe_fused_gate.py b/sgl-kernel/tests/test_moe_fused_gate.py new file mode 100644 index 00000000000..3d6221bf426 --- /dev/null +++ b/sgl-kernel/tests/test_moe_fused_gate.py @@ -0,0 +1,72 @@ +import pytest +import torch +from sgl_kernel import moe_fused_gate + +from sglang.srt.layers.moe.topk import biased_grouped_topk + + +@pytest.mark.parametrize( + "seq_length", + list(range(1, 10)) + + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.bfloat16]) +@pytest.mark.parametrize( + "params", + [ + (128, 4, 2, 4), + (256, 8, 4, 8), # deepseek v3 + (512, 16, 8, 16), + ], +) +def test_moe_fused_gate_combined(seq_length, dtype, params): + num_experts, num_expert_group, topk_group, topk = params + + torch.manual_seed(seq_length) + tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda() + scores = tensor.clone() + bias = torch.rand(num_experts).to(dtype).cuda() + + output, indices = moe_fused_gate( + tensor, + bias, + num_expert_group=num_expert_group, + topk_group=topk_group, + topk=topk, + ) + ref_output, ref_indices = biased_grouped_topk( + scores, + scores, + bias, + topk=topk, + renormalize=True, + num_expert_group=num_expert_group, + topk_group=topk_group, + compiled=False, + ) + + idx_check = torch.allclose( + ref_indices.sort()[0].to(torch.int32), + indices.sort()[0].to(torch.int32), + rtol=1e-04, + atol=1e-05, + ) + output_check = torch.allclose( + ref_output.sort()[0].to(torch.float32), + output.sort()[0].to(torch.float32), + rtol=1e-04, + atol=1e-05, + ) + + assert idx_check, ( + f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, " + f"params {params}" + ) + assert output_check, ( + f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " + f"params {params}" + ) + + +if __name__ == "__main__": + pytest.main([__file__])