diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 7c394a78f03b..9494025314af 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -319,6 +319,7 @@ set(SOURCES "csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_fused_gate.cu" + "csrc/moe/kimi_k2_moe_fused_gate.cu" "csrc/moe/moe_sum.cu" "csrc/moe/moe_sum_reduce.cu" "csrc/moe/moe_topk_softmax_kernels.cu" diff --git a/sgl-kernel/benchmark/bench_kimi_k2_moe_fused_gate.py b/sgl-kernel/benchmark/bench_kimi_k2_moe_fused_gate.py new file mode 100644 index 000000000000..78b75231c3c2 --- /dev/null +++ b/sgl-kernel/benchmark/bench_kimi_k2_moe_fused_gate.py @@ -0,0 +1,117 @@ +import itertools +import math +import os + +import torch +import triton +import triton.language as tl +from sgl_kernel import kimi_k2_moe_fused_gate + +from sglang.srt.layers.moe.topk import kimi_k2_biased_topk_impl + +# CI environment detection +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + + +def kimi_k2_biased_topk_torch_compile(scores, bias, topk, routed_scaling_factor): + """Original torch.compile-based implementation""" + return kimi_k2_biased_topk_impl( + scores, + scores, + bias, + topk=topk, + renormalize=True, + routed_scaling_factor=routed_scaling_factor, + ) + + +def kimi_k2_biased_topk_fused_kernel(scores, bias, topk, routed_scaling_factor): + """Our fused CUDA kernel implementation""" + return kimi_k2_moe_fused_gate( + scores, + bias, + topk=topk, + renormalize=True, + routed_scaling_factor=routed_scaling_factor, + ) + + +# CI environment uses simplified parameters +if IS_CI: + seq_length_range = [5000] # Only test one sequence length in CI +else: + seq_length_range = [ + 1, + 8, + 16, + 32, + 64, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 10000, + 15000, + 20000, + 25000, + 30000, + 35000, + 40000, + ] + +configs = [(sq,) for sq in seq_length_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["seq_length"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["torch_compile", "fused_kernel"], + line_names=["Torch Compile", "Fused Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="kimi-k2-moe-fused-gate-performance", + args={}, + ) +) +def benchmark(seq_length, provider): + dtype = torch.float32 + device = torch.device("cuda") + num_experts, topk = 384, 6 # Kimi K2 configuration + routed_scaling_factor = 2.872 # Kimi K2's routed scaling factor + + scores = torch.randn((seq_length, num_experts), device=device, dtype=dtype) + bias = torch.rand(num_experts, device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch_compile": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: kimi_k2_biased_topk_torch_compile( + scores.clone(), bias.clone(), topk, routed_scaling_factor + ), + quantiles=quantiles, + ) + elif provider == "fused_kernel": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: kimi_k2_biased_topk_fused_kernel( + scores.clone(), bias.clone(), topk, routed_scaling_factor + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + print("=" * 80) + print("Benchmarking Kimi K2 MoE Fused Gate Performance") + print("=" * 80) + print("\nPerformance vs Sequence Length (384 experts, topk=6)") + benchmark.run(print_data=True, save_path=".") diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 03a7ec0151f4..a40de3e249ff 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -242,6 +242,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "(Tensor[])"); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); + m.def( + "kimi_k2_moe_fused_gate(Tensor input, Tensor bias, int topk, bool renormalize, " + "float routed_scaling_factor, bool apply_routed_scaling_factor_on_output) -> " + "(Tensor[])"); + m.impl("kimi_k2_moe_fused_gate", torch::kCUDA, &kimi_k2_moe_fused_gate); + m.def( "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor " "a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " diff --git a/sgl-kernel/csrc/moe/kimi_k2_moe_fused_gate.cu b/sgl-kernel/csrc/moe/kimi_k2_moe_fused_gate.cu new file mode 100644 index 000000000000..d4ceedab47a3 --- /dev/null +++ b/sgl-kernel/csrc/moe/kimi_k2_moe_fused_gate.cu @@ -0,0 +1,354 @@ +#include +#include +#include +#include +#include +#include + +#include + +using bfloat16_t = cutlass::bfloat16_t; +using float16_t = cutlass::half_t; + +// Kimi K2 specific constants +static constexpr int WARP_SIZE = 32; +static constexpr int WARPS_PER_CTA = 6; +static constexpr int NUM_EXPERTS = 384; +static constexpr int VPT = 12; // 384 / 32 = 12 + +// Small token optimization constants +static constexpr int SMALL_TOKEN_THRESHOLD = 512; +static constexpr int WARPS_PER_TOKEN_SMALL = 12; // Use 12 warps per token for small batches +static constexpr int THREADS_PER_BLOCK_SMALL = WARPS_PER_TOKEN_SMALL * WARP_SIZE; // 384 threads + +template +__device__ inline bool cmp_gt(const T& a, const T& b) { + return static_cast(a) > static_cast(b); +} + +template +__device__ inline bool cmp_eq(const T& a, const T& b) { + return static_cast(a) == static_cast(b); +} + +// Small token optimized kernel: Multiple warps collaborate on a single token +template +__global__ void kimi_k2_moe_fused_gate_kernel_small_token( + T* input, + T* bias, + float* output_ptr, + int32_t* indices_ptr, + int64_t num_rows, + int64_t topk, + bool renormalize, + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + // Each block handles one token with WARPS_PER_TOKEN_SMALL warps collaborating + int64_t row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + int tid = threadIdx.x; + int warp_id = tid / WARP_SIZE; + int lane_id = tid % WARP_SIZE; + + // Shared memory for all warps to collaborate + __shared__ float shared_scores[NUM_EXPERTS]; + __shared__ float shared_original_scores[NUM_EXPERTS]; + + // Each thread loads one expert (384 threads for 384 experts) + if (tid < NUM_EXPERTS) { + T input_val = input[row_idx * NUM_EXPERTS + tid]; + T bias_val = bias[tid]; + float sigmoid_val = 1.0f / (1.0f + expf(-static_cast(input_val))); + float biased_val = sigmoid_val + static_cast(bias_val); + shared_scores[tid] = biased_val; + shared_original_scores[tid] = sigmoid_val; + } + + __syncthreads(); + + // Parallel TopK: Each warp processes a portion of experts + // Use multiple warps to find top-k elements in parallel + int experts_per_warp = (NUM_EXPERTS + WARPS_PER_TOKEN_SMALL - 1) / WARPS_PER_TOKEN_SMALL; + int warp_start = warp_id * experts_per_warp; + int warp_end = min(warp_start + experts_per_warp, NUM_EXPERTS); + + for (int k = 0; k < topk; k++) { + float max_val = -FLT_MAX; + int max_expert = -1; + + // Each warp finds the max in its portion + for (int expert = warp_start + lane_id; expert < warp_end; expert += WARP_SIZE) { + float val = shared_scores[expert]; + if (val > max_val) { + max_val = val; + max_expert = expert; + } + } + + // Warp-level reduction to find warp's maximum + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, max_expert, offset); + + if (other_val > max_val || (other_val == max_val && other_expert < max_expert)) { + max_val = other_val; + max_expert = other_expert; + } + } + + // Store warp results in shared memory + __shared__ float warp_max_vals[WARPS_PER_TOKEN_SMALL]; + __shared__ int warp_max_experts[WARPS_PER_TOKEN_SMALL]; + + if (lane_id == 0) { + warp_max_vals[warp_id] = max_val; + warp_max_experts[warp_id] = max_expert; + } + + __syncthreads(); + + // First warp reduces across all warp results + if (warp_id == 0) { + float final_max_val = -FLT_MAX; + int final_max_expert = -1; + + if (lane_id < WARPS_PER_TOKEN_SMALL) { + final_max_val = warp_max_vals[lane_id]; + final_max_expert = warp_max_experts[lane_id]; + } + + // Warp reduction + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, final_max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, final_max_expert, offset); + + if (other_val > final_max_val || (other_val == final_max_val && other_expert < final_max_expert)) { + final_max_val = other_val; + final_max_expert = other_expert; + } + } + + // Lane 0 writes result and marks the expert as used + if (lane_id == 0 && final_max_expert != -1) { + int64_t output_idx = row_idx * topk + k; + output_ptr[output_idx] = shared_original_scores[final_max_expert]; + indices_ptr[output_idx] = final_max_expert; + shared_scores[final_max_expert] = -FLT_MAX; + } + } + + __syncthreads(); + } + + // Renormalization (only first warp) + if (renormalize && warp_id == 0 && lane_id == 0) { + float sum = 0.0f; + for (int k = 0; k < topk; k++) { + sum += output_ptr[row_idx * topk + k]; + } + + if (sum > 0.0f) { + for (int k = 0; k < topk; k++) { + int64_t idx = row_idx * topk + k; + output_ptr[idx] /= sum; + if (apply_routed_scaling_factor_on_output) { + output_ptr[idx] *= static_cast(routed_scaling_factor); + } + } + } + } +} + +template +__global__ void kimi_k2_moe_fused_gate_kernel( + T* input, + T* bias, + float* output_ptr, + int32_t* indices_ptr, + int64_t num_rows, + int64_t topk, + bool renormalize, + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + int64_t row_idx = blockIdx.x * WARPS_PER_CTA + threadIdx.y; + if (row_idx >= num_rows) return; + + int lane_id = threadIdx.x; + int warp_id = threadIdx.y; + + __shared__ float shared_scores[NUM_EXPERTS * WARPS_PER_CTA]; + __shared__ float shared_original_scores[NUM_EXPERTS * WARPS_PER_CTA]; + + float* warp_scores = shared_scores + warp_id * NUM_EXPERTS; + float* warp_original_scores = shared_original_scores + warp_id * NUM_EXPERTS; + + for (int expert = lane_id; expert < NUM_EXPERTS; expert += WARP_SIZE) { + T input_val = input[row_idx * NUM_EXPERTS + expert]; + T bias_val = bias[expert]; + float sigmoid_val = 1.0f / (1.0f + expf(-static_cast(input_val))); + float biased_val = sigmoid_val + static_cast(bias_val); + warp_scores[expert] = biased_val; + warp_original_scores[expert] = sigmoid_val; + } + + __syncthreads(); + + for (int k = 0; k < topk; k++) { + float max_val = -FLT_MAX; + int max_expert = -1; + + for (int expert = lane_id; expert < NUM_EXPERTS; expert += WARP_SIZE) { + if (warp_scores[expert] > max_val) { + max_val = warp_scores[expert]; + max_expert = expert; + } + } + + for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { + float other_val = __shfl_down_sync(0xFFFFFFFF, max_val, offset); + int other_expert = __shfl_down_sync(0xFFFFFFFF, max_expert, offset); + + if (other_val > max_val || (other_val == max_val && other_expert < max_expert)) { + max_val = other_val; + max_expert = other_expert; + } + } + + if (lane_id == 0 && max_expert != -1) { + int64_t output_idx = row_idx * topk + k; + output_ptr[output_idx] = warp_original_scores[max_expert]; + indices_ptr[output_idx] = max_expert; + warp_scores[max_expert] = -FLT_MAX; + } + + __syncwarp(); + } + + __syncthreads(); + + if (renormalize && lane_id == 0) { + float sum = 0.0f; + for (int k = 0; k < topk; k++) { + sum += output_ptr[row_idx * topk + k]; + } + + if (sum > 0.0f) { + for (int k = 0; k < topk; k++) { + int64_t idx = row_idx * topk + k; + output_ptr[idx] /= sum; + if (apply_routed_scaling_factor_on_output) { + output_ptr[idx] *= static_cast(routed_scaling_factor); + } + } + } + } +} + +std::vector kimi_k2_moe_fused_gate( + at::Tensor& input, + at::Tensor& bias, + int64_t topk, + bool renormalize, + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output) { + int64_t num_rows = input.size(0); + int32_t num_experts = input.size(1); + + // Assert: Only support 384 experts + TORCH_CHECK(num_experts == 384, "kimi_k2_moe_fused_gate only supports 384 experts, but got ", num_experts); + TORCH_CHECK(input.dtype() == bias.dtype(), "input and bias should have the same dtype"); + + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + auto output = torch::empty({num_rows, topk}, options); + auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32)); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + bool use_small_token_kernel = num_rows <= SMALL_TOKEN_THRESHOLD; + + if (use_small_token_kernel) { + // Small token kernel: Each block handles 1 token with multiple warps collaborating + int64_t num_blocks = num_rows; + dim3 block_dim(THREADS_PER_BLOCK_SMALL); + + if (input.scalar_type() == at::kBFloat16) { + kimi_k2_moe_fused_gate_kernel_small_token<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(bias.data_ptr()), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kHalf) { + kimi_k2_moe_fused_gate_kernel_small_token<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(bias.data_ptr()), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kFloat) { + kimi_k2_moe_fused_gate_kernel_small_token<<>>( + input.data_ptr(), + bias.data_ptr(), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else { + TORCH_CHECK(false, "Unsupported data type for kimi_k2_moe_fused_gate"); + } + } else { + int64_t num_blocks = (num_rows + WARPS_PER_CTA - 1) / WARPS_PER_CTA; + dim3 block_dim(WARP_SIZE, WARPS_PER_CTA); + + if (input.scalar_type() == at::kBFloat16) { + kimi_k2_moe_fused_gate_kernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(bias.data_ptr()), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kHalf) { + kimi_k2_moe_fused_gate_kernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(bias.data_ptr()), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else if (input.scalar_type() == at::kFloat) { + kimi_k2_moe_fused_gate_kernel<<>>( + input.data_ptr(), + bias.data_ptr(), + output.data_ptr(), + indices.data_ptr(), + num_rows, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output); + } else { + TORCH_CHECK(false, "Unsupported data type for kimi_k2_moe_fused_gate"); + } + } + + return {output, indices}; +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index a3240e783d7d..c7b6388c688c 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -331,6 +331,14 @@ std::vector moe_fused_gate( double routed_scaling_factor, bool apply_routed_scaling_factor_on_output); +std::vector kimi_k2_moe_fused_gate( + at::Tensor& input, + at::Tensor& bias, + int64_t topk, + bool renormalize, + double routed_scaling_factor, + bool apply_routed_scaling_factor_on_output); + void fp8_blockwise_scaled_grouped_mm( torch::Tensor& output, torch::Tensor& a_ptrs, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 44fe6d45c8df..a827a504e6df 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -85,6 +85,7 @@ apply_shuffle_mul_sum, cutlass_fp4_group_mm, fp8_blockwise_scaled_grouped_mm, + kimi_k2_moe_fused_gate, moe_align_block_size, moe_fused_gate, moe_sum, diff --git a/sgl-kernel/python/sgl_kernel/moe.py b/sgl-kernel/python/sgl_kernel/moe.py index ff5021eee7d9..4f849451b9a4 100755 --- a/sgl-kernel/python/sgl_kernel/moe.py +++ b/sgl-kernel/python/sgl_kernel/moe.py @@ -111,6 +111,41 @@ def moe_fused_gate( ) +def kimi_k2_moe_fused_gate( + input_tensor, + bias, + topk, + renormalize=True, + routed_scaling_factor=1.0, + apply_routed_scaling_factor_on_output=False, +): + """ + Simplified fused kernel for Kimi K2 model (num_expert_group=1). + This kernel removes the grouped topk logic since all experts belong to a single group. + + Args: + input_tensor: Gating output tensor [num_tokens, num_experts] + bias: Correction bias tensor [num_experts] + topk: Number of experts to select per token + renormalize: Whether to renormalize the topk weights + routed_scaling_factor: Scaling factor for expert weights + apply_routed_scaling_factor_on_output: If true, apply scaling factor to output + + Returns: + Tuple of (topk_weights, topk_ids) + - topk_weights: [num_tokens, topk] float32 tensor + - topk_ids: [num_tokens, topk] int32 tensor + """ + return torch.ops.sgl_kernel.kimi_k2_moe_fused_gate.default( + input_tensor, + bias, + topk, + renormalize, + routed_scaling_factor, + apply_routed_scaling_factor_on_output, + ) + + def fp8_blockwise_scaled_grouped_mm( output, a_ptrs, diff --git a/sgl-kernel/tests/test_kimi_k2_moe_fused_gate.py b/sgl-kernel/tests/test_kimi_k2_moe_fused_gate.py new file mode 100644 index 000000000000..f96312a19b19 --- /dev/null +++ b/sgl-kernel/tests/test_kimi_k2_moe_fused_gate.py @@ -0,0 +1,124 @@ +import pytest +import torch +from sgl_kernel import kimi_k2_moe_fused_gate + +from sglang.srt.layers.moe.topk import kimi_k2_biased_topk_impl + + +@pytest.mark.parametrize( + "seq_length", + list(range(1, 10)) + + [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536], +) +@pytest.mark.parametrize("topk", [6]) # Kimi K2 uses topk=6 +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("apply_routed_scaling_factor_on_output", [False, True]) +def test_kimi_k2_moe_fused_gate( + seq_length, topk, dtype, apply_routed_scaling_factor_on_output +): + num_experts = 384 # Kimi K2: only support 384 experts + renormalize = True + routed_scaling_factor = 2.872 # Kimi K2's routed scaling factor + + torch.manual_seed(seq_length) + tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda") + scores = tensor.clone() + bias = torch.rand(num_experts, dtype=dtype, device="cuda") + + # Test our fused kernel + output, indices = kimi_k2_moe_fused_gate( + tensor, + bias, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + + # Reference implementation + ref_output, ref_indices = kimi_k2_biased_topk_impl( + scores, + scores, + bias, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output, + ) + + # Check weights match (after sorting) + # Weights are the most important - they determine the actual MoE output + output_check = torch.allclose( + ref_output.sort()[0].to(torch.float32), + output.sort()[0].to(torch.float32), + rtol=1e-02, + atol=1e-03, + ) + + assert output_check, ( + f"Output mismatch at seq_length {seq_length}, dtype {dtype}, " + f"num_experts {num_experts}, topk {topk}, " + f"apply_routed_scaling_factor_on_output {apply_routed_scaling_factor_on_output}" + ) + + +@pytest.mark.parametrize("seq_length", [1024, 4096]) +@pytest.mark.parametrize("num_experts", [384]) +@pytest.mark.parametrize("topk", [6]) +def test_kimi_k2_specific_case(seq_length, num_experts, topk): + """Test specifically for Kimi K2 configuration: 384 experts, topk=6""" + dtype = torch.float32 + renormalize = True + routed_scaling_factor = 2.872 + + torch.manual_seed(42) + tensor = torch.rand((seq_length, num_experts), dtype=dtype, device="cuda") + scores = tensor.clone() + bias = torch.rand(num_experts, dtype=dtype, device="cuda") + + output, indices = kimi_k2_moe_fused_gate( + tensor, + bias, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=False, + ) + + ref_output, ref_indices = kimi_k2_biased_topk_impl( + scores, + scores, + bias, + topk=topk, + renormalize=renormalize, + routed_scaling_factor=routed_scaling_factor, + apply_routed_scaling_factor_on_output=False, + ) + + # Verify output shapes + assert output.shape == (seq_length, topk) + assert indices.shape == (seq_length, topk) + assert output.dtype == torch.float32 + assert indices.dtype == torch.int32 + + # Verify weights are normalized (sum to 1 per token if renormalize=True) + if renormalize: + weight_sums = output.sum(dim=-1) + assert torch.allclose( + weight_sums, torch.ones_like(weight_sums), rtol=1e-3, atol=1e-4 + ) + + # Check weights match (after sorting) + # Weights are the most important - they determine the actual MoE output + output_check = torch.allclose( + ref_output.sort()[0].to(torch.float32), + output.sort()[0].to(torch.float32), + rtol=1e-02, + atol=1e-03, + ) + + assert output_check, f"Output mismatch for Kimi K2 specific case" + + +if __name__ == "__main__": + pytest.main([__file__])