From 6f48bbd558d6476eafd94bee91dbaf91ac75e63b Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 13:57:58 +0800 Subject: [PATCH 1/6] add moe_wna16_marlin_gemm_v2 --- .../fused_moe_triton/moe_align_block_size.py | 16 +- .../benchmark/bench_moe_align_block_size.py | 15 +- sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 + sgl-kernel/csrc/moe/moe_align_kernel.cu | 90 ++++++--- sgl-kernel/tests/test_moe_align.py | 181 ++++++++++++++++++ 5 files changed, 278 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py index ce1cae66e9e8..f86867077783 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -15,7 +15,10 @@ def moe_align_block_size( - topk_ids: torch.Tensor, block_size: int, num_experts: int + topk_ids: torch.Tensor, + block_size: int, + num_experts: int, + pad_to_block_size: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -26,6 +29,9 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. + - pad_to_block_size: Whether to pad the sorted_ids size to a multiple + of block_size. For small batch sizes, setting this to False can + save memory. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -54,7 +60,15 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ + # Optimization 1: More precise memory allocation for small batches + # Calculate the minimum required size max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + # Only round up to block_size if explicitly requested + # This saves memory for small batch sizes + if pad_to_block_size: + max_num_tokens_padded = triton.cdiv(max_num_tokens_padded, block_size) * block_size + sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index 2156c5cd41a7..a99e58c0f79d 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -380,7 +380,11 @@ def benchmark(num_tokens, num_experts, topk, provider): if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser( + description="Benchmark moe_align_block_size kernel. " + "Includes optimizations: " + "1) Precise memory allocation 2) Parallel init 3) EP mode filtering 4) expert_ids padding" + ) parser.add_argument( "--save_path", type=str, @@ -418,6 +422,15 @@ def benchmark(num_tokens, num_experts, topk, provider): num_experts = args.num_experts topk = args.topk + print("\n" + "=" * 80) + print("MoE Align Block Size Kernel Benchmark") + print("Includes optimizations:") + print(" 1. Precise memory allocation for small batches") + print(" 2. Parallel initialization of sorted_token_ids") + print(" 3. EP mode invalid expert filtering") + print(" 4. expert_ids padding") + print("=" * 80 + "\n") + calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk) if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index b249f64156da..84148e7df526 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,6 +24,7 @@ #endif #include "kernel.h" +#include "marlin_template.h" #include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 92fd342707e6..bcca07033f4a 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -29,12 +29,17 @@ __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel) { + size_t numel, + int32_t num_experts) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; + // Filter out invalid experts (for EP mode) + if (expert_id < 0 || expert_id > num_experts) { + continue; + } int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -63,7 +68,8 @@ __global__ void moe_align_block_size_kernel( size_t numel, int32_t* __restrict__ cumsum, bool pad_sorted_token_ids, - const int32_t scan_size) { + const int32_t scan_size, + const int32_t max_num_tokens_padded) { extern __shared__ int32_t smem[]; int32_t* shared_counts = smem; // [num_experts] int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] @@ -73,14 +79,29 @@ __global__ void moe_align_block_size_kernel( const size_t tid = threadIdx.x; const size_t stride = blockDim.x; + // Optimization 2: Parallel initialization of sorted_token_ids + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + if (tid < num_experts) { shared_counts[tid] = 0; } __syncthreads(); + // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } atomicAdd(&shared_counts[expert_id], 1); } @@ -200,6 +221,7 @@ __global__ void moe_align_block_size_kernel( if (tid <= num_experts) { cumsum[tid] = prefix[tid]; } + // fill expert_ids const int32_t num_blocks = s_total_tokens_post_pad / block_size; for (int32_t i = tid; i < num_blocks; i += stride) { @@ -216,14 +238,11 @@ __global__ void moe_align_block_size_kernel( expert_ids[i] = left - 2; } - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } + // Optimization 4: Fill remaining expert_ids with -1 (invalid expert) + const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; + const int32_t fill_start_idx = num_blocks + tid; + for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { + expert_ids[i] = -1; } } @@ -236,7 +255,8 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t num_experts, int32_t block_size, size_t numel, - bool pad_sorted_token_ids) { + bool pad_sorted_token_ids, + const int32_t max_num_tokens_padded) { const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -244,12 +264,28 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t* cumsum = shared_mem; int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + // Optimization 2: Parallel initialization of sorted_token_ids + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } + } + for (int i = 0; i < num_experts; ++i) { tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; } + // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { - ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; + int32_t expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } + ++tokens_cnts[(threadIdx.x + 1) * num_experts + expert_id]; } __syncthreads(); @@ -279,20 +315,22 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } + // Optimization 4: Fill remaining expert_ids with -1 + const int32_t num_valid_blocks = (*total_tokens_post_pad + block_size - 1) / block_size; + const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; + const int32_t fill_start_idx = num_valid_blocks + tid; + for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { + expert_ids[i] = -1; } __syncthreads(); + // Optimization 3: Filter out invalid experts in sorting phase for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; + if (expert_id < 0 || expert_id > num_experts) { + continue; + } int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[threadIdx.x * num_experts + expert_id]; @@ -314,6 +352,9 @@ void moe_align_block_size( threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + // Optimization 1: Pass the actual allocated size to kernel + const int32_t max_num_tokens_padded = sorted_token_ids.size(0); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -330,7 +371,8 @@ void moe_align_block_size( num_experts, block_size, topk_ids.numel(), - pad_sorted_token_ids); + pad_sorted_token_ids, + max_num_tokens_padded); } else { auto align_kernel = moe_align_block_size_kernel; @@ -346,7 +388,8 @@ void moe_align_block_size( topk_ids.numel(), cumsum_buffer.data_ptr(), pad_sorted_token_ids, - scan_size); + scan_size, + max_num_tokens_padded); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -358,7 +401,8 @@ void moe_align_block_size( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), - topk_ids.numel()); + topk_ids.numel(), + num_experts); } }); } diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 40a37f563278..25ce5c98e332 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -268,5 +268,186 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) +# Additional optimization tests +def test_memory_allocation_optimization(): + """ + Test precise memory allocation for small batches + """ + num_tokens = 10 + num_experts = 8 + topk = 2 + block_size = 64 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + + # Test with pad_to_block_size=False (should use less memory) + max_num_tokens_padded_no_pad = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + # Test with pad_to_block_size=True (should round up) + max_num_tokens_padded_with_pad = triton.cdiv(max_num_tokens_padded_no_pad, block_size) * block_size + + assert max_num_tokens_padded_no_pad < max_num_tokens_padded_with_pad, \ + "Without padding should use less memory" + + +def test_parallel_initialization(): + """ + Test parallel initialization of sorted_token_ids + """ + num_tokens = 100 + num_experts = 16 + topk = 4 + block_size = 64 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Run with pad_sorted_token_ids=True + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Check that padding values are correctly set to numel + valid_count = num_tokens_post_pad.item() + padding_values = sorted_ids[valid_count:valid_count+10] + + # All padding values should be equal to numel (the sentinel value) + assert torch.all(padding_values == topk_ids.numel()), \ + f"Padding values should all be {topk_ids.numel()}, got {padding_values}" + + +def test_invalid_expert_filtering(): + """ + Test filtering out invalid experts (for EP mode) + """ + num_tokens = 50 + num_experts = 16 + topk = 2 + block_size = 32 + + # Create topk_ids with some invalid expert IDs + topk_ids = torch.randint(0, num_experts + 5, (num_tokens, topk), dtype=torch.int32, device="cuda") + + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Should not crash even with invalid expert IDs + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Count how many invalid expert IDs we had + invalid_count = torch.sum(topk_ids >= num_experts).item() + valid_count = topk_ids.numel() - invalid_count + + +def test_expert_ids_padding(): + """ + Test filling remaining expert_ids with -1 + """ + num_tokens = 30 + num_experts = 8 + topk = 2 + block_size = 32 + + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Calculate the number of valid blocks + valid_blocks = (num_tokens_post_pad.item() + block_size - 1) // block_size + + # Check that remaining expert_ids are filled with -1 + remaining_expert_ids = expert_ids[valid_blocks:] + + # All remaining blocks should have expert_id = -1 + if len(remaining_expert_ids) > 0: + assert torch.all(remaining_expert_ids == -1), \ + f"Remaining expert_ids should be -1, got {remaining_expert_ids[:10]}" + + +@pytest.mark.parametrize( + "num_tokens,num_experts,topk,block_size", + [ + (1, 8, 1, 32), # Small batch + (10, 16, 2, 64), # Medium batch + (100, 32, 4, 128), # Larger batch + (1000, 64, 8, 256), # Large batch + ] +) +def test_all_optimizations_combined(num_tokens, num_experts, topk, block_size): + """ + Test all optimizations work together correctly + """ + topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") + max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) + + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + expert_ids = torch.empty( + (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") + cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") + + # Should complete successfully with all optimizations + moe_align_block_size( + topk_ids, + num_experts + 1, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + cumsum_buffer, + pad_sorted_token_ids=True, + ) + + # Basic sanity checks + assert num_tokens_post_pad.item() > 0, "Should have some valid tokens" + assert num_tokens_post_pad.item() % block_size == 0, "Result should be aligned to block_size" + + if __name__ == "__main__": pytest.main([__file__]) From eeea208cb0b9af4ed9fd70e58c4969f12b625684 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 14:04:31 +0800 Subject: [PATCH 2/6] Revert "add moe_wna16_marlin_gemm_v2" This reverts commit 6f48bbd558d6476eafd94bee91dbaf91ac75e63b. --- .../fused_moe_triton/moe_align_block_size.py | 16 +- .../benchmark/bench_moe_align_block_size.py | 15 +- sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu | 1 - sgl-kernel/csrc/moe/moe_align_kernel.cu | 90 +++------ sgl-kernel/tests/test_moe_align.py | 181 ------------------ 5 files changed, 25 insertions(+), 278 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py index f86867077783..ce1cae66e9e8 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py @@ -15,10 +15,7 @@ def moe_align_block_size( - topk_ids: torch.Tensor, - block_size: int, - num_experts: int, - pad_to_block_size: bool = False, + topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Aligns the token distribution across experts to be compatible with block @@ -29,9 +26,6 @@ def moe_align_block_size( top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. - - pad_to_block_size: Whether to pad the sorted_ids size to a multiple - of block_size. For small batch sizes, setting this to False can - save memory. Returns: - sorted_token_ids: A tensor containing the sorted token indices according @@ -60,15 +54,7 @@ def moe_align_block_size( - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. """ - # Optimization 1: More precise memory allocation for small batches - # Calculate the minimum required size max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - # Only round up to block_size if explicitly requested - # This saves memory for small batch sizes - if pad_to_block_size: - max_num_tokens_padded = triton.cdiv(max_num_tokens_padded, block_size) * block_size - sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) diff --git a/sgl-kernel/benchmark/bench_moe_align_block_size.py b/sgl-kernel/benchmark/bench_moe_align_block_size.py index a99e58c0f79d..2156c5cd41a7 100644 --- a/sgl-kernel/benchmark/bench_moe_align_block_size.py +++ b/sgl-kernel/benchmark/bench_moe_align_block_size.py @@ -380,11 +380,7 @@ def benchmark(num_tokens, num_experts, topk, provider): if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark moe_align_block_size kernel. " - "Includes optimizations: " - "1) Precise memory allocation 2) Parallel init 3) EP mode filtering 4) expert_ids padding" - ) + parser = argparse.ArgumentParser() parser.add_argument( "--save_path", type=str, @@ -422,15 +418,6 @@ def benchmark(num_tokens, num_experts, topk, provider): num_experts = args.num_experts topk = args.topk - print("\n" + "=" * 80) - print("MoE Align Block Size Kernel Benchmark") - print("Includes optimizations:") - print(" 1. Precise memory allocation for small batches") - print(" 2. Parallel initialization of sorted_token_ids") - print(" 3. EP mode invalid expert filtering") - print(" 4. expert_ids padding") - print("=" * 80 + "\n") - calculate_diff(num_tokens=num_tokens, num_experts=num_experts, topk=topk) if not args.skip_full_benchmark and not IS_CI: # Skip full benchmark in CI diff --git a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu index 84148e7df526..b249f64156da 100644 --- a/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu +++ b/sgl-kernel/csrc/moe/marlin_moe_wna16/ops.cu @@ -24,7 +24,6 @@ #endif #include "kernel.h" -#include "marlin_template.h" #include "kernel_marlin.cuh" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index bcca07033f4a..92fd342707e6 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -29,17 +29,12 @@ __global__ void count_and_sort_expert_tokens_kernel( const scalar_t* __restrict__ topk_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, - size_t numel, - int32_t num_experts) { + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; - // Filter out invalid experts (for EP mode) - if (expert_id < 0 || expert_id > num_experts) { - continue; - } int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); sorted_token_ids[rank_post_pad] = i; } @@ -68,8 +63,7 @@ __global__ void moe_align_block_size_kernel( size_t numel, int32_t* __restrict__ cumsum, bool pad_sorted_token_ids, - const int32_t scan_size, - const int32_t max_num_tokens_padded) { + const int32_t scan_size) { extern __shared__ int32_t smem[]; int32_t* shared_counts = smem; // [num_experts] int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] @@ -79,29 +73,14 @@ __global__ void moe_align_block_size_kernel( const size_t tid = threadIdx.x; const size_t stride = blockDim.x; - // Optimization 2: Parallel initialization of sorted_token_ids - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } - } - if (tid < num_experts) { shared_counts[tid] = 0; } __syncthreads(); - // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } atomicAdd(&shared_counts[expert_id], 1); } @@ -221,7 +200,6 @@ __global__ void moe_align_block_size_kernel( if (tid <= num_experts) { cumsum[tid] = prefix[tid]; } - // fill expert_ids const int32_t num_blocks = s_total_tokens_post_pad / block_size; for (int32_t i = tid; i < num_blocks; i += stride) { @@ -238,11 +216,14 @@ __global__ void moe_align_block_size_kernel( expert_ids[i] = left - 2; } - // Optimization 4: Fill remaining expert_ids with -1 (invalid expert) - const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; - const int32_t fill_start_idx = num_blocks + tid; - for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { - expert_ids[i] = -1; + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (s_total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } } } @@ -255,8 +236,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t num_experts, int32_t block_size, size_t numel, - bool pad_sorted_token_ids, - const int32_t max_num_tokens_padded) { + bool pad_sorted_token_ids) { const size_t tid = threadIdx.x; const size_t stride = blockDim.x; @@ -264,28 +244,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( int32_t* cumsum = shared_mem; int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); - // Optimization 2: Parallel initialization of sorted_token_ids - if (pad_sorted_token_ids) { - Vec fill_vec; - fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; - int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; - Vec* out_ptr = reinterpret_cast(sorted_token_ids); - for (int32_t i = tid; i < total_vecs; i += stride) { - out_ptr[i] = fill_vec; - } - } - for (int i = 0; i < num_experts; ++i) { tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; } - // Optimization 3: Filter out invalid experts for (size_t i = tid; i < numel; i += stride) { - int32_t expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } - ++tokens_cnts[(threadIdx.x + 1) * num_experts + expert_id]; + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1]; } __syncthreads(); @@ -315,22 +279,20 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( } } - // Optimization 4: Fill remaining expert_ids with -1 - const int32_t num_valid_blocks = (*total_tokens_post_pad + block_size - 1) / block_size; - const int32_t expert_ids_size = (max_num_tokens_padded + block_size - 1) / block_size; - const int32_t fill_start_idx = num_valid_blocks + tid; - for (int32_t i = fill_start_idx; i < expert_ids_size; i += stride) { - expert_ids[i] = -1; + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (*total_tokens_post_pad + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = tid; i < total_vecs; i += stride) { + out_ptr[i] = fill_vec; + } } __syncthreads(); - // Optimization 3: Filter out invalid experts in sorting phase for (size_t i = tid; i < numel; i += stride) { int32_t expert_id = topk_ids[i] + 1; - if (expert_id < 0 || expert_id > num_experts) { - continue; - } int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; sorted_token_ids[rank_post_pad] = i; ++tokens_cnts[threadIdx.x * num_experts + expert_id]; @@ -352,9 +314,6 @@ void moe_align_block_size( threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - // Optimization 1: Pass the actual allocated size to kernel - const int32_t max_num_tokens_padded = sorted_token_ids.size(0); - DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); @@ -371,8 +330,7 @@ void moe_align_block_size( num_experts, block_size, topk_ids.numel(), - pad_sorted_token_ids, - max_num_tokens_padded); + pad_sorted_token_ids); } else { auto align_kernel = moe_align_block_size_kernel; @@ -388,8 +346,7 @@ void moe_align_block_size( topk_ids.numel(), cumsum_buffer.data_ptr(), pad_sorted_token_ids, - scan_size, - max_num_tokens_padded); + scan_size); const int block_threads = std::min(256, (int)threads); const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; @@ -401,8 +358,7 @@ void moe_align_block_size( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), cumsum_buffer.data_ptr(), - topk_ids.numel(), - num_experts); + topk_ids.numel()); } }); } diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 25ce5c98e332..40a37f563278 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -268,186 +268,5 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): torch.testing.assert_close(actual, expected, atol=2e-2, rtol=0) -# Additional optimization tests -def test_memory_allocation_optimization(): - """ - Test precise memory allocation for small batches - """ - num_tokens = 10 - num_experts = 8 - topk = 2 - block_size = 64 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - - # Test with pad_to_block_size=False (should use less memory) - max_num_tokens_padded_no_pad = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - # Test with pad_to_block_size=True (should round up) - max_num_tokens_padded_with_pad = triton.cdiv(max_num_tokens_padded_no_pad, block_size) * block_size - - assert max_num_tokens_padded_no_pad < max_num_tokens_padded_with_pad, \ - "Without padding should use less memory" - - -def test_parallel_initialization(): - """ - Test parallel initialization of sorted_token_ids - """ - num_tokens = 100 - num_experts = 16 - topk = 4 - block_size = 64 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Run with pad_sorted_token_ids=True - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Check that padding values are correctly set to numel - valid_count = num_tokens_post_pad.item() - padding_values = sorted_ids[valid_count:valid_count+10] - - # All padding values should be equal to numel (the sentinel value) - assert torch.all(padding_values == topk_ids.numel()), \ - f"Padding values should all be {topk_ids.numel()}, got {padding_values}" - - -def test_invalid_expert_filtering(): - """ - Test filtering out invalid experts (for EP mode) - """ - num_tokens = 50 - num_experts = 16 - topk = 2 - block_size = 32 - - # Create topk_ids with some invalid expert IDs - topk_ids = torch.randint(0, num_experts + 5, (num_tokens, topk), dtype=torch.int32, device="cuda") - - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Should not crash even with invalid expert IDs - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Count how many invalid expert IDs we had - invalid_count = torch.sum(topk_ids >= num_experts).item() - valid_count = topk_ids.numel() - invalid_count - - -def test_expert_ids_padding(): - """ - Test filling remaining expert_ids with -1 - """ - num_tokens = 30 - num_experts = 8 - topk = 2 - block_size = 32 - - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Calculate the number of valid blocks - valid_blocks = (num_tokens_post_pad.item() + block_size - 1) // block_size - - # Check that remaining expert_ids are filled with -1 - remaining_expert_ids = expert_ids[valid_blocks:] - - # All remaining blocks should have expert_id = -1 - if len(remaining_expert_ids) > 0: - assert torch.all(remaining_expert_ids == -1), \ - f"Remaining expert_ids should be -1, got {remaining_expert_ids[:10]}" - - -@pytest.mark.parametrize( - "num_tokens,num_experts,topk,block_size", - [ - (1, 8, 1, 32), # Small batch - (10, 16, 2, 64), # Medium batch - (100, 32, 4, 128), # Larger batch - (1000, 64, 8, 256), # Large batch - ] -) -def test_all_optimizations_combined(num_tokens, num_experts, topk, block_size): - """ - Test all optimizations work together correctly - """ - topk_ids = torch.randint(0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda") - max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1) - - sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") - expert_ids = torch.empty( - (triton.cdiv(max_num_tokens_padded, block_size),), dtype=torch.int32, device="cuda" - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda") - cumsum_buffer = torch.empty((num_experts + 2,), dtype=torch.int32, device="cuda") - - # Should complete successfully with all optimizations - moe_align_block_size( - topk_ids, - num_experts + 1, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - cumsum_buffer, - pad_sorted_token_ids=True, - ) - - # Basic sanity checks - assert num_tokens_post_pad.item() > 0, "Should have some valid tokens" - assert num_tokens_post_pad.item() % block_size == 0, "Result should be aligned to block_size" - - if __name__ == "__main__": pytest.main([__file__]) From 51d1a9e08ea461ba70a199622425018ad17011dc Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 19:06:11 +0800 Subject: [PATCH 3/6] upd --- benchmark/kernels/quantization/README.md | 59 +++++++ .../quantization/tuning_block_wise_kernel.py | 35 +++- .../srt/layers/quantization/configs/README.md | 16 ++ .../srt/layers/quantization/fp8_kernel.py | 151 +++++++++++------- 4 files changed, 197 insertions(+), 64 deletions(-) create mode 100644 benchmark/kernels/quantization/README.md create mode 100644 python/sglang/srt/layers/quantization/configs/README.md diff --git a/benchmark/kernels/quantization/README.md b/benchmark/kernels/quantization/README.md new file mode 100644 index 000000000000..0c8babbddcd2 --- /dev/null +++ b/benchmark/kernels/quantization/README.md @@ -0,0 +1,59 @@ +# W8A8 Block-wise Quantization Kernel Tuning + +Auto-tune Triton FP8/INT8 block-wise quantization kernels for optimal performance. + +## Quick Start + +**Default (DeepSeek-V3):** +```bash +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --tp-size 8 +``` + +**Custom Model (specify N and K):** +```bash +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600 +``` + +## Parameters + +- `--N`, `--K`: Weight matrix dimensions (N=output_dim, K=input_dim). If not specified, uses `--tp-size` for DeepSeek-V3 +- `--tp-size`: Tensor parallelism size for DeepSeek-V3 (default: 8) +- `--input-type`: `fp8` or `int8` (default: fp8) +- `--block-n`, `--block-k`: Block quantization granularity (default: 128) +- `--batch-size`: Test single batch size (optional) + +## How to Calculate N and K + +For a linear layer `y = xW^T` where `x` is (M, K) and `W` is (N, K): +- **N**: Output features (weight matrix output dimension) +- **K**: Input features (weight matrix input dimension) + +**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) +```bash +# QKV projection: Q(8192) + K(1024) + V(1024) = 10240 +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 10240 --K 5120 + +# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200 +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 51200 --K 5120 + +# MLP down projection +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 25600 + +# O projection (if separate from QKV) +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 8192 +``` + +## Output + +Generates JSON config files saved to `python/sglang/srt/layers/quantization/configs/`: +``` +N={N},K={K},device_name={DEVICE},dtype=fp8_w8a8,block_shape=[128,128].json +``` + +Config maps batch size to optimal kernel parameters: +```json +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, ...}, + "2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, ...} +} +``` diff --git a/benchmark/kernels/quantization/tuning_block_wise_kernel.py b/benchmark/kernels/quantization/tuning_block_wise_kernel.py index 1b51e54b779c..0a5e7fb534b9 100644 --- a/benchmark/kernels/quantization/tuning_block_wise_kernel.py +++ b/benchmark/kernels/quantization/tuning_block_wise_kernel.py @@ -84,6 +84,8 @@ def w8a8_block_matmul( C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) + needs_masking = bool(K % config["BLOCK_SIZE_K"] != 0) + def grid(META): return ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -127,6 +129,7 @@ def grid(META): Bs.stride(1), Bs.stride(0), **config, + needs_masking=needs_masking, ) return C @@ -428,7 +431,13 @@ def main(args): batch_sizes = [args.batch_size] num_gpus = 1 # If only one batch size, use only one GPU - weight_shapes = get_weight_shapes(args.tp_size) + # Support manual N and K specification + if args.N is not None and args.K is not None: + weight_shapes = [(args.N, args.K)] + print(f"Using manually specified weight shape: N={args.N}, K={args.K}") + else: + weight_shapes = get_weight_shapes(args.tp_size) + print(f"Using predefined weight shapes for TP size {args.tp_size}") batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus) @@ -453,7 +462,25 @@ def main(args): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--tp-size", "-tp", type=int, default=8) + parser.add_argument( + "--tp-size", + "-tp", + type=int, + default=8, + help="Tensor parallelism size (ignored if --N and --K are specified)", + ) + parser.add_argument( + "--N", + type=int, + default=None, + help="Output dimension of weight matrix (number of columns)", + ) + parser.add_argument( + "--K", + type=int, + default=None, + help="Input dimension of weight matrix (number of rows)", + ) parser.add_argument( "--input-type", type=str, choices=["fp8", "int8"], default="fp8" ) @@ -471,4 +498,8 @@ def main(args): ) args = parser.parse_args() + # Validate arguments + if (args.N is None) != (args.K is None): + parser.error("--N and --K must be specified together or not at all") + main(args) diff --git a/python/sglang/srt/layers/quantization/configs/README.md b/python/sglang/srt/layers/quantization/configs/README.md new file mode 100644 index 000000000000..718c9adb93d3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/README.md @@ -0,0 +1,16 @@ +# W8A8 Block FP8 Kernel Configurations + +This directory contains optimized kernel configurations for the W8A8 block FP8 matrix multiplication kernel. + +## Configuration File Format + +Configuration files are named using the following pattern: +``` +N={N},K={K},device_name={DEVICE_NAME},dtype=fp8_w8a8,block_shape=[{BLOCK_N},{BLOCK_K}].json +``` + +Where: +- `N`: Output dimension (number of columns in weight matrix) +- `K`: Input dimension (number of columns in activation matrix) +- `DEVICE_NAME`: GPU device name with spaces replaced by underscores (e.g., `NVIDIA_H100_80GB_HBM3`) +- `BLOCK_N`, `BLOCK_K`: Block quantization granularity (typically `[128,128]`) diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 4a3d1093d32f..3f8053be35b9 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -723,6 +723,7 @@ def _w8a8_block_fp8_matmul( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + needs_masking: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -748,20 +749,25 @@ def _w8a8_block_fp8_matmul( As_ptrs = As + offs_am * stride_As_m offs_bsn = offs_bn // group_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n + scale_step_k = BLOCK_SIZE_K // group_k accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if needs_masking: + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) - k_start = k * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + a_s = tl.load(As_ptrs) + b_s = tl.load(Bs_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + As_ptrs += scale_step_k * stride_As_k + Bs_ptrs += scale_step_k * stride_Bs_k if C.dtype.element_ty == tl.bfloat16: c = accumulator.to(tl.bfloat16) @@ -808,6 +814,7 @@ def _w8a8_block_fp8_matmul_unrolledx4( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + needs_masking: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output @@ -833,94 +840,111 @@ def _w8a8_block_fp8_matmul_unrolledx4( As_ptrs = As + offs_am * stride_As_m offs_bsn = offs_bn // group_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n + scale_step_k = BLOCK_SIZE_K // group_k accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # manually unroll to 4 iterations UNROLL_FACTOR = 4 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * UNROLL_FACTOR)): # 1st iteration - a = tl.load( - a_ptrs, - mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, - other=0.0, - ) + if needs_masking: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR) * BLOCK_SIZE_K, + other=0.0, + ) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) - k_start = (k * UNROLL_FACTOR) * BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + a_s = tl.load(As_ptrs) + b_s = tl.load(Bs_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + As_ptrs += scale_step_k * stride_As_k + Bs_ptrs += scale_step_k * stride_Bs_k # 2nd iteration - a = tl.load( - a_ptrs, - mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, - other=0.0, - ) + if needs_masking: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 1) * BLOCK_SIZE_K, + other=0.0, + ) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) - k_start = k_start + BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + a_s = tl.load(As_ptrs) + b_s = tl.load(Bs_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + As_ptrs += scale_step_k * stride_As_k + Bs_ptrs += scale_step_k * stride_Bs_k # 3rd iteration - a = tl.load( - a_ptrs, - mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, - other=0.0, - ) + if needs_masking: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 2) * BLOCK_SIZE_K, + other=0.0, + ) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) - k_start = k_start + BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + a_s = tl.load(As_ptrs) + b_s = tl.load(Bs_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + As_ptrs += scale_step_k * stride_As_k + Bs_ptrs += scale_step_k * stride_Bs_k # 4th iteration - a = tl.load( - a_ptrs, - mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, - other=0.0, - ) - b = tl.load( - b_ptrs, - mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, - other=0.0, - ) + if needs_masking: + a = tl.load( + a_ptrs, + mask=offs_k[None, :] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - (k * UNROLL_FACTOR + 3) * BLOCK_SIZE_K, + other=0.0, + ) + else: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) - k_start = k_start + BLOCK_SIZE_K - offs_ks = k_start // group_k - a_s = tl.load(As_ptrs + offs_ks * stride_As_k) - b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + a_s = tl.load(As_ptrs) + b_s = tl.load(Bs_ptrs) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk + As_ptrs += scale_step_k * stride_As_k + Bs_ptrs += scale_step_k * stride_Bs_k if C.dtype.element_ty == tl.bfloat16: c = accumulator.to(tl.bfloat16) @@ -1115,6 +1139,8 @@ def w8a8_block_fp8_matmul_triton( "num_stages": 3, } + needs_masking = bool(K % config["BLOCK_SIZE_K"] != 0) + def grid(META): return ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -1144,6 +1170,7 @@ def grid(META): Bs.stride(1), Bs.stride(0), **config, + needs_masking=needs_masking, ) return C From b21d2448089e94a3e8689cf1d9ad54af2cf54d95 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sat, 29 Nov 2025 23:21:29 +0800 Subject: [PATCH 4/6] upd --- benchmark/kernels/quantization/README.md | 18 ++++++++++++- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++++++++++++++++++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 26 +++++++++++++++++++ 5 files changed, 121 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json diff --git a/benchmark/kernels/quantization/README.md b/benchmark/kernels/quantization/README.md index 0c8babbddcd2..606fb303f9be 100644 --- a/benchmark/kernels/quantization/README.md +++ b/benchmark/kernels/quantization/README.md @@ -28,7 +28,7 @@ For a linear layer `y = xW^T` where `x` is (M, K) and `W` is (N, K): - **N**: Output features (weight matrix output dimension) - **K**: Input features (weight matrix input dimension) -**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) +**Example: Qwen3-VL-32B** (hidden_size=5120, intermediate_size=25600, num_heads=64, num_kv_heads=8, head_dim=128) and TP=1 ```bash # QKV projection: Q(8192) + K(1024) + V(1024) = 10240 python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 10240 --K 5120 @@ -43,6 +43,22 @@ python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 2 python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 8192 ``` +If TP=8: + +```bash +# QKV projection: Q(8192) + K(1024) + V(1024) = 10240 / TP=8 +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 1280 --K 5120 + +# MLP gate+up (SwiGLU): 2 * intermediate_size = 51200 / TP=8 +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 6400 --K 5120 + +# MLP down projection +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 3200 + +# O projection (if separate from QKV) +python benchmark/kernels/quantization/tuning_block_wise_kernel.py --N 5120 --K 1024 +``` + ## Output Generates JSON config files saved to `python/sglang/srt/layers/quantization/configs/`: diff --git a/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..5c0c8d76195f --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1280,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..15e91cde59a3 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=5120,K=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..c714b7f1928c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=5120,K=3200,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..f33809b0ad05 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=6400,K=5120,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,26 @@ +{ + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} From 93d07c3f845a0bcc734d9fbd57c8a137a7811407 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 7 Dec 2025 20:01:06 +0800 Subject: [PATCH 5/6] upd --- benchmark/kernels/quantization/README.md | 17 ++++++++++++ test/srt/quant/test_fp8_kernel.py | 33 ++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/benchmark/kernels/quantization/README.md b/benchmark/kernels/quantization/README.md index 606fb303f9be..acf6f0b0d128 100644 --- a/benchmark/kernels/quantization/README.md +++ b/benchmark/kernels/quantization/README.md @@ -2,6 +2,23 @@ Auto-tune Triton FP8/INT8 block-wise quantization kernels for optimal performance. +## When to Use Triton FP8 Block-wise Quantization Kernel vs DeepGEMM + +**Use Triton FP8 Block-wise Quantization Kernel when:** +- Output dtype is NOT `bfloat16` (e.g., `float16`, `float32`) +- DeepGEMM is disabled (environment variable `SGLANG_ENABLE_JIT_DEEPGEMM=0`) +- Running on GPUs with compute capability < SM90 (DeepGEMM requires SM90+) +- You need cross-platform compatibility (Triton works on both NVIDIA and AMD GPUs) + +**Use DeepGEMM when:** +- Output dtype is `bfloat16` AND DeepGEMM is enabled +- Running on NVIDIA GPUs with compute capability >= SM90 (e.g., H100, H200) +- Need maximum performance for production workloads (DeepGEMM is highly optimized for Hopper architecture) + +**Note:** DeepGEMM requires CUDA compute capability >= 9.0 (SM90+). It is specifically optimized for NVIDIA Hopper GPUs (H100/H200). + +The kernel selection logic in SGLang automatically chooses DeepGEMM when conditions are met (see `w8a8_block_fp8_matmul` function in `fp8_kernel.py`), otherwise falls back to Triton implementation. + ## Quick Start **Default (DeepSeek-V3):** diff --git a/test/srt/quant/test_fp8_kernel.py b/test/srt/quant/test_fp8_kernel.py index 42502277b1bf..2a5aba50e6d3 100644 --- a/test/srt/quant/test_fp8_kernel.py +++ b/test/srt/quant/test_fp8_kernel.py @@ -121,6 +121,39 @@ def test_w8a8_block_fp8_matmul(self): ) torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) + def test_w8a8_block_fp8_matmul_with_masking(self): + """Test needs_masking=True case where K is not divisible by BLOCK_SIZE_K. + + This test uses float16 output dtype to ensure it goes through the Triton kernel path + rather than DeepGEMM (which requires bfloat16 output). + """ + if torch.cuda.get_device_capability()[0] < 9: + return + + # Use K that is NOT divisible by 128 to trigger needs_masking=True + M = 256 + K_with_remainder = 512 + 64 # Not divisible by 128 + N = 1024 + group_size = 128 + output_dtype = torch.float16 # Use float16 to force Triton path (not DeepGEMM) + + A, A_quant_gt, A_scale_gt = self._make_A( + M=M, K=K_with_remainder, group_size=group_size, out_dtype=self.quant_type + ) + B, B_quant_gt, B_scale_gt = self._make_B( + K=K_with_remainder, N=N, group_size=group_size, out_dtype=self.quant_type + ) + C_gt = A.to(output_dtype) @ B.to(output_dtype) + C = w8a8_block_fp8_matmul( + A=A_quant_gt, + B=B_quant_gt.T.contiguous(), + As=A_scale_gt, + Bs=B_scale_gt.T.contiguous(), + block_size=[128, 128], + output_dtype=output_dtype, + ) + torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) + if __name__ == "__main__": unittest.main() From 4bb5508bc460d35d517cc33dcf545404a8be7ed7 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Sun, 7 Dec 2025 20:18:38 +0800 Subject: [PATCH 6/6] revert --- test/srt/quant/test_fp8_kernel.py | 33 ------------------------------- 1 file changed, 33 deletions(-) diff --git a/test/srt/quant/test_fp8_kernel.py b/test/srt/quant/test_fp8_kernel.py index 2a5aba50e6d3..42502277b1bf 100644 --- a/test/srt/quant/test_fp8_kernel.py +++ b/test/srt/quant/test_fp8_kernel.py @@ -121,39 +121,6 @@ def test_w8a8_block_fp8_matmul(self): ) torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) - def test_w8a8_block_fp8_matmul_with_masking(self): - """Test needs_masking=True case where K is not divisible by BLOCK_SIZE_K. - - This test uses float16 output dtype to ensure it goes through the Triton kernel path - rather than DeepGEMM (which requires bfloat16 output). - """ - if torch.cuda.get_device_capability()[0] < 9: - return - - # Use K that is NOT divisible by 128 to trigger needs_masking=True - M = 256 - K_with_remainder = 512 + 64 # Not divisible by 128 - N = 1024 - group_size = 128 - output_dtype = torch.float16 # Use float16 to force Triton path (not DeepGEMM) - - A, A_quant_gt, A_scale_gt = self._make_A( - M=M, K=K_with_remainder, group_size=group_size, out_dtype=self.quant_type - ) - B, B_quant_gt, B_scale_gt = self._make_B( - K=K_with_remainder, N=N, group_size=group_size, out_dtype=self.quant_type - ) - C_gt = A.to(output_dtype) @ B.to(output_dtype) - C = w8a8_block_fp8_matmul( - A=A_quant_gt, - B=B_quant_gt.T.contiguous(), - As=A_scale_gt, - Bs=B_scale_gt.T.contiguous(), - block_size=[128, 128], - output_dtype=output_dtype, - ) - torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) - if __name__ == "__main__": unittest.main()