From 6f48bbd558d6476eafd94bee91dbaf91ac75e63b Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 13:57:58 +0800 Subject: [PATCH 1/3] add moe_wna16_marlin_gemm_v2 --- .../fused_moe_triton/moe_align_block_size.py | 16 +- .../benchmark/bench_moe_align_block_size.py | 15 +- sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 + sgl-kernel/csrc/moe/moe_align_kernel.cu | 90 ++++++--- sgl-kernel/tests/test_moe_align.py | 181 ++++++++++++++++++ 5 files changed, 278 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py index ce1cae66e9e8..f86867077783 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -15,7 +15,10 @@ def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + pad_to_block_size: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -26,6 +29,9 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. + - pad_to_block_size: Whether to pad the sorted_ids size to a multiple + of block_size. For small batch sizes, setting this to False can + save memory. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -54,7 +60,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ + # Optimization 1: More precise memory allocation for small batches + # Calculate the minimum required size max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + # Only round up to block_size if explicitly requested + # This saves memory for small batch sizes + if pad_to_block_size: + max_num_tokens_padded = triton.cdiv(max_num_tokens_padded, block_size) * block_size + sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 2156c5cd41a7..a99e58c0f79d 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -380,7 +380,11 @@ def benchmark(num_tokens, num_experts, topk, provider): if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Benchmark moe_align_block_size kernel. " + "Includes optimizations: " + "1) Precise memory allocation 2) Parallel init 3) EP mode filtering 4) expert_ids padding" + ) parser.add_argument( "--save_path", type=str, @@ -418,6 +422,15 @@ def benchmark(num_tokens, num_experts, topk, provider): num_experts = args.num_experts topk = args.topk + print("\n" + "=" * 80) + print("MoE Align Block Size Kernel Benchmark") + print("Includes optimizations:") + print(" 1. Precise memory allocation for small batches") + print(" 2. Parallel initialization of sorted_token_ids") + print(" 3. EP mode invalid expert filtering") + print(" 4. expert_ids padding") + print("=" * 80 + "\n") + calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk) if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index b249f64156da..84148e7df526 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,6 +24,7 @@ #endif #include "kernel.h" +#include "marlin_template.h" #include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 92fd342707e6..bcca07033f4a 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -29,12 +29,17 @@ __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel) { + size_t numel, + int32_t num_experts) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; + // Filter out invalid experts (for EP mode) + if (expert_id < 0 || expert_id > num_experts) { + continue; + } int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -63,7 +68,8 @@ __global__ void moe_align_block_size_kernel( size_t numel, int32_t* __restrict__ cumsum, bool pad_sorted_token_ids, - const int32_t scan_size) { + const int32_t scan_size, + const int32_t max_num_tokens_padded) { extern __shared__ int32_t smem[]; int32_t* shared_counts = smem; // [num_experts] int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] @@ -73,14 +79,29 @@ __global__ void moe_align_block_size_kernel( const size_t tid = threadIdx.x; const size_t stride = blockDim.x; + // Optimization 2: Parallel initialization of sorted_token_ids + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + if (tid < num_experts) { shared_counts[tid] = 0; } __syncthreads(); + // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } atomicAdd(&shared_counts[expert_id], 1); } @@ -200,6 +221,7 @@ __global__ void moe_align_block_size_kernel( if (tid <= num_experts) { cumsum[tid] = prefix[tid]; } + // fill expert_ids const int32_t num_blocks = s_total_tokens_post_pad / block_size; for (int32_t i = tid; i < num_blocks; i += stride) { @@ -216,14 +238,11 @@ __global__ void moe_align_block_size_kernel( expert_ids[i] = left - 2; } - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } + // Optimization 4: Fill remaining expert_ids with -1 (invalid expert) + const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; + const int32_t fill_start_idx = num_blocks + tid; + for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { + expert_ids[i] = -1; } } @@ -236,7 +255,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t num_experts, int32_t block_size, size_t numel, - bool pad_sorted_token_ids) { + bool pad_sorted_token_ids, + const int32_t max_num_tokens_padded) { const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -244,12 +264,28 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t* cumsum = shared_mem; int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + // Optimization 2: Parallel initialization of sorted_token_ids + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + for (int i = 0; i < num_experts; ++i) { tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; } + // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { - ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; + int32_t expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } + ++tokens_cnts[(threadIdx.x + 1) * num_experts + expert_id]; } __syncthreads(); @@ -279,20 +315,22 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } + // Optimization 4: Fill remaining expert_ids with -1 + const int32_t num_valid_blocks = (*total_tokens_post_pad + block_size - 1) / block_size; + const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; + const int32_t fill_start_idx = num_valid_blocks + tid; + for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { + expert_ids[i] = -1; } __syncthreads(); + // Optimization 3: Filter out invalid experts in sorting phase for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[threadIdx.x * num_experts + expert_id]; @@ -314,6 +352,9 @@ void moe_align_block_size( threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + // Optimization 1: Pass the actual allocated size to kernel + const int32_t max_num_tokens_padded = sorted_token_ids.size(0); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -330,7 +371,8 @@ void moe_align_block_size( num_experts, block_size, topk_ids.numel(), - pad_sorted_token_ids); + pad_sorted_token_ids, + max_num_tokens_padded); } else { auto align_kernel = moe_align_block_size_kernel; @@ -346,7 +388,8 @@ void moe_align_block_size( topk_ids.numel(), cumsum_buffer.data_ptr(), pad_sorted_token_ids, - scan_size); + scan_size, + max_num_tokens_padded); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -358,7 +401,8 @@ void moe_align_block_size( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), - topk_ids.numel()); + topk_ids.numel(), + num_experts); } }); } diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 40a37f563278..25ce5c98e332 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -268,5 +268,186 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) +# Additional optimization tests +def test_memory_allocation_optimization(): + """ + Test precise memory allocation for small batches + """ + num_tokens = 10 + num_experts = 8 + topk = 2 + block_size = 64 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + + # Test with pad_to_block_size=False (should use less memory) + max_num_tokens_padded_no_pad = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + # Test with pad_to_block_size=True (should round up) + max_num_tokens_padded_with_pad = triton.cdiv(max_num_tokens_padded_no_pad, block_size) * block_size + + assert max_num_tokens_padded_no_pad < max_num_tokens_padded_with_pad, \ + "Without padding should use less memory" + + +def test_parallel_initialization(): + """ + Test parallel initialization of sorted_token_ids + """ + num_tokens = 100 + num_experts = 16 + topk = 4 + block_size = 64 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Run with pad_sorted_token_ids=True + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Check that padding values are correctly set to numel + valid_count = num_tokens_post_pad.item() + padding_values = sorted_ids[valid_count:valid_count+10] + + # All padding values should be equal to numel (the sentinel value) + assert torch.all(padding_values == topk_ids.numel()), \ + f"Padding values should all be {topk_ids.numel()}, got {padding_values}" + + +def test_invalid_expert_filtering(): + """ + Test filtering out invalid experts (for EP mode) + """ + num_tokens = 50 + num_experts = 16 + topk = 2 + block_size = 32 + + # Create topk_ids with some invalid expert IDs + topk_ids = torch.randint(0, num_experts + 5, (num_tokens, topk), dtype=torch.int32, device="cuda") + + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Should not crash even with invalid expert IDs + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Count how many invalid expert IDs we had + invalid_count = torch.sum(topk_ids >= num_experts).item() + valid_count = topk_ids.numel() - invalid_count + + +def test_expert_ids_padding(): + """ + Test filling remaining expert_ids with -1 + """ + num_tokens = 30 + num_experts = 8 + topk = 2 + block_size = 32 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Calculate the number of valid blocks + valid_blocks = (num_tokens_post_pad.item() + block_size - 1) // block_size + + # Check that remaining expert_ids are filled with -1 + remaining_expert_ids = expert_ids[valid_blocks:] + + # All remaining blocks should have expert_id = -1 + if len(remaining_expert_ids) > 0: + assert torch.all(remaining_expert_ids == -1), \ + f"Remaining expert_ids should be -1, got {remaining_expert_ids[:10]}" + + +@pytest.mark.parametrize( + "num_tokens,num_experts,topk,block_size", + [ + (1, 8, 1, 32), # Small batch + (10, 16, 2, 64), # Medium batch + (100, 32, 4, 128), # Larger batch + (1000, 64, 8, 256), # Large batch + ] +) +def test_all_optimizations_combined(num_tokens, num_experts, topk, block_size): + """ + Test all optimizations work together correctly + """ + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Should complete successfully with all optimizations + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Basic sanity checks + assert num_tokens_post_pad.item() > 0, "Should have some valid tokens" + assert num_tokens_post_pad.item() % block_size == 0, "Result should be aligned to block_size" + + if __name__ == "__main__": pytest.main([__file__]) From eeea208cb0b9af4ed9fd70e58c4969f12b625684 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 14:04:31 +0800 Subject: [PATCH 2/3] Revert "add moe_wna16_marlin_gemm_v2" This reverts commit 6f48bbd558d6476eafd94bee91dbaf91ac75e63b. --- .../fused_moe_triton/moe_align_block_size.py | 16 +- .../benchmark/bench_moe_align_block_size.py | 15 +- sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 - sgl-kernel/csrc/moe/moe_align_kernel.cu | 90 +++------ sgl-kernel/tests/test_moe_align.py | 181 ------------------ 5 files changed, 25 insertions(+), 278 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py index f86867077783..ce1cae66e9e8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -15,10 +15,7 @@ def moe_align_block_size( - topk_ids: torch.Tensor, - block_size: int, - num_experts: int, - pad_to_block_size: bool = False, + topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -29,9 +26,6 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. - - pad_to_block_size: Whether to pad the sorted_ids size to a multiple - of block_size. For small batch sizes, setting this to False can - save memory. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -60,15 +54,7 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - # Optimization 1: More precise memory allocation for small batches - # Calculate the minimum required size max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - # Only round up to block_size if explicitly requested - # This saves memory for small batch sizes - if pad_to_block_size: - max_num_tokens_padded = triton.cdiv(max_num_tokens_padded, block_size) * block_size - sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index a99e58c0f79d..2156c5cd41a7 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -380,11 +380,7 @@ def benchmark(num_tokens, num_experts, topk, provider): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark moe_align_block_size kernel. " - "Includes optimizations: " - "1) Precise memory allocation 2) Parallel init 3) EP mode filtering 4) expert_ids padding" - ) + parser = argparse.ArgumentParser() parser.add_argument( "--save_path", type=str, @@ -422,15 +418,6 @@ def benchmark(num_tokens, num_experts, topk, provider): num_experts = args.num_experts topk = args.topk - print("\n" + "=" * 80) - print("MoE Align Block Size Kernel Benchmark") - print("Includes optimizations:") - print(" 1. Precise memory allocation for small batches") - print(" 2. Parallel initialization of sorted_token_ids") - print(" 3. EP mode invalid expert filtering") - print(" 4. expert_ids padding") - print("=" * 80 + "\n") - calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk) if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index 84148e7df526..b249f64156da 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,7 +24,6 @@ #endif #include "kernel.h" -#include "marlin_template.h" #include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index bcca07033f4a..92fd342707e6 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -29,17 +29,12 @@ __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel, - int32_t num_experts) { + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; - // Filter out invalid experts (for EP mode) - if (expert_id < 0 || expert_id > num_experts) { - continue; - } int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -68,8 +63,7 @@ __global__ void moe_align_block_size_kernel( size_t numel, int32_t* __restrict__ cumsum, bool pad_sorted_token_ids, - const int32_t scan_size, - const int32_t max_num_tokens_padded) { + const int32_t scan_size) { extern __shared__ int32_t smem[]; int32_t* shared_counts = smem; // [num_experts] int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] @@ -79,29 +73,14 @@ __global__ void moe_align_block_size_kernel( const size_t tid = threadIdx.x; const size_t stride = blockDim.x; - // Optimization 2: Parallel initialization of sorted_token_ids - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } - } - if (tid < num_experts) { shared_counts[tid] = 0; } __syncthreads(); - // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } atomicAdd(&shared_counts[expert_id], 1); } @@ -221,7 +200,6 @@ __global__ void moe_align_block_size_kernel( if (tid <= num_experts) { cumsum[tid] = prefix[tid]; } - // fill expert_ids const int32_t num_blocks = s_total_tokens_post_pad / block_size; for (int32_t i = tid; i < num_blocks; i += stride) { @@ -238,11 +216,14 @@ __global__ void moe_align_block_size_kernel( expert_ids[i] = left - 2; } - // Optimization 4: Fill remaining expert_ids with -1 (invalid expert) - const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; - const int32_t fill_start_idx = num_blocks + tid; - for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { - expert_ids[i] = -1; + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } } } @@ -255,8 +236,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t num_experts, int32_t block_size, size_t numel, - bool pad_sorted_token_ids, - const int32_t max_num_tokens_padded) { + bool pad_sorted_token_ids) { const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -264,28 +244,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t* cumsum = shared_mem; int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); - // Optimization 2: Parallel initialization of sorted_token_ids - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } - } - for (int i = 0; i < num_experts; ++i) { tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; } - // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } - ++tokens_cnts[(threadIdx.x + 1) * num_experts + expert_id]; + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; } __syncthreads(); @@ -315,22 +279,20 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } - // Optimization 4: Fill remaining expert_ids with -1 - const int32_t num_valid_blocks = (*total_tokens_post_pad + block_size - 1) / block_size; - const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; - const int32_t fill_start_idx = num_valid_blocks + tid; - for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { - expert_ids[i] = -1; + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } } __syncthreads(); - // Optimization 3: Filter out invalid experts in sorting phase for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[threadIdx.x * num_experts + expert_id]; @@ -352,9 +314,6 @@ void moe_align_block_size( threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - // Optimization 1: Pass the actual allocated size to kernel - const int32_t max_num_tokens_padded = sorted_token_ids.size(0); - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -371,8 +330,7 @@ void moe_align_block_size( num_experts, block_size, topk_ids.numel(), - pad_sorted_token_ids, - max_num_tokens_padded); + pad_sorted_token_ids); } else { auto align_kernel = moe_align_block_size_kernel; @@ -388,8 +346,7 @@ void moe_align_block_size( topk_ids.numel(), cumsum_buffer.data_ptr(), pad_sorted_token_ids, - scan_size, - max_num_tokens_padded); + scan_size); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -401,8 +358,7 @@ void moe_align_block_size( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), - topk_ids.numel(), - num_experts); + topk_ids.numel()); } }); } diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 25ce5c98e332..40a37f563278 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -268,186 +268,5 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) -# Additional optimization tests -def test_memory_allocation_optimization(): - """ - Test precise memory allocation for small batches - """ - num_tokens = 10 - num_experts = 8 - topk = 2 - block_size = 64 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - - # Test with pad_to_block_size=False (should use less memory) - max_num_tokens_padded_no_pad = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - # Test with pad_to_block_size=True (should round up) - max_num_tokens_padded_with_pad = triton.cdiv(max_num_tokens_padded_no_pad, block_size) * block_size - - assert max_num_tokens_padded_no_pad < max_num_tokens_padded_with_pad, \ - "Without padding should use less memory" - - -def test_parallel_initialization(): - """ - Test parallel initialization of sorted_token_ids - """ - num_tokens = 100 - num_experts = 16 - topk = 4 - block_size = 64 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Run with pad_sorted_token_ids=True - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Check that padding values are correctly set to numel - valid_count = num_tokens_post_pad.item() - padding_values = sorted_ids[valid_count:valid_count+10] - - # All padding values should be equal to numel (the sentinel value) - assert torch.all(padding_values == topk_ids.numel()), \ - f"Padding values should all be {topk_ids.numel()}, got {padding_values}" - - -def test_invalid_expert_filtering(): - """ - Test filtering out invalid experts (for EP mode) - """ - num_tokens = 50 - num_experts = 16 - topk = 2 - block_size = 32 - - # Create topk_ids with some invalid expert IDs - topk_ids = torch.randint(0, num_experts + 5, (num_tokens, topk), dtype=torch.int32, device="cuda") - - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Should not crash even with invalid expert IDs - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Count how many invalid expert IDs we had - invalid_count = torch.sum(topk_ids >= num_experts).item() - valid_count = topk_ids.numel() - invalid_count - - -def test_expert_ids_padding(): - """ - Test filling remaining expert_ids with -1 - """ - num_tokens = 30 - num_experts = 8 - topk = 2 - block_size = 32 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Calculate the number of valid blocks - valid_blocks = (num_tokens_post_pad.item() + block_size - 1) // block_size - - # Check that remaining expert_ids are filled with -1 - remaining_expert_ids = expert_ids[valid_blocks:] - - # All remaining blocks should have expert_id = -1 - if len(remaining_expert_ids) > 0: - assert torch.all(remaining_expert_ids == -1), \ - f"Remaining expert_ids should be -1, got {remaining_expert_ids[:10]}" - - -@pytest.mark.parametrize( - "num_tokens,num_experts,topk,block_size", - [ - (1, 8, 1, 32), # Small batch - (10, 16, 2, 64), # Medium batch - (100, 32, 4, 128), # Larger batch - (1000, 64, 8, 256), # Large batch - ] -) -def test_all_optimizations_combined(num_tokens, num_experts, topk, block_size): - """ - Test all optimizations work together correctly - """ - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Should complete successfully with all optimizations - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Basic sanity checks - assert num_tokens_post_pad.item() > 0, "Should have some valid tokens" - assert num_tokens_post_pad.item() % block_size == 0, "Result should be aligned to block_size" - - if __name__ == "__main__": pytest.main([__file__]) From 842bbecdeca68e504c375a1ac2c6ccd03af7318b Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Fri, 5 Dec 2025 18:39:08 +0800 Subject: [PATCH 3/3] Fix the bug where profiler traces generated by SGLang diffusion pipeline cannot jump from CUDA kernels to Python code in Chrome trace viewer --- .../multimodal_gen/runtime/pipelines_core/stages/denoising.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py index aae4af5033a8..ed2f7aab8051 100644 --- a/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py +++ b/python/sglang/multimodal_gen/runtime/pipelines_core/stages/denoising.py @@ -599,9 +599,7 @@ def start_profile(self, batch: Req): active=batch.num_profiled_timesteps, repeat=5, ), - on_trace_ready=lambda _: torch.profiler.tensorboard_trace_handler( - f"./logs" - ), + on_trace_ready=None, record_shapes=True, with_stack=True, )