From c8731401056d3e9611a73e26960859691c1303be Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Wed, 1 Apr 2026 13:46:42 +0800 Subject: [PATCH 1/3] feat: support faster rmsnorm for large hidden_dim --- .../jit_kernel/benchmark/bench_rmsnorm.py | 24 +++- python/sglang/jit_kernel/benchmark/utils.py | 19 ++- .../jit_kernel/csrc/elementwise/rmsnorm.cuh | 126 ++++++++++++++++++ python/sglang/jit_kernel/norm.py | 15 ++- 4 files changed, 165 insertions(+), 19 deletions(-) diff --git a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py index 779b8ad7e207..eaca5c66d200 100644 --- a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py +++ b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py @@ -55,9 +55,10 @@ def torch_impl_rmsnorm( ci_range=[16], ) HIDDEN_SIZE_LIST = get_benchmark_range( - full_range=[1536, 3072, 4096, 5120, 8192], - ci_range=[512, 2048], + full_range=[1536, 3072, 4096, 5120, 8192, 12288, 16384], + ci_range=[4096], ) +NUM_LAYERS = 4 LINE_VALS = ["aot", "jit", "flashinfer", "torch"] LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] @@ -81,17 +82,28 @@ def torch_impl_rmsnorm( ) def benchmark(hidden_size: int, batch_size: int, provider: str): input = torch.randn( - (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + (NUM_LAYERS, batch_size, hidden_size), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, + ) + weight = torch.randn( + (NUM_LAYERS, hidden_size), + dtype=DEFAULT_DTYPE, + device=DEFAULT_DEVICE, ) - weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) FN_MAP = { "aot": sglang_aot_rmsnorm, "jit": sglang_jit_rmsnorm, "flashinfer": flashinfer_rmsnorm, "torch": torch_impl_rmsnorm, } - fn = lambda: FN_MAP[provider](input.clone(), weight) - return run_benchmark(fn) + + def f(): + fn = FN_MAP[provider] + for i in range(NUM_LAYERS): + fn(input[i], weight[i]) + + return run_benchmark(f, scale=NUM_LAYERS) if __name__ == "__main__": diff --git a/python/sglang/jit_kernel/benchmark/utils.py b/python/sglang/jit_kernel/benchmark/utils.py index c17ef4f9a0a1..3bd5e793d945 100644 --- a/python/sglang/jit_kernel/benchmark/utils.py +++ b/python/sglang/jit_kernel/benchmark/utils.py @@ -1,6 +1,6 @@ """Common utilities for jit_kernel benchmark files.""" -from typing import Callable, List, Tuple +from typing import Callable, List, Sequence, Tuple import torch import triton.testing @@ -19,25 +19,30 @@ def get_benchmark_range(full_range: List, ci_range: List) -> List: def run_benchmark( - fn: Callable, quantiles: List[float] = None + fn: Callable, + quantiles: Sequence[float] = (), + scale: float = 1.0, ) -> Tuple[float, float, float]: """Execute benchmark using CUDA graph and return times in microseconds. Args: fn: Function to benchmark quantiles: Quantiles for timing measurements [median, min, max] + scale: Scale the result down (usually num_layers). Returns: Tuple of (median_us, max_us, min_us) """ - quantiles = quantiles or DEFAULT_QUANTILES + quantiles = list(quantiles or DEFAULT_QUANTILES) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale def run_benchmark_no_cudagraph( - fn: Callable, quantiles: List[float] = None + fn: Callable, + quantiles: Sequence[float] = (), + scale: float = 1.0, ) -> Tuple[float, float, float]: - quantiles = quantiles or DEFAULT_QUANTILES + quantiles = list(quantiles or DEFAULT_QUANTILES) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) - return 1000 * ms, 1000 * max_ms, 1000 * min_ms + return 1000 * ms / scale, 1000 * max_ms / scale, 1000 * min_ms / scale diff --git a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh index 4f24b09736e1..ccfcd025bccd 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -47,6 +48,80 @@ __global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) { PDLTriggerSecondary(); // launch secondary kernel } +template +__global__ __launch_bounds__(kDim / 16) // optimize the occupancy + void rmsnorm_cta_half(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Float2 = packed_t; + using Storage = AlignedVector; + + constexpr auto kNumThreads = kDim / 16; + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[32]; + + PDLWaitPrimary(); + + const auto input_ptr = pointer::offset(input, blockIdx.x * input_stride); + const auto output_ptr = pointer::offset(output, blockIdx.x * output_stride); + + // Each thread loads two tiles: first half and second half + const auto input_first = gmem.load(input_ptr, 0); + const auto input_second = gmem.load(input_ptr, 1); + const auto weight_first = gmem.load(weight_ptr, 0); + const auto weight_second = gmem.load(weight_ptr, 1); + + // Compute sum of squares across both halves + float sum_of_squares = 0.0f; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [x, y] = cast(input_first[j]); + sum_of_squares += x * x + y * y; + } +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [x, y] = cast(input_second[j]); + sum_of_squares += x * x + y * y; + } + + // CTA-wide reduction + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = sum_of_squares; + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < kNumWarps ? smem[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem[tx] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + const float norm_factor = smem[warp_id]; + + // Apply norm to both halves + Storage output_first, output_second; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [ix, iy] = cast(input_first[j]); + const auto [wx, wy] = cast(weight_first[j]); + output_first[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } + +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto [ix, iy] = cast(input_second[j]); + const auto [wx, wy] = cast(weight_second[j]); + output_second[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } + + gmem.store(output_ptr, output_first, 0); + gmem.store(output_ptr, output_second, 1); + + PDLTriggerSecondary(); +} + template __global__ void rmsnorm_warp(const RMSNormParams __grid_constant__ params) { using namespace device; @@ -178,4 +253,55 @@ struct RMSNormKernel { } }; +template +struct RMSNormHalfKernel { + static_assert(kDim % 512 == 0 && sizeof(DType) == 2); + static constexpr auto kernel = rmsnorm_cta_half; + static constexpr auto kBlockSize = static_cast(kDim / 16); + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto SI = SymbolicSize{"input_stride"}; + auto SO = SymbolicSize{"output_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kDim); + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({SI, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // output + .with_strides({SO, 1}) + .with_dtype() + .with_device(device) + .verify(output); + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = RMSNormParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .output = output.data_ptr(), + .input_stride = SI.unwrap(), + .output_stride = SO.unwrap(), + .num_tokens = num_tokens, + .eps = eps, + }; + + LaunchKernel(num_tokens, kBlockSize, device.unwrap()) // + .enable_pdl(kUsePDL)(kernel, params); + } +}; + } // namespace diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index 606358dd1d97..cebbd40b2a64 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -32,20 +32,23 @@ def _jit_qknorm_module(head_dim: int, dtype: torch.dtype) -> Module: _RMSNORM_WARP_SIZES = frozenset({64, 128, 256}) -_RMSNORM_MAX_HIDDEN_SIZE = 8192 +_RMSNORM_MAX_HIDDEN_SIZE = 16384 +_RMSNORM_HALF_BLOCK_MIN_SIZE = 2048 -def _is_supported_rmsnorm_hidden_size(hidden_size: int) -> bool: - return hidden_size in _RMSNORM_WARP_SIZES or ( - hidden_size > 256 - and hidden_size % 256 == 0 - and hidden_size <= _RMSNORM_MAX_HIDDEN_SIZE +def _is_supported_rmsnorm_hidden_size(d: int) -> bool: + return d in _RMSNORM_WARP_SIZES or ( + (d > 256 and d % 256 == 0 and d <= 8192) + or (d >= 8192 and d % 512 == 0 and d <= 16384) ) def _rmsnorm_kernel_class(hidden_size: int) -> str: if hidden_size in _RMSNORM_WARP_SIZES: return "RMSNormWarpKernel" + if hidden_size >= _RMSNORM_HALF_BLOCK_MIN_SIZE: + if hidden_size % 512 == 0: + return "RMSNormHalfKernel" return "RMSNormKernel" From 6440b016ec6ed270145a554491ed5ee451c44250 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Wed, 1 Apr 2026 14:22:37 +0800 Subject: [PATCH 2/3] feat: optimize blackwell a little --- .../jit_kernel/csrc/elementwise/rmsnorm.cuh | 70 ++++++++++++++++--- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh index ccfcd025bccd..2e1edd692d87 100644 --- a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh +++ b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm.cuh @@ -48,9 +48,9 @@ __global__ void rmsnorm_cta(const RMSNormParams __grid_constant__ params) { PDLTriggerSecondary(); // launch secondary kernel } +// Pre-Blackwell: 16B vector, each thread loads/stores twice template -__global__ __launch_bounds__(kDim / 16) // optimize the occupancy - void rmsnorm_cta_half(const RMSNormParams __grid_constant__ params) { +__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_double(const RMSNormParams __grid_constant__ params) { using namespace device; using Float2 = packed_t; using Storage = AlignedVector; @@ -67,13 +67,11 @@ __global__ __launch_bounds__(kDim / 16) // optimize the occupancy const auto input_ptr = pointer::offset(input, blockIdx.x * input_stride); const auto output_ptr = pointer::offset(output, blockIdx.x * output_stride); - // Each thread loads two tiles: first half and second half const auto input_first = gmem.load(input_ptr, 0); const auto input_second = gmem.load(input_ptr, 1); const auto weight_first = gmem.load(weight_ptr, 0); const auto weight_second = gmem.load(weight_ptr, 1); - // Compute sum of squares across both halves float sum_of_squares = 0.0f; #pragma unroll for (auto j = 0u; j < 4u; ++j) { @@ -86,7 +84,6 @@ __global__ __launch_bounds__(kDim / 16) // optimize the occupancy sum_of_squares += x * x + y * y; } - // CTA-wide reduction sum_of_squares = warp::reduce_sum(sum_of_squares); const auto warp_id = threadIdx.x / kWarpThreads; smem[warp_id] = sum_of_squares; @@ -100,7 +97,6 @@ __global__ __launch_bounds__(kDim / 16) // optimize the occupancy __syncthreads(); const float norm_factor = smem[warp_id]; - // Apply norm to both halves Storage output_first, output_second; #pragma unroll for (auto j = 0u; j < 4u; ++j) { @@ -108,7 +104,6 @@ __global__ __launch_bounds__(kDim / 16) // optimize the occupancy const auto [wx, wy] = cast(weight_first[j]); output_first[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); } - #pragma unroll for (auto j = 0u; j < 4u; ++j) { const auto [ix, iy] = cast(input_second[j]); @@ -122,6 +117,61 @@ __global__ __launch_bounds__(kDim / 16) // optimize the occupancy PDLTriggerSecondary(); } +// Blackwell: 32B vector, each thread loads/stores once +template +__global__ __launch_bounds__(kDim / 16) void rmsnorm_cta_wide(const RMSNormParams __grid_constant__ params) { + using namespace device; + using Float2 = packed_t; + using Storage = AlignedVector; + + constexpr auto kNumThreads = kDim / 16; + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + + const auto& [input, weight_ptr, output, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem = tile::Memory::cta(kNumThreads); + __shared__ float smem[32]; + + PDLWaitPrimary(); + + const auto input_ptr = pointer::offset(input, blockIdx.x * input_stride); + const auto output_ptr = pointer::offset(output, blockIdx.x * output_stride); + + const auto input_vec = gmem.load(input_ptr); + const auto weight_vec = gmem.load(weight_ptr); + + float sum_of_squares = 0.0f; +#pragma unroll + for (auto j = 0u; j < 8u; ++j) { + const auto [x, y] = cast(input_vec[j]); + sum_of_squares += x * x + y * y; + } + + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = sum_of_squares; + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < kNumWarps ? smem[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem[tx] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + const float norm_factor = smem[warp_id]; + + Storage output_vec; +#pragma unroll + for (auto j = 0u; j < 8u; ++j) { + const auto [ix, iy] = cast(input_vec[j]); + const auto [wx, wy] = cast(weight_vec[j]); + output_vec[j] = cast(fp32x2_t{ix * norm_factor * wx, iy * norm_factor * wy}); + } + + gmem.store(output_ptr, output_vec); + + PDLTriggerSecondary(); +} + template __global__ void rmsnorm_warp(const RMSNormParams __grid_constant__ params) { using namespace device; @@ -256,7 +306,11 @@ struct RMSNormKernel { template struct RMSNormHalfKernel { static_assert(kDim % 512 == 0 && sizeof(DType) == 2); - static constexpr auto kernel = rmsnorm_cta_half; +#if SGL_ARCH_BLACKWELL_OR_GREATER + static constexpr auto kernel = rmsnorm_cta_wide; +#else + static constexpr auto kernel = rmsnorm_cta_double; +#endif static constexpr auto kBlockSize = static_cast(kDim / 16); static void From 75d50c5fbaa468f1dce70b0b87d9900a484a8ad1 Mon Sep 17 00:00:00 2001 From: DarkSharpness <2040703891@qq.com> Date: Wed, 1 Apr 2026 22:27:07 +0800 Subject: [PATCH 3/3] misc: deduplicate test --- .../benchmark/bench_fused_add_rmsnorm.py | 75 --------- .../sglang/jit_kernel/benchmark/bench_norm.py | 83 +++++----- .../jit_kernel/benchmark/bench_rmsnorm.py | 110 ------------- python/sglang/jit_kernel/norm.py | 6 +- .../sglang/jit_kernel/tests/test_norm_jit.py | 145 ------------------ .../sglang/jit_kernel/tests/test_rmsnorm.py | 95 +++++++++--- 6 files changed, 121 insertions(+), 393 deletions(-) delete mode 100644 python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py delete mode 100644 python/sglang/jit_kernel/benchmark/bench_rmsnorm.py delete mode 100644 python/sglang/jit_kernel/tests/test_norm_jit.py diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py deleted file mode 100644 index a842be84b72b..000000000000 --- a/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm.py +++ /dev/null @@ -1,75 +0,0 @@ -import itertools - -import torch -import triton -import triton.testing -from flashinfer import fused_add_rmsnorm as fi_fused_add_rmsnorm - -from sglang.jit_kernel.benchmark.utils import run_benchmark -from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.utils import is_in_ci - -register_cuda_ci(est_time=6, suite="stage-b-kernel-benchmark-1-gpu-large") - -IS_CI = is_in_ci() - - -def sglang_jit_fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float -) -> None: - jit_fused_add_rmsnorm(input, residual, weight, eps) - - -def flashinfer_fused_add_rmsnorm( - input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float -) -> None: - fi_fused_add_rmsnorm(input, residual, weight, eps=eps) - - -DTYPE = torch.bfloat16 -DEVICE = "cuda" - -if IS_CI: - BS_LIST = [16] - HIDDEN_SIZE_LIST = [512, 2048] -else: - BS_LIST = [2**n for n in range(0, 14)] - HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] - -LINE_VALS = ["jit", "flashinfer"] -LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"] -STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] - -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["hidden_size", "batch_size"], - x_vals=configs, - line_arg="provider", - line_vals=LINE_VALS, - line_names=LINE_NAMES, - styles=STYLES, - ylabel="us", - plot_name="fused-add-rmsnorm-performance", - args={}, - ) -) -def benchmark(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": sglang_jit_fused_add_rmsnorm, - "flashinfer": flashinfer_fused_add_rmsnorm, - } - fn = lambda: FN_MAP[provider]( - input.clone(), residual.clone(), weight, torch.finfo(torch.bfloat16).eps - ) - return run_benchmark(fn) - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/bench_norm.py b/python/sglang/jit_kernel/benchmark/bench_norm.py index d046ecf2d2a8..345388ef7fc3 100644 --- a/python/sglang/jit_kernel/benchmark/bench_norm.py +++ b/python/sglang/jit_kernel/benchmark/bench_norm.py @@ -6,40 +6,39 @@ from flashinfer.norm import fused_add_rmsnorm as fi_fused_add_rmsnorm from flashinfer.norm import rmsnorm as fi_rmsnorm -from sglang.jit_kernel.benchmark.utils import run_benchmark +from sglang.jit_kernel.benchmark.utils import get_benchmark_range, run_benchmark from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm from sglang.test.ci.ci_register import register_cuda_ci -from sglang.utils import is_in_ci -register_cuda_ci(est_time=5, suite="stage-b-kernel-benchmark-1-gpu-large") +register_cuda_ci(est_time=30, suite="stage-b-kernel-benchmark-1-gpu-large") -IS_CI = is_in_ci() DTYPE = torch.bfloat16 DEVICE = "cuda" -# JIT rmsnorm: hidden_size in {64,128,256} or (multiple of 256, <=8192) -# JIT fused_add_rmsnorm: hidden_size % 8 == 0, <=8192 -# Use multiples of 256 <=8192 to satisfy both kernels -if IS_CI: - BS_LIST = [16] - HIDDEN_SIZE_LIST = [512, 2048] -else: - BS_LIST = [2**n for n in range(0, 14)] - HIDDEN_SIZE_LIST = [1536, 3072, 4096, 5120, 8192] - -LINE_VALS = ["jit", "flashinfer"] -LINE_NAMES = ["SGL JIT Kernel", "FlashInfer"] +BS_LIST = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16, 32], +) +HIDDEN_SIZE_LIST = get_benchmark_range( + full_range=sorted([1536, *range(1024, 8192 + 1, 1024)]), + ci_range=[512, 2048], +) + +LINE_VALS = ["flashinfer", "jit"] +LINE_NAMES = ["FlashInfer", "SGL JIT Kernel"] STYLES = [("blue", "--"), ("green", "-.")] +NUM_LAYERS = 4 # avoid L2 effect -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) +configs_0 = list(itertools.product(HIDDEN_SIZE_LIST + [16384], BS_LIST)) +configs_1 = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["hidden_size", "batch_size"], - x_vals=configs, + x_vals=configs_0, line_arg="provider", line_vals=LINE_VALS, line_names=LINE_NAMES, @@ -50,20 +49,24 @@ ) ) def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": lambda: jit_rmsnorm(input.clone(), weight), - "flashinfer": lambda: fi_rmsnorm(input.clone(), weight, out=input.clone()), - } - fn = FN_MAP[provider] - return run_benchmark(fn) + input = torch.randn( + (NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE + ) + weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE) + FN_MAP = {"jit": jit_rmsnorm, "flashinfer": fi_rmsnorm} + + def f(): + fn = FN_MAP[provider] + for i in range(NUM_LAYERS): + fn(input[i], weight[i], out=input[i]) + + return run_benchmark(f, scale=NUM_LAYERS) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["hidden_size", "batch_size"], - x_vals=configs, + x_vals=configs_1, line_arg="provider", line_vals=LINE_VALS, line_names=LINE_NAMES, @@ -74,19 +77,19 @@ def benchmark_rmsnorm(hidden_size: int, batch_size: int, provider: str): ) ) def benchmark_fused_add_rmsnorm(hidden_size: int, batch_size: int, provider: str): - input = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - residual = torch.randn((batch_size, hidden_size), dtype=DTYPE, device=DEVICE) - weight = torch.randn(hidden_size, dtype=DTYPE, device=DEVICE) - FN_MAP = { - "jit": lambda: jit_fused_add_rmsnorm( - input.clone(), residual.clone(), weight, torch.finfo(DTYPE).eps - ), - "flashinfer": lambda: fi_fused_add_rmsnorm( - input.clone(), residual.clone(), weight, eps=torch.finfo(DTYPE).eps - ), - } - fn = FN_MAP[provider] - return run_benchmark(fn) + input = torch.randn( + (NUM_LAYERS, batch_size, hidden_size), dtype=DTYPE, device=DEVICE + ) + residual = torch.randn_like(input) + weight = torch.randn((NUM_LAYERS, hidden_size), dtype=DTYPE, device=DEVICE) + FN_MAP = {"jit": jit_fused_add_rmsnorm, "flashinfer": fi_fused_add_rmsnorm} + + def f(): + fn = FN_MAP[provider] + for i in range(NUM_LAYERS): + fn(input[i], residual[i], weight[i]) + + return run_benchmark(f, scale=NUM_LAYERS) if __name__ == "__main__": diff --git a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py b/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py deleted file mode 100644 index eaca5c66d200..000000000000 --- a/python/sglang/jit_kernel/benchmark/bench_rmsnorm.py +++ /dev/null @@ -1,110 +0,0 @@ -import itertools - -import torch -import triton -import triton.testing -from flashinfer import rmsnorm as fi_rmsnorm -from sgl_kernel import rmsnorm - -from sglang.jit_kernel.benchmark.utils import ( - DEFAULT_DEVICE, - DEFAULT_DTYPE, - get_benchmark_range, - run_benchmark, -) -from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm -from sglang.test.ci.ci_register import register_cuda_ci - -register_cuda_ci(est_time=21, suite="stage-b-kernel-benchmark-1-gpu-large") - - -def sglang_aot_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - rmsnorm(input, weight, out=input) - - -def sglang_jit_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - jit_rmsnorm(input, weight, output=input) - - -def flashinfer_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, -) -> None: - fi_rmsnorm(input, weight, out=input) - - -@torch.compile() -def torch_impl_rmsnorm( - input: torch.Tensor, - weight: torch.Tensor, - eps: float = 1e-6, -) -> None: - mean = input.float().pow(2).mean(dim=-1, keepdim=True) - norm = (mean + eps).rsqrt() - input.copy_(input.float() * norm * weight.float()) - - -BS_LIST = get_benchmark_range( - full_range=[2**n for n in range(0, 14)], - ci_range=[16], -) -HIDDEN_SIZE_LIST = get_benchmark_range( - full_range=[1536, 3072, 4096, 5120, 8192, 12288, 16384], - ci_range=[4096], -) -NUM_LAYERS = 4 - -LINE_VALS = ["aot", "jit", "flashinfer", "torch"] -LINE_NAMES = ["SGL AOT Kernel", "SGL JIT Kernel", "FlashInfer", "PyTorch"] -STYLES = [("orange", "-"), ("blue", "--"), ("green", "-."), ("red", ":")] - -configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["hidden_size", "batch_size"], - x_vals=configs, - line_arg="provider", - line_vals=LINE_VALS, - line_names=LINE_NAMES, - styles=STYLES, - ylabel="us", - plot_name="rmsnorm-performance", - args={}, - ) -) -def benchmark(hidden_size: int, batch_size: int, provider: str): - input = torch.randn( - (NUM_LAYERS, batch_size, hidden_size), - dtype=DEFAULT_DTYPE, - device=DEFAULT_DEVICE, - ) - weight = torch.randn( - (NUM_LAYERS, hidden_size), - dtype=DEFAULT_DTYPE, - device=DEFAULT_DEVICE, - ) - FN_MAP = { - "aot": sglang_aot_rmsnorm, - "jit": sglang_jit_rmsnorm, - "flashinfer": flashinfer_rmsnorm, - "torch": torch_impl_rmsnorm, - } - - def f(): - fn = FN_MAP[provider] - for i in range(NUM_LAYERS): - fn(input[i], weight[i]) - - return run_benchmark(f, scale=NUM_LAYERS) - - -if __name__ == "__main__": - benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index cebbd40b2a64..25b4a5f2c1b2 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -121,10 +121,10 @@ def fused_inplace_qknorm( def rmsnorm( input: torch.Tensor, weight: torch.Tensor, - output: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, eps: float = 1e-6, ) -> None: - output = output if output is not None else input + out = out if out is not None else input hidden_size = input.size(-1) if not _is_supported_rmsnorm_hidden_size(hidden_size): raise RuntimeError( @@ -133,7 +133,7 @@ def rmsnorm( f"(256, {_RMSNORM_MAX_HIDDEN_SIZE}]." ) module = _jit_rmsnorm_module(hidden_size, input.dtype) - module.rmsnorm(input, weight, output, eps) + module.rmsnorm(input, weight, out, eps) @debug_kernel_api diff --git a/python/sglang/jit_kernel/tests/test_norm_jit.py b/python/sglang/jit_kernel/tests/test_norm_jit.py deleted file mode 100644 index ebd0d3034cd9..000000000000 --- a/python/sglang/jit_kernel/tests/test_norm_jit.py +++ /dev/null @@ -1,145 +0,0 @@ -# Adapted from sgl-kernel/tests/test_norm.py - -import sys - -import pytest -import torch - -from sglang.test.ci.ci_register import register_cuda_ci - -register_cuda_ci(est_time=125, suite="stage-b-kernel-unit-1-gpu-large") -register_cuda_ci(est_time=500, suite="nightly-kernel-1-gpu", nightly=True) - -# JIT rmsnorm: fp16/bf16 only -# - Warp norm path (one warp per token): hidden_size in {64, 128, 256} -# - CTA norm path (multi-warp per token): hidden_size is a multiple of 256, > 256, and <=8192 -RMSNORM_HIDDEN_SIZES = [64, 128, 256, 512, 1024, 3072, 3584, 4096, 8192] - -# JIT fused_add_rmsnorm: fp16/bf16 only; hidden_size % 8 == 0, <=8192 -FUSED_ADD_RMSNORM_HIDDEN_SIZES = [1024, 3072, 3584, 4096, 8192] - -BS_LIST = [ - 1, - 19, - 99, - 989, - 8192, -] # 8192 ensures num_tokens > max_occupancy * kNumSM on any GPU - - -def _jit_rmsnorm(input, weight, output, eps): - from sglang.jit_kernel.norm import rmsnorm - - rmsnorm(input, weight, output=output, eps=eps) - - -def _fi_rmsnorm(input, weight, out, eps): - from flashinfer.norm import rmsnorm - - rmsnorm(input, weight, out=out, eps=eps) - - -def _jit_fused_add_rmsnorm(input, residual, weight, eps): - from sglang.jit_kernel.norm import fused_add_rmsnorm - - fused_add_rmsnorm(input, residual, weight, eps) - - -def _fi_fused_add_rmsnorm(input, residual, weight, eps): - from flashinfer.norm import fused_add_rmsnorm - - fused_add_rmsnorm(input, residual, weight, eps=eps) - - -@pytest.mark.parametrize("batch_size", BS_LIST) -@pytest.mark.parametrize("hidden_size", RMSNORM_HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("specify_out", [True, False]) -def test_rmsnorm_jit(batch_size, hidden_size, dtype, specify_out): - eps = 1e-6 - x = torch.randn(batch_size, hidden_size, device="cuda", dtype=dtype) - w = torch.randn(hidden_size, device="cuda", dtype=dtype) - - # flashinfer reference - x_ref = x.clone() - _fi_rmsnorm(x_ref, w, out=x_ref, eps=eps) - - if specify_out: - y = torch.empty_like(x) - _jit_rmsnorm(x, w, output=y, eps=eps) - else: - y = x.clone() - _jit_rmsnorm(y, w, output=y, eps=eps) - - torch.testing.assert_close(y, x_ref, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize("batch_size", BS_LIST) -@pytest.mark.parametrize("hidden_size", FUSED_ADD_RMSNORM_HIDDEN_SIZES) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_fused_add_rmsnorm_jit(batch_size, hidden_size, dtype): - eps = 1e-6 - x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") - residual = torch.randn_like(x) - weight = torch.randn(hidden_size, dtype=dtype, device="cuda") - - # flashinfer reference - x_ref = x.clone() - r_ref = residual.clone() - _fi_fused_add_rmsnorm(x_ref, r_ref, weight, eps=eps) - - x_jit = x.clone() - r_jit = residual.clone() - _jit_fused_add_rmsnorm(x_jit, r_jit, weight, eps) - - torch.testing.assert_close(x_jit, x_ref, rtol=1e-2, atol=1e-2) - torch.testing.assert_close(r_jit, r_ref, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize( - ("hidden_size", "expected"), - [ - (0, False), - (64, True), - (128, True), - (256, True), - (512, True), - (8192, True), - (16384, False), - ], -) -def test_rmsnorm_hidden_size_support(hidden_size, expected): - from sglang.jit_kernel.norm import _is_supported_rmsnorm_hidden_size - - assert _is_supported_rmsnorm_hidden_size(hidden_size) is expected - - -@pytest.mark.parametrize( - ("hidden_size", "expected"), - [ - (64, "RMSNormWarpKernel"), - (128, "RMSNormWarpKernel"), - (256, "RMSNormWarpKernel"), - (512, "RMSNormKernel"), - (8192, "RMSNormKernel"), - ], -) -def test_rmsnorm_kernel_dispatch(hidden_size, expected): - from sglang.jit_kernel.norm import _rmsnorm_kernel_class - - assert _rmsnorm_kernel_class(hidden_size) == expected - - -@pytest.mark.parametrize("hidden_size", [0, 16384]) -def test_rmsnorm_rejects_unsupported_hidden_size(hidden_size): - from sglang.jit_kernel.norm import rmsnorm - - x = torch.randn(1, hidden_size) - w = torch.randn(hidden_size) - - with pytest.raises(RuntimeError, match=f"unsupported hidden_size={hidden_size}"): - rmsnorm(x, w) - - -if __name__ == "__main__": - sys.exit(pytest.main([__file__])) diff --git a/python/sglang/jit_kernel/tests/test_rmsnorm.py b/python/sglang/jit_kernel/tests/test_rmsnorm.py index ac31a792747d..59ce90f299f2 100644 --- a/python/sglang/jit_kernel/tests/test_rmsnorm.py +++ b/python/sglang/jit_kernel/tests/test_rmsnorm.py @@ -3,49 +3,104 @@ import pytest import torch -import triton from sglang.jit_kernel.utils import get_ci_test_range from sglang.test.ci.ci_register import register_cuda_ci -register_cuda_ci(est_time=18, suite="stage-b-kernel-unit-1-gpu-large") -register_cuda_ci(est_time=120, suite="nightly-kernel-1-gpu", nightly=True) +register_cuda_ci(est_time=45, suite="stage-b-kernel-unit-1-gpu-large") +register_cuda_ci(est_time=240, suite="nightly-kernel-1-gpu", nightly=True) -def sglang_jit_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: +EPS = 1e-6 +DEVICE = "cuda" +DTYPES = [torch.float16, torch.bfloat16] + + +def sglang_jit_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + *, + output: torch.Tensor | None = None, + eps: float = EPS, +) -> None: from sglang.jit_kernel.norm import rmsnorm - rmsnorm(input, weight, output=input) + rmsnorm(input, weight, out=output, eps=eps) -def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor) -> None: +def flashinfer_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + *, + output: torch.Tensor, + eps: float = EPS, +) -> None: from flashinfer.norm import rmsnorm - rmsnorm(input, weight, out=input) + rmsnorm(input, weight, out=output, eps=eps) BS_LIST = [2**n for n in range(0, 14)] BS_LIST += [x + 1 + i for i, x in enumerate(BS_LIST)] BS_LIST = get_ci_test_range(BS_LIST, [1, 9, 256, 4109]) -HIDDEN_SIZE_LIST = get_ci_test_range( - [512, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 7168, 8192], - [512, 2048, 8192], +SUPPORTED_HIDDEN_SIZE_LIST = get_ci_test_range( + [64, 128, 256, 512, *range(1024, 8192 + 1, 1024), 2304, 2560, 12288, 16384], + [256, 1024, 16384], ) -DEVICE = "cuda" -DTYPE = torch.bfloat16 @pytest.mark.parametrize( - "batch_size,hidden_size", list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)) + "batch_size,hidden_size", + list(itertools.product(BS_LIST, SUPPORTED_HIDDEN_SIZE_LIST)), ) -def test_rmsnorm(batch_size: int, hidden_size: int) -> None: - input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=DTYPE) - weight = torch.randn(hidden_size, device=DEVICE, dtype=DTYPE) - input_sglang = input.clone() +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_rmsnorm( + batch_size: int, hidden_size: int, dtype: torch.dtype, specify_out: bool +) -> None: + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + input_flashinfer = input.clone() - sglang_jit_rmsnorm(input_sglang, weight) - flashinfer_rmsnorm(input_flashinfer, weight) - triton.testing.assert_close(input_sglang, input_flashinfer, atol=1e-2, rtol=1e-2) + output_flashinfer = torch.empty_like(input) + flashinfer_rmsnorm(input_flashinfer, weight, output=output_flashinfer) + + if specify_out: + output_sglang = torch.empty_like(input) + sglang_jit_rmsnorm(input, weight, output=output_sglang) + else: + output_sglang = input.clone() + sglang_jit_rmsnorm(output_sglang, weight, output=output_sglang) + + torch.testing.assert_close(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("hidden_size", [64, 128, 256, 512, 8192, 8704, 16384]) +def test_rmsnorm_hidden_size_support(hidden_size: int) -> None: + from sglang.jit_kernel.norm import _is_supported_rmsnorm_hidden_size + + assert _is_supported_rmsnorm_hidden_size(hidden_size) + + +@pytest.mark.parametrize( + ("hidden_size", "expected"), + [ + (64, "RMSNormWarpKernel"), + (128, "RMSNormWarpKernel"), + (256, "RMSNormWarpKernel"), + (512, "RMSNormKernel"), + (1536, "RMSNormKernel"), + (2048, "RMSNormHalfKernel"), + (2304, "RMSNormKernel"), # NOTE: not 512 aligned + (8192, "RMSNormHalfKernel"), + (8704, "RMSNormHalfKernel"), + (16384, "RMSNormHalfKernel"), + ], +) +def test_rmsnorm_kernel_dispatch(hidden_size: int, expected: str) -> None: + from sglang.jit_kernel.norm import _rmsnorm_kernel_class + + assert _rmsnorm_kernel_class(hidden_size) == expected if __name__ == "__main__":