diff --git a/python/sglang/jit_kernel/benchmark/bench_dual_gemm.py b/python/sglang/jit_kernel/benchmark/bench_dual_gemm.py new file mode 100644 index 000000000000..acbdfccc733e --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_dual_gemm.py @@ -0,0 +1,92 @@ +import itertools + +import sgl_kernel +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import run_benchmark +from sglang.jit_kernel.cutedsl_dual_gemm import cutedsl_dual_gemm +from sglang.srt.compilation.fusion.ops.triton_ops.dual_gemm import dual_gemm + +DEVICE = "cuda" +# M_LIST = [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024] +M_LIST = [1, 8, 32, 64, 128, 256, 512, 1024] +K_LIST = [4096] +N_LIST = [4096, 11008] +DTYPE_LIST = [torch.bfloat16] + +configs = list(itertools.product(M_LIST, K_LIST, N_LIST, DTYPE_LIST)) + + +def _is_sm90(): + return torch.cuda.get_device_capability()[0] >= 9 + + +def cutedsl_fn(x, w_gate, w_up, out): + cutedsl_dual_gemm(x, w_gate, w_up, out) + + +def triton_fn(x, w, out): + out.copy_(dual_gemm(x, w)) + + +def reference_fn(x, w, out): + mm_result = torch.mm(x, w) # (M, 2*N) + sgl_kernel.silu_and_mul(mm_result, out) # (M, N) + + +def warmup(): + if not _is_sm90(): + return + for K, N, dtype in set((K, N, dtype) for _, K, N, dtype in configs): + M = M_LIST[0] + x = torch.randn((M, K), dtype=dtype, device=DEVICE) + w_gate = torch.randn((K, N), dtype=dtype, device=DEVICE) + w_up = torch.randn((K, N), dtype=dtype, device=DEVICE) + out = torch.empty((M, N), dtype=dtype, device=DEVICE) + cutedsl_dual_gemm(x, w_gate, w_up, out) + torch.cuda.synchronize() + + +LINE_VALS = ["cutedsl", "triton", "reference"] +LINE_NAMES = ["CuteDSL Dual GEMM", "Triton Dual GEMM", "Reference (mm + silu_and_mul)"] +STYLES = [("blue", "-"), ("orange", "--"), ("red", ":")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "K", "N", "dtype"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="dual-gemm-performance", + args={}, + ) +) +def benchmark(M: int, K: int, N: int, dtype: torch.dtype, provider: str): + if provider == "cutedsl" and not _is_sm90(): + return 0.0, 0.0, 0.0 + + x = torch.randn((M, K), dtype=dtype, device=DEVICE) + w_gate = torch.randn((K, N), dtype=dtype, device=DEVICE) + w_up = torch.randn((K, N), dtype=dtype, device=DEVICE) + out = torch.empty((M, N), dtype=dtype, device=DEVICE) + + # Concat for triton and reference (they take w as (K, 2*N)) + w = torch.cat([w_gate, w_up], dim=1) + + FN_MAP = { + "cutedsl": lambda: cutedsl_fn(x, w_gate, w_up, out), + "triton": lambda: triton_fn(x, w, out), + "reference": lambda: reference_fn(x, w, out), + } + return run_benchmark(FN_MAP[provider]) + + +if __name__ == "__main__": + warmup() + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/bench_dual_gemm_fp8.py b/python/sglang/jit_kernel/benchmark/bench_dual_gemm_fp8.py new file mode 100644 index 000000000000..82b944b91754 --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_dual_gemm_fp8.py @@ -0,0 +1,166 @@ +import itertools + +import sgl_kernel +import torch +import triton +import triton.testing + +from sglang.jit_kernel.benchmark.utils import run_benchmark +from sglang.jit_kernel.cutedsl_dual_gemm import cutedsl_dual_gemm +from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 +from sglang.srt.compilation.fusion.ops.triton_ops.dual_gemm import dual_gemm_kernel + +DTYPE = torch.bfloat16 +DEVICE = "cuda" +M_LIST = [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024] +K_LIST = [1024, 2048, 4096] +N_LIST = [1024, 2048, 4096, 8192] + +configs = list(itertools.product(M_LIST, K_LIST, N_LIST)) + + +def _is_sm90(): + return torch.cuda.get_device_capability()[0] >= 9 + + +def cutedsl_fn(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale): + cutedsl_dual_gemm(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale) + + +def triton_fn(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale): + M_val, K_val = x_fp8.shape + N_val = w_gate_fp8.shape[1] + + def grid(META): + return ( + triton.cdiv(M_val, META["BLOCK_SIZE_M"]) + * triton.cdiv(N_val, META["BLOCK_SIZE_K"]), + ) + + dual_gemm_kernel[grid]( + x_fp8, + w_gate_fp8, + w_up_fp8, + out_fp8, + x_scale, + w_scale, + o_scale, + True, + x_fp8.stride(0), + x_fp8.stride(1), + w_gate_fp8.stride(0), + w_gate_fp8.stride(1), + out_fp8.stride(0), + out_fp8.stride(1), + M_val, + K_val, + N_val, + torch.finfo(torch.float8_e4m3fn).min, + torch.finfo(torch.float8_e4m3fn).max, + 128, + 64, + 128, + 1, + num_warps=4, + num_stages=4, + ) + + +def reference_fp8_fn( + x_fp8, w_fp8_col_major, out_bf16, out_fp8, x_scale, w_scale, o_scale +): + # Step 1: FP8 scaled matmul -> BF16 + # w_fp8_col_major is (K, 2*N) column-major for _scaled_mm + mm_result = torch._scaled_mm( + x_fp8, + w_fp8_col_major, + scale_a=x_scale, + scale_b=w_scale, + out_dtype=torch.bfloat16, + ) + # Step 2: silu_and_mul + sgl_kernel.silu_and_mul(mm_result, out_bf16) + # Step 3: per-tensor FP8 quantize + per_tensor_quant_fp8(out_bf16, out_fp8, o_scale, is_static=True) + + +def warmup(): + if not _is_sm90(): + return + for K, N in set((K, N) for _, K, N in configs): + M = M_LIST[0] + x = torch.randn((M, K), dtype=DTYPE, device=DEVICE) + x_fp8 = x.to(torch.float8_e4m3fn) + w_gate_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to( + torch.float8_e4m3fn + ) + w_up_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to( + torch.float8_e4m3fn + ) + out_fp8 = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=DEVICE) + x_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + w_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + o_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + cutedsl_dual_gemm( + x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale + ) + torch.cuda.synchronize() + + +LINE_VALS = ["cutedsl", "triton", "reference"] +LINE_NAMES = [ + "CuteDSL Dual GEMM FP8", + "Triton Dual GEMM FP8", + "Reference (scaled_mm + silu_and_mul + quant_fp8)", +] +STYLES = [("blue", "-"), ("orange", "--"), ("red", ":")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "K", "N"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="dual-gemm-fp8-performance", + args={}, + ) +) +def benchmark(M: int, K: int, N: int, provider: str): + if provider == "cutedsl" and not _is_sm90(): + return 0.0, 0.0, 0.0 + + x_bf16 = torch.randn((M, K), dtype=DTYPE, device=DEVICE) + x_fp8 = x_bf16.to(torch.float8_e4m3fn) + w_gate_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(torch.float8_e4m3fn) + w_up_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(torch.float8_e4m3fn) + out_fp8 = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=DEVICE) + x_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + w_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + o_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE) + + # For reference: _scaled_mm expects b as column-major (2*N, K).t() + w_cat = torch.cat([w_gate_fp8, w_up_fp8], dim=1) # (K, 2*N) + w_fp8_t = w_cat.t().contiguous().t() # (K, 2*N) column-major + out_bf16 = torch.empty((M, N), dtype=DTYPE, device=DEVICE) + + FN_MAP = { + "cutedsl": lambda: cutedsl_fn( + x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale + ), + "triton": lambda: triton_fn( + x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale + ), + "reference": lambda: reference_fp8_fn( + x_fp8, w_fp8_t, out_bf16, out_fp8, x_scale, w_scale, o_scale + ), + } + return run_benchmark(FN_MAP[provider]) + + +if __name__ == "__main__": + warmup() + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm_quant.py b/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm_quant.py new file mode 100644 index 000000000000..9343e3436faa --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_fused_add_rmsnorm_quant.py @@ -0,0 +1,157 @@ +import itertools + +import torch +import triton +import triton.testing +from flashinfer import fused_add_rmsnorm_quant as fi_fused_add_rmsnorm_quant +from sgl_kernel import fused_add_rms_norm_static_fp8_quant + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.norm import fused_add_rmsnorm as jit_fused_add_rmsnorm +from sglang.jit_kernel.norm import ( + fused_add_rmsnorm_quant as jit_fused_add_rmsnorm_quant, +) +from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 as jit_quant + +FP8_DTYPE = torch.float8_e4m3fn +FP8_E4M3_MAX = 448.0 + + +def sglang_aot_fused_add_rmsnorm_quant( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + fused_add_rms_norm_static_fp8_quant(out, input, residual, weight, scale) + + +def sglang_jit_fused_add_rmsnorm_quant( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + jit_fused_add_rmsnorm_quant(out, input, residual, weight, scale) + + +def flashinfer_fused_add_rmsnorm_quant( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + fi_fused_add_rmsnorm_quant(out, input, residual, weight, scale) + + +def sglang_unfused_jit( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + # fused_add_rmsnorm: residual += input, input = rmsnorm(residual) + jit_fused_add_rmsnorm(input, residual, weight) + # input now holds the normed result, quantize it + jit_quant(input, out, scale, is_static=True) + + +@torch.compile() +def torch_impl_fused_add_rmsnorm_quant( + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, + eps: float = 1e-6, +) -> None: + residual.add_(input) + x = residual.float() + mean = x.pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + normed = x * norm * weight.float() + inv_scale = 1.0 / scale + out.copy_((normed * inv_scale).clamp(-FP8_E4M3_MAX, FP8_E4M3_MAX).to(FP8_DTYPE)) + + +BS_LIST = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16], +) +HIDDEN_SIZE_LIST = get_benchmark_range( + full_range=[1536, 3072, 4096, 5120, 8192], + ci_range=[512, 2048], +) + +LINE_VALS = [ + "fused_aot", + "fused_jit", + "fused_flashinfer", + "unfused_jit", + "unfused_torch", +] +LINE_NAMES = [ + "SGL AOT Fused", + "SGL JIT Fused", + "FlashInfer Fused", + "SGL JIT Unfused", + "PyTorch Unfused", +] +STYLES = [ + ("orange", "-"), + ("blue", "--"), + ("purple", "-."), + ("green", "-."), + ("red", ":"), +] + +configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="fused-add-rmsnorm-quant-performance", + args={}, + ) +) +def benchmark(hidden_size: int, batch_size: int, provider: str): + input = torch.randn( + (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + residual = torch.randn( + (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + scale = torch.tensor([4.0], dtype=torch.float32, device=DEFAULT_DEVICE) + out = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE, device=DEFAULT_DEVICE) + FN_MAP = { + "fused_aot": sglang_aot_fused_add_rmsnorm_quant, + "fused_jit": sglang_jit_fused_add_rmsnorm_quant, + "fused_flashinfer": flashinfer_fused_add_rmsnorm_quant, + "unfused_jit": sglang_unfused_jit, + "unfused_torch": torch_impl_fused_add_rmsnorm_quant, + } + fn = lambda: FN_MAP[provider]( + input.clone(), residual.clone(), weight, scale, out.clone() + ) + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/benchmark/bench_rmsnorm_quant.py b/python/sglang/jit_kernel/benchmark/bench_rmsnorm_quant.py new file mode 100644 index 000000000000..ec7fa5ee244a --- /dev/null +++ b/python/sglang/jit_kernel/benchmark/bench_rmsnorm_quant.py @@ -0,0 +1,164 @@ +import itertools + +import torch +import triton +import triton.testing +from flashinfer import rmsnorm_quant as fi_rmsnorm_quant +from sgl_kernel import rms_norm_static_fp8_quant + +from sglang.jit_kernel.benchmark.utils import ( + DEFAULT_DEVICE, + DEFAULT_DTYPE, + get_benchmark_range, + run_benchmark, +) +from sglang.jit_kernel.norm import rmsnorm as jit_rmsnorm +from sglang.jit_kernel.norm import rmsnorm_quant as jit_rmsnorm_quant +from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 as jit_quant + +DEVICE = "cuda" +FP8_DTYPE = torch.float8_e4m3fn +# maximum value for e4m3fn for clamping in kernel +FP8_E4M3_MAX = 448.0 +# FP8 is low precision, so the tolerance needs to be higher +TOLERANCE = {"atol": 1.5e-1, "rtol": 1.5e-1} +FP_TOLERANCE = {"atol": 1e-4, "rtol": 1e-4} + + +def scaled_fp8_conversion_ref( + val: torch.Tensor, scale: torch.Tensor, fp8_dtype: torch.dtype +) -> torch.Tensor: + """Helper function matching the scaled_fp8_conversion device function.""" + quant_scale = 1.0 / scale + + x = val * quant_scale + + r = torch.clamp(x, min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) + + if r.dtype != fp8_dtype: + return r.to(fp8_dtype) + return r + + +def sglang_aot_rmsnorm_quant( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + rms_norm_static_fp8_quant(out, input, weight, scale) + + +def sglang_jit_rmsnorm_quant( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + jit_rmsnorm_quant(out, input, weight, scale) + + +def sglang_unfused_jit( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + temp = input.clone() + jit_rmsnorm(temp, weight, output=temp) + jit_quant(temp, out, scale, is_static=True) + + +def flashinfer_rmsnorm_quant( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, +) -> None: + fi_rmsnorm_quant(out, input, weight, scale) + + +@torch.compile() +def torch_impl_rmsnorm_quant( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + out: torch.Tensor, + eps: float = 1e-6, +) -> None: + mean = input.float().pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + out.copy_( + scaled_fp8_conversion_ref( + (input.float() * norm * weight.float()), scale, FP8_DTYPE + ) + ) + + +BS_LIST = get_benchmark_range( + full_range=[2**n for n in range(0, 14)], + ci_range=[16], +) +HIDDEN_SIZE_LIST = get_benchmark_range( + full_range=[1536, 3072, 4096, 5120, 8192], + ci_range=[512, 2048], +) + +LINE_VALS = [ + "fused_aot", + "fused_jit", + "fused_flashinfer", + "unfused_jit", + "unfused_torch", +] +LINE_NAMES = [ + "SGL AOT Fused", + "SGL JIT Fused", + "FlashInfer Fused", + "SGL JIT Unfused", + "PyTorch Unfused", +] +STYLES = [ + ("orange", "-"), + ("blue", "--"), + ("purple", "-."), + ("green", "-."), + ("red", ":"), +] + +configs = list(itertools.product(HIDDEN_SIZE_LIST, BS_LIST)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["hidden_size", "batch_size"], + x_vals=configs, + line_arg="provider", + line_vals=LINE_VALS, + line_names=LINE_NAMES, + styles=STYLES, + ylabel="us", + plot_name="rmsnorm-performance", + args={}, + ) +) +def benchmark(hidden_size: int, batch_size: int, provider: str): + input = torch.randn( + (batch_size, hidden_size), dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE + ) + weight = torch.randn(hidden_size, dtype=DEFAULT_DTYPE, device=DEFAULT_DEVICE) + scale = torch.tensor([4.0], dtype=torch.float32, device=DEFAULT_DEVICE) + out = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE, device=DEFAULT_DEVICE) + FN_MAP = { + "fused_aot": sglang_aot_rmsnorm_quant, + "fused_jit": sglang_jit_rmsnorm_quant, + "fused_flashinfer": flashinfer_rmsnorm_quant, + "unfused_jit": sglang_unfused_jit, + "unfused_torch": torch_impl_rmsnorm_quant, + } + fn = lambda: FN_MAP[provider](input, weight, scale, out.clone()) + return run_benchmark(fn) + + +if __name__ == "__main__": + benchmark.run(print_data=True) diff --git a/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm_quant.cuh b/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm_quant.cuh new file mode 100644 index 000000000000..e5b6e559436c --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/fused_add_rmsnorm_quant.cuh @@ -0,0 +1,213 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +namespace { + +template +struct FusedAddRMSNormQuantVecTypeTrait; + +template <> +struct FusedAddRMSNormQuantVecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct FusedAddRMSNormQuantVecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct FusedAddRMSNormQuantVecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template <> +struct FusedAddRMSNormQuantVecTypeTrait { + using packed_t = packed_t; + using vec_t = device::AlignedVector; +}; + +template +__global__ void fused_add_rmsnorm_quant_reg_kernel( + T* __restrict__ input, + T* __restrict__ residual, + const T* __restrict__ weight, + fp8_e4m3_t* __restrict__ output, + const float* __restrict__ scale, + int vec_hidden_size, + float eps) { + constexpr int inner_loop = VEC_SIZE_IN_BYTE == 16 ? 4 : 8; + // Number of fp8 elements per thread: each inner_loop iteration yields 2 scalars + constexpr int fp8_per_thread = inner_loop * 2; + + __shared__ float shared_memory[32]; // Used for CTA reduce + + using vec_t = typename FusedAddRMSNormQuantVecTypeTrait::vec_t; + using packed_t = typename FusedAddRMSNormQuantVecTypeTrait::packed_t; + vec_t v; // Save input + vec_t v_res; // Save residual + vec_t v_weight; // Save weight + + auto token_id = blockIdx.x; + float2 acc_square = make_float2(0.0f, 0.0f); // Sum of squares for each thread + + if (threadIdx.x < vec_hidden_size) { + // Compute address + vec_t* p = reinterpret_cast(input) + token_id * vec_hidden_size; + vec_t* p_res = reinterpret_cast(residual) + token_id * vec_hidden_size; + const vec_t* p_weight = reinterpret_cast(weight); + + // Load data + v = p[threadIdx.x]; + v_res = p_res[threadIdx.x]; + v_weight = p_weight[threadIdx.x]; + + for (int i = 0; i < inner_loop; i++) { + float2 val = device::cast(v[i]); + float2 res = device::cast(v_res[i]); + float2 inp_res = make_float2(val.x + res.x, val.y + res.y); + acc_square.x += inp_res.x * inp_res.x; + acc_square.y += inp_res.y * inp_res.y; + v[i] = device::cast(inp_res); + } + + // Store inp+res to residual + p_res[threadIdx.x] = v; + } + + // CTA Reduce + // Step 0: Warp Reduce + auto cg_warp = cooperative_groups::tiled_partition<32>(cooperative_groups::this_thread_block()); + float warp_sum = cooperative_groups::reduce(cg_warp, acc_square.x + acc_square.y, cooperative_groups::plus()); + + float* buffer = shared_memory; + if (threadIdx.x % 32 == 0) { + buffer[threadIdx.x / 32] = warp_sum; // Write warp_sum to buffer + } + + // Step 1: CTA Reduce + __syncthreads(); + if (threadIdx.x < 32) { + float cta_sum = cooperative_groups::reduce( + cg_warp, (threadIdx.x < blockDim.x / 32) ? buffer[threadIdx.x] : 0.0f, cooperative_groups::plus()); + buffer[threadIdx.x] = + rsqrtf(eps + cta_sum * (1.0f / static_cast(vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T))))); + } + __syncthreads(); + + // Compute RMSNorm + FP8 quantization + if (threadIdx.x < vec_hidden_size) { + float rsqrt_square_sum = buffer[threadIdx.x / 32]; // Read rsqrt from Shared Memory(Broadcast) + const float inv_scale = 1.0f / (*scale); + + device::AlignedVector fp8_out; +#pragma unroll + for (int i = 0; i < inner_loop; i++) { + float2 valf = device::cast(v[i]); + float2 weightf = device::cast(v_weight[i]); + float val0 = valf.x * weightf.x * rsqrt_square_sum * inv_scale; + float val1 = valf.y * weightf.y * rsqrt_square_sum * inv_scale; + fp8_out[i * 2 + 0] = static_cast( + device::math::max(-device::math::FP8_E4M3_MAX, device::math::min(val0, device::math::FP8_E4M3_MAX))); + fp8_out[i * 2 + 1] = static_cast( + device::math::max(-device::math::FP8_E4M3_MAX, device::math::min(val1, device::math::FP8_E4M3_MAX))); + } + + // Vectorized fp8 store + // Each thread handles fp8_per_thread fp8 elements at the corresponding position + auto* output_row = output + token_id * vec_hidden_size * (VEC_SIZE_IN_BYTE / sizeof(T)); + fp8_out.store(reinterpret_cast(output_row), threadIdx.x); + } +} + +template +struct FusedAddRMSNormQuantKernel { + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView residual, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView scale, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto device = SymbolicDevice{}; + device.set_options(); + + TensorMatcher({N, D}) // input + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(input); + TensorMatcher({D}) // weight + .with_dtype() + .with_device(device) + .verify(weight); + TensorMatcher({N, D}) // residual + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(residual); + TensorMatcher({N, D}) // output (fp8) + .with_strides({D, 1}) + .with_dtype() + .with_device(device) + .verify(output); + TensorMatcher({1}) // scale + .with_dtype() + .with_device(device) + .verify(scale); + + auto cc_major = host::runtime::get_cc_major(device.unwrap().device_id); + int hidden_size = static_cast(D.unwrap()); + if ((cc_major <= 9 && hidden_size <= 8192) || (cc_major >= 10 && hidden_size <= 12288)) { + int max_vec_size_byte = cc_major >= 10 ? 32 : 16; + int elements_in_vec = max_vec_size_byte / sizeof(DType); + int vec_hidden_size = hidden_size / elements_in_vec; + uint threads = (vec_hidden_size + 31) / 32 * 32; + + // Runtime check + host::RuntimeCheck( + hidden_size % elements_in_vec == 0, + "hidden_size", + hidden_size, + " can not align to elements_in_vec ", + elements_in_vec); + + // Launch kernel + auto kernel = max_vec_size_byte == 32 ? fused_add_rmsnorm_quant_reg_kernel + : fused_add_rmsnorm_quant_reg_kernel; + LaunchKernel(static_cast(N.unwrap()), threads, device.unwrap()) + .enable_pdl(false)( + kernel, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(residual.data_ptr()), + reinterpret_cast(weight.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(scale.data_ptr()), + vec_hidden_size, + eps); + } else { + host::RuntimeCheck(false, "Large hidden_sizes are not supported for now."); + } + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/csrc/elementwise/rmsnorm_quant.cuh b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm_quant.cuh new file mode 100644 index 000000000000..cd05e3fe4d17 --- /dev/null +++ b/python/sglang/jit_kernel/csrc/elementwise/rmsnorm_quant.cuh @@ -0,0 +1,144 @@ +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include + +namespace { + +struct RMSNormQuantParams { + const void* input; + const void* __restrict__ weight; + void* output; + const float* __restrict__ scale; + int64_t input_stride; + int64_t output_stride; + uint32_t num_tokens; + float eps; +}; + +template +__global__ void rmsnorm_quant_cta(const RMSNormQuantParams __grid_constant__ params) { + using namespace device; + using InputStorage = norm::StorageType; + + constexpr auto kNumThreads = host::norm::get_cta_threads(); + constexpr auto kNumWarps = kNumThreads / kWarpThreads; + // Each thread loads 4 packed elements (16B) for input/weight. + // packed_t is 2 scalars (4B), so 4 of them = 8 scalars per thread. + // For fp8 output: 8 scalars * 1 byte = 8 bytes per thread. + constexpr auto kFP8VecSize = 8u; + + const auto& [input, weight_ptr, output, scale_ptr, input_stride, output_stride, num_tokens, eps] = params; + const auto gmem_in = tile::Memory::cta(kNumThreads); + __shared__ float smem[norm::kSmemBufferSize]; + + PDLWaitPrimary(); + + const float inv_scale = 1.0f / (*scale_ptr); + + for (uint32_t i = blockIdx.x; i < num_tokens; i += gridDim.x) { + const auto input_ptr = pointer::offset(input, i * input_stride); + const auto input_vec = gmem_in.load(input_ptr); + const auto weight_vec = gmem_in.load(weight_ptr); + + // Compute sum of squares for RMSNorm + float sum_of_squares = 0.0f; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto fp32_pair = cast(input_vec[j]); + sum_of_squares += fp32_pair.x * fp32_pair.x; + sum_of_squares += fp32_pair.y * fp32_pair.y; + } + + // Warp reduce then CTA reduce + sum_of_squares = warp::reduce_sum(sum_of_squares); + const auto warp_id = threadIdx.x / kWarpThreads; + smem[warp_id] = sum_of_squares; + __syncthreads(); + if (warp_id == 0) { + const auto tx = threadIdx.x; + const auto local_sum = tx < kNumWarps ? smem[tx] : 0.0f; + sum_of_squares = warp::reduce_sum(local_sum); + smem[32] = math::rsqrt(sum_of_squares / kDim + eps); + } + __syncthreads(); + const float norm_factor = smem[32]; + + // Apply norm, scale, clamp and convert to fp8 + AlignedVector output_vec; +#pragma unroll + for (auto j = 0u; j < 4u; ++j) { + const auto fp32_pair = cast(input_vec[j]); + const auto fp32_weight = cast(weight_vec[j]); + const float val0 = fp32_pair.x * norm_factor * fp32_weight.x * inv_scale; + const float val1 = fp32_pair.y * norm_factor * fp32_weight.y * inv_scale; + output_vec[j * 2 + 0] = + static_cast(math::max(-math::FP8_E4M3_MAX, math::min(val0, math::FP8_E4M3_MAX))); + output_vec[j * 2 + 1] = + static_cast(math::max(-math::FP8_E4M3_MAX, math::min(val1, math::FP8_E4M3_MAX))); + } + + // Vectorized fp8 store (8 bytes) + auto* output_row = pointer::offset(output, i * output_stride); + const auto gmem_out = tile::Memory>::cta(kNumThreads); + gmem_out.store(output_row, output_vec); + } + + PDLTriggerSecondary(); +} + +template +struct RMSNormQuantKernel { + static_assert(host::norm::should_use_cta(), "Hidden size must be > 256 for RMSNormQuant"); + static constexpr auto kernel = rmsnorm_quant_cta; + + static void + run(const tvm::ffi::TensorView input, + const tvm::ffi::TensorView weight, + const tvm::ffi::TensorView output, + const tvm::ffi::TensorView scale, + float eps) { + using namespace host; + auto N = SymbolicSize{"num_tokens"}; + auto D = SymbolicSize{"hidden_size"}; + auto SI = SymbolicSize{"input_stride"}; + auto SO = SymbolicSize{"output_stride"}; + auto device = SymbolicDevice{}; + D.set_value(kDim); + device.set_options(); + + TensorMatcher({N, D}).with_strides({SI, 1}).with_dtype().with_device(device).verify(input); + TensorMatcher({D}).with_dtype().with_device(device).verify(weight); + TensorMatcher({N, D}).with_strides({SO, 1}).with_dtype().with_device(device).verify(output); + TensorMatcher({1}).with_dtype().with_device(device).verify(scale); + + const auto num_tokens = static_cast(N.unwrap()); + const auto params = RMSNormQuantParams{ + .input = input.data_ptr(), + .weight = weight.data_ptr(), + .output = output.data_ptr(), + .scale = static_cast(scale.data_ptr()), + .input_stride = SI.unwrap(), + .output_stride = SO.unwrap(), + .num_tokens = num_tokens, + .eps = eps, + }; + + static constexpr auto kNumThreads = norm::get_cta_threads(); + static const uint32_t max_occupancy = runtime::get_blocks_per_sm(kernel, kNumThreads); + static const uint32_t kNumSM = runtime::get_sm_count(device.unwrap().device_id); + const auto num_blocks = std::min(num_tokens, max_occupancy * kNumSM); + LaunchKernel(num_blocks, kNumThreads, device.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/python/sglang/jit_kernel/cutedsl_dual_gemm.py b/python/sglang/jit_kernel/cutedsl_dual_gemm.py new file mode 100644 index 000000000000..80862e4e801d --- /dev/null +++ b/python/sglang/jit_kernel/cutedsl_dual_gemm.py @@ -0,0 +1,1212 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +""" +CuteDSL SM90 Dual GEMM kernel: fuses SiLU(X @ W_gate) * (X @ W_up) into a +single kernel using TMA + WGMMA on Hopper GPUs. + +Architecture: + - Producer warpgroup: TMA loads A, B0, B1 into SMEM (A shared between GEMMs) + - Consumer warpgroup: 2x WGMMA per K-tile (reusing A descriptor) + - Epilogue: fused SiLU*mul in registers, then R2S + TMA store + +Supports optional FP8 quantized mode: + - FP8 (e4m3fn) inputs with per-tensor x_scale and w_scale + - FP8 (e4m3fn) output with o_scale for requantization +""" + +import math +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.hopper_helpers as sm90_utils +import torch +from cutlass.cute.nvgpu.common import CopyUniversalOp +from cutlass.cute.runtime import from_dlpack +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +from sglang.srt.utils.custom_op import register_custom_op + +# Monkey-patch PersistentTileSchedulerParams to fix attribute name mismatches +# between JIT-traced __init__ (which creates _raster_along_m, cluster_shape_m_fdd, +# cluster_shape_n_fdd) and __extract_mlir_values__ (which expects raster_along_m, +# cluster_shape_major_fdd, cluster_shape_minor_fdd). +_PTSP_ATTR_FIXES = { + "raster_along_m": "_raster_along_m", + "cluster_shape_major_fdd": "cluster_shape_m_fdd", + "cluster_shape_minor_fdd": "cluster_shape_n_fdd", +} +_orig_extract = utils.PersistentTileSchedulerParams.__extract_mlir_values__ + + +def _patched_extract(self): + for expected, actual in _PTSP_ATTR_FIXES.items(): + if not hasattr(self, expected) and hasattr(self, actual): + setattr(self, expected, getattr(self, actual)) + return _orig_extract(self) + + +utils.PersistentTileSchedulerParams.__extract_mlir_values__ = _patched_extract + +_COMPILED_KERNELS = {} +_WEIGHT_CACHE = {} # id(w) -> (w_ref, w_transposed, cute_tensor) +_KERNEL_OBJ_CACHE = ( + {} +) # (acc_dtype, tile_mn, cluster_mn, use_fp8) -> HopperDualGemmKernel +_DUMMY_SCALE_CACHE = {} # device -> (dummy_tensor, cute_dummy) +_RAW_SCALE_CACHE = {} # data_ptr -> (scale_f32, cute_tensor) +_STREAM_CACHE = {} # cuda_stream_int -> CUstream + + +class HopperDualGemmKernel: + """SM90 persistent dual GEMM: computes SiLU(A @ B0) * (A @ B1). + + Uses warp specialization with dedicated DMA and MMA warp groups. + DMA warp group loads A, B0, B1 via TMA; MMA warp groups compute + both GEMMs and fused SiLU*mul epilogue. Persistent tile scheduler + keeps CTAs alive across multiple output tiles. + """ + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + tile_shape_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_fp8_scales: bool = False, + swizzle_size: int = 1, + raster_along_m: bool = True, + ): + self.acc_dtype = acc_dtype + self.cluster_shape_mn = cluster_shape_mn + self.use_fp8_scales = use_fp8_scales + self.swizzle_size = swizzle_size + self.raster_along_m = raster_along_m + self.tile_shape_mnk = (*tile_shape_mn, 1) + self.atom_layout_mnk = ( + (2, 1, 1) + if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128 + else (1, 1, 1) + ) + + # Warp specialization: 1 DMA warp group + N MMA warp groups + self.num_dma_warp_groups = 1 + self.num_mma_warp_groups = math.prod(self.atom_layout_mnk) + self.num_warps_per_warp_group = 4 + self.num_threads_per_warp_group = self.num_warps_per_warp_group * 32 + self.threads_per_cta = ( + self.num_dma_warp_groups + self.num_mma_warp_groups + ) * self.num_threads_per_warp_group + self.load_warp_id = 0 + self.epi_store_warp_id = ( + self.num_dma_warp_groups * self.num_warps_per_warp_group + ) + self.load_register_requirement = 40 + self.mma_register_requirement = 232 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") + + self.occupancy = 1 + self.ab_stage = None + self.epi_stage = None + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + self.tiled_mma = None + self.shared_storage = None + self.buffer_align_bytes = 128 + + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + + self.num_mma_threads = ( + self.num_mma_warp_groups * self.num_threads_per_warp_group + ) + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=self.num_mma_threads + ) + + def _setup_attributes(self): + if self.tile_shape_mnk[0] not in [64, 128]: + raise ValueError("CTA tile shape M must be 64/128") + if self.tile_shape_mnk[1] not in [64, 128, 256]: + raise ValueError("CTA tile shape N must be 64/128/256") + + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.tile_shape_mnk = ( + self.tile_shape_mnk[0], + self.tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1)) + self.num_mcast_ctas_a = self.cluster_shape_mn[1] + self.num_mcast_ctas_b = self.cluster_shape_mn[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + is_cooperative = self.atom_layout_mnk == (2, 1, 1) + self.epi_tile = sm90_utils.compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative + ) + + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.smem_capacity, + self.occupancy, + ) + + self.a_smem_layout_staged = sm90_utils.make_smem_layout_a( + self.a_layout, + self.tile_shape_mnk, + self.a_dtype, + self.ab_stage, + ) + self.b_smem_layout_staged = sm90_utils.make_smem_layout_b( + self.b_layout, + self.tile_shape_mnk, + self.b_dtype, + self.ab_stage, + ) + self.epi_smem_layout_staged = sm90_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.epi_stage, + ) + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: tuple[int, int], + c_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + epi_stage = 2 + c_bytes_per_stage = cute.size(epi_tile) * c_dtype.width // 8 + epi_bytes = c_bytes_per_stage * epi_stage + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + # 3 buffers per stage: A + B0 + B1 + bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + 2 * cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + epi_bytes) + ) // bytes_per_stage + return ab_stage, epi_stage + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b0: cute.Tensor, + b1: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + x_scale: cute.Tensor = None, + w_scale_gate: cute.Tensor = None, + w_scale_up: cute.Tensor = None, + o_scale: cute.Tensor = None, + ): + self.a_dtype = a.element_type + self.b_dtype = b0.element_type + self.c_dtype = c.element_type + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b0) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + self._setup_attributes() + + # TMA atoms for A, B0, B1 (loads) + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mn[1], + ) + tma_atom_b0, tma_tensor_b0 = self._make_tma_atoms_and_tensors( + b0, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mn[0], + ) + tma_atom_b1, tma_tensor_b1 = self._make_tma_atoms_and_tensors( + b1, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mn[0], + ) + + # TMA atom for C (store) + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + tile_sched_params, grid = self._compute_grid( + c, + self.tile_shape_mnk, + self.cluster_shape_mn, + self.swizzle_size, + self.raster_along_m, + max_active_clusters, + ) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB0: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB1: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged), + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b0, + tma_tensor_b0, + tma_atom_b1, + tma_tensor_b1, + tma_atom_c, + tma_tensor_c, + self.tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + tile_sched_params, + x_scale, + w_scale_gate, + w_scale_up, + o_scale, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b0: cute.CopyAtom, + mB0_nkl: cute.Tensor, + tma_atom_b1: cute.CopyAtom, + mB1_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: utils.PersistentTileSchedulerParams, + m_x_scale: cute.Tensor = None, + m_w_scale_gate: cute.Tensor = None, + m_w_scale_up: cute.Tensor = None, + m_o_scale: cute.Tensor = None, + ): + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch TMA descriptors + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b0) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b1) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + # Multicast masks + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + 2 * cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # ===================================================================== + # Allocate SMEM and create pipeline + # ===================================================================== + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = ( + mcast_size * self.num_mma_warp_groups * self.num_warps_per_warp_group + ) + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)), + defer_sync=True, + ) + + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # ===================================================================== + # SMEM tensors + # ===================================================================== + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB0 = storage.sB0.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sB1 = storage.sB1.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # Separate sC buffer for persistent kernel (DMA may load next tile's A + # while MMA is still writing epilogue) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + + # ===================================================================== + # Partition global tensors (all tiles) + # ===================================================================== + gA_mkl = cute.local_tile( + mA_mkl, + cute.slice_(self.tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + gB0_nkl = cute.local_tile( + mB0_nkl, + cute.slice_(self.tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + gB1_nkl = cute.local_tile( + mB1_nkl, + cute.slice_(self.tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + gC_mnl = cute.local_tile( + mC_mnl, + cute.slice_(self.tile_shape_mnk, (None, None, 0)), + (None, None, None), + ) + + # ===================================================================== + # TMA load partitions + # ===================================================================== + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + cute.group_modes(sA, 0, 2), + cute.group_modes(gA_mkl, 0, 2), + ) + + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + tB0sB0, tB0gB0 = cute.nvgpu.cpasync.tma_partition( + tma_atom_b0, + b_cta_crd, + b_cta_layout, + cute.group_modes(sB0, 0, 2), + cute.group_modes(gB0_nkl, 0, 2), + ) + tB1sB1, tB1gB1 = cute.nvgpu.cpasync.tma_partition( + tma_atom_b1, + b_cta_crd, + b_cta_layout, + cute.group_modes(sB1, 0, 2), + cute.group_modes(gB1_nkl, 0, 2), + ) + + # ===================================================================== + # MMA thread partitions + # ===================================================================== + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + mma_warp_group_thread_layout = cute.make_layout( + self.num_mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma = tiled_mma.get_slice( + mma_warp_group_thread_layout(warp_group_idx - self.num_dma_warp_groups) + ) + + tCsA = thr_mma.partition_A(sA) + tCsB0 = thr_mma.partition_B(sB0) + tCsB1 = thr_mma.partition_B(sB1) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB0 = tiled_mma.make_fragment_B(tCsB0) + tCrB1 = tiled_mma.make_fragment_B(tCsB1) + + tCgC = thr_mma.partition_C(gC_mnl) + acc_shape = tCgC.shape[:3] + acc_gate = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + acc_up = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + num_k_blocks = cute.size(tCrA, mode=[2]) + + # ===================================================================== + # Cluster wait + # ===================================================================== + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + is_dma_warp_group = warp_group_idx < self.num_dma_warp_groups + + # ===================================================================== + # DMA warp group: persistent TMA loader + # ===================================================================== + if is_dma_warp_group: + cute.arch.setmaxregister_decrease(self.load_register_requirement) + + if warp_idx == self.load_warp_id: + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])] + tB0gB0_nkl = tB0gB0[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] + tB1gB1_nkl = tB1gB1[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] + + mainloop_producer_state.reset_count() + + for k_tile in range(k_tile_cnt): + mainloop_pipeline.producer_acquire(mainloop_producer_state) + + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + tB0gB0_k = tB0gB0_nkl[(None, mainloop_producer_state.count)] + tB0sB0_pipe = tB0sB0[(None, mainloop_producer_state.index)] + tB1gB1_k = tB1gB1_nkl[(None, mainloop_producer_state.count)] + tB1sB1_pipe = tB1sB1[(None, mainloop_producer_state.index)] + + bar_ptr = mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ) + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=bar_ptr, + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b0, + tB0gB0_k, + tB0sB0_pipe, + tma_bar_ptr=bar_ptr, + mcast_mask=b_mcast_mask, + ) + cute.copy( + tma_atom_b1, + tB1gB1_k, + tB1sB1_pipe, + tma_bar_ptr=bar_ptr, + mcast_mask=b_mcast_mask, + ) + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mainloop_pipeline.producer_tail(mainloop_producer_state) + + # ===================================================================== + # MMA warp group: persistent compute + epilogue + # ===================================================================== + if not is_dma_warp_group: + cute.arch.setmaxregister_increase(self.mma_register_requirement) + + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + mainloop_consumer_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + # Epilogue setup + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + if cutlass.const_expr(self.c_dtype.width == 8): + copy_atom_C = cute.make_copy_atom(CopyUniversalOp(), self.c_dtype) + else: + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_Atom) + + thr_copy_r2s = tiled_copy_r2s.get_slice( + tidx - self.num_dma_warp_groups * self.num_threads_per_warp_group + ) + tRS_sD = thr_copy_r2s.partition_D(sC) + tRS_rAcc_gate = tiled_copy_r2s.retile(acc_gate) + tRS_rAcc_up = tiled_copy_r2s.retile(acc_up) + + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype) + tRS_rD_out = cute.make_rmem_tensor(tRS_rD_layout.shape, self.c_dtype) + size_tRS_rD = cute.size(tRS_rD) + + k_pipe_mmas = 1 + prologue_mma_cnt = min(k_pipe_mmas, k_tile_cnt) + + # TMA store pipeline + tma_store_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.num_mma_threads, + ) + tma_store_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=tma_store_producer_group, + ) + + if cutlass.const_expr(self.use_fp8_scales): + input_scale_gate_val = m_x_scale[0] * m_w_scale_gate[0] + input_scale_up_val = m_x_scale[0] * m_w_scale_up[0] + output_scale_inv_val = cutlass.Float32(1.0) / m_o_scale[0] + fp8_max = cutlass.Float32(448.0) + fp8_min = cutlass.Float32(-448.0) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)] + + # ============= MAINLOOP ============= + mainloop_consumer_read_state.reset_count() + mainloop_consumer_release_state.reset_count() + acc_gate.fill(0.0) + acc_up.fill(0.0) + + # --- Prologue: dual GEMM --- + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + cute.nvgpu.warpgroup.fence() + for k_tile in range(prologue_mma_cnt): + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + acc_gate, + tCrA[k_block_coord], + tCrB0[k_block_coord], + acc_gate, + ) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + acc_up, + tCrA[k_block_coord], + tCrB1[k_block_coord], + acc_up, + ) + cute.nvgpu.warpgroup.commit_group() + + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + mainloop_consumer_read_state.advance() + + # --- Steady state: dual GEMM --- + for k_tile in range(prologue_mma_cnt, k_tile_cnt): + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + + cute.nvgpu.warpgroup.fence() + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + acc_gate, + tCrA[k_block_coord], + tCrB0[k_block_coord], + acc_gate, + ) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + cute.gemm( + tiled_mma, + acc_up, + tCrA[k_block_coord], + tCrB1[k_block_coord], + acc_up, + ) + cute.nvgpu.warpgroup.commit_group() + + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + mainloop_consumer_read_state.advance() + + cute.nvgpu.warpgroup.wait_group(0) + for k_tile in range(prologue_mma_cnt): + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + + # ============= EPILOGUE: fused SiLU + multiply ============= + tCgC_for_tma = cute.zipped_divide(gC_mnl_slice, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + tCgC_for_tma, + ) + + epi_tile_num = cute.size(tCgC_for_tma, mode=[1]) + epi_tile_shape = tCgC_for_tma.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) + + num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num + for epi_idx in cutlass.range_constexpr(epi_tile_num): + if cutlass.const_expr(self.use_fp8_scales): + for epi_v in cutlass.range_constexpr(size_tRS_rD): + idx = epi_idx * size_tRS_rD + epi_v + g = tRS_rAcc_gate[idx] * input_scale_gate_val + u = tRS_rAcc_up[idx] * input_scale_up_val + neg_g = -g + exp_neg_g = cute.exp(neg_g) + sigmoid_g = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + exp_neg_g + ) + result = g * sigmoid_g * u * output_scale_inv_val + tRS_rD[epi_v] = cutlass.max( + cutlass.min(result, fp8_max), fp8_min + ) + else: + for epi_v in cutlass.range_constexpr(size_tRS_rD): + idx = epi_idx * size_tRS_rD + epi_v + g = tRS_rAcc_gate[idx] + u = tRS_rAcc_up[idx] + neg_g = -g + exp_neg_g = cute.exp(neg_g) + sigmoid_g = cutlass.Float32(1.0) / ( + cutlass.Float32(1.0) + exp_neg_g + ) + tRS_rD[epi_v] = g * sigmoid_g * u + + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size( + tRS_sD, mode=[3] + ) + cute.copy( + tiled_copy_r2s, + tRS_rD_out, + tRS_sD[(None, None, None, epi_buffer)], + ) + + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + if warp_idx == self.epi_store_warp_id: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + tma_store_pipeline.producer_commit() + tma_store_pipeline.producer_acquire() + + self.epilog_sync_barrier.arrive_and_wait() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + tma_store_pipeline.producer_tail() + + # ===================================================================== + # Static helper methods + # ===================================================================== + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + swizzle_size: int, + raster_along_m: bool, + max_active_clusters: cutlass.Constexpr, + ) -> tuple: + c_shape = cute.slice_(tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, + cluster_shape_mnl, + swizzle_size, + raster_along_m, + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + epi_tile, + ) + return tma_atom_c, tma_tensor_c + + +def _select_tile_config(M): + # (64,64) tiles give many more pipeline stages (8 vs 2-3 for 128x128) + # due to the 3-buffer-per-stage design (A+B0+B1). More stages = better + # latency hiding for TMA loads. + return (64, 64), (1, 1) + + +def _to_3d_k_major(t2d: torch.Tensor) -> torch.Tensor: + """Convert 2D row-major (R, C) tensor to 3D (R, C, 1) with strides (C, 1, R*C). + + This matches the dense_gemm convention for K-major tensors where the + contiguous (stride-1) dimension is dim 1. + """ + return t2d.unsqueeze(0).permute(1, 2, 0) + + +def _make_cute_tensor(t3d: torch.Tensor, cutlass_dtype=None) -> "cute.Tensor": + """Create a cute.Tensor from a 3D torch tensor, with optional dtype override for FP8. + + DLPack doesn't support float8 types, so FP8 tensors are viewed as uint8 + and the element_type is overridden on the cute.Tensor. + """ + is_fp8 = t3d.dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + if is_fp8: + t3d_for_dlpack = t3d.view(torch.uint8) + else: + t3d_for_dlpack = t3d + + ct = from_dlpack(t3d_for_dlpack, assumed_align=16) + if is_fp8 and cutlass_dtype is not None: + ct.element_type = cutlass_dtype + ct = ct.mark_layout_dynamic(leading_dim=1) + return ct + + +def _get_cached_weight(w: torch.Tensor, fp8_dtype) -> "cute.Tensor": + """Cache weight transpose and cute tensor creation. + + Weights are static in inference, so we cache the transposed (N, K) + tensor and its cute.Tensor wrapper keyed by object identity. + This avoids GPU transpose kernels being captured in CUDA graphs. + """ + key = id(w) + cached = _WEIGHT_CACHE.get(key) + if cached is not None and cached[0] is w: + return cached[2] + w_t = w.t().contiguous() + w_3d = _to_3d_k_major(w_t) + ct = _make_cute_tensor(w_3d, fp8_dtype) + # Store strong ref to w to prevent id reuse while cached + _WEIGHT_CACHE[key] = (w, w_t, ct) + return ct + + +def _get_cached_dummy_scale(device) -> "cute.Tensor": + """Cache dummy scale tensor for non-FP8 mode.""" + cached = _DUMMY_SCALE_CACHE.get(device) + if cached is not None: + return cached[1] + dummy = torch.zeros(1, device=device, dtype=torch.float32) + ct = from_dlpack(dummy, assumed_align=4) + _DUMMY_SCALE_CACHE[device] = (dummy, ct) + return ct + + +def _get_cached_raw_scale(scale: torch.Tensor, idx: int = 0) -> "cute.Tensor": + """Cache from_dlpack wrapper for a scalar extracted from a scale tensor. + + Args: + scale: Scale tensor (may be scalar or per-channel). + idx: Element index to extract (default 0). + """ + key = (scale.data_ptr(), idx) + cached = _RAW_SCALE_CACHE.get(key) + if cached is not None: + return cached[1] + scale_f32 = scale.flatten()[idx].float().reshape(1).contiguous() + ct = from_dlpack(scale_f32, assumed_align=4) + _RAW_SCALE_CACHE[key] = (scale_f32, ct) + return ct + + +def _get_cached_stream(torch_stream) -> cuda.CUstream: + """Cache CUstream wrapper.""" + stream_int = torch_stream.cuda_stream + cached = _STREAM_CACHE.get(stream_int) + if cached is not None: + return cached + stream = cuda.CUstream(stream_int) + _STREAM_CACHE[stream_int] = stream + return stream + + +def cutedsl_dual_gemm( + x: torch.Tensor, + w_gate: torch.Tensor, + w_up: torch.Tensor, + out: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +): + """Compute out = SiLU(x @ w_gate) * (x @ w_up) using SM90 CuteDSL kernel. + + Args: + x: Input tensor (M, K), row-major, BF16/FP16 or FP8 (e4m3fn) + w_gate: Gate weight (K, N), row-major, BF16/FP16 or FP8 (e4m3fn) + w_up: Up weight (K, N), row-major, BF16/FP16 or FP8 (e4m3fn) + out: Output tensor (M, N), row-major, BF16/FP16 or FP8 (e4m3fn) + x_scale: Per-tensor input scale (scalar float32), required for FP8 + w_scale: Per-tensor weight scale (scalar float32), required for FP8 + o_scale: Per-tensor output scale (scalar float32), required for FP8 + """ + M, K = x.shape + _, N = w_gate.shape + assert w_up.shape == (K, N), f"w_up shape mismatch: {w_up.shape} vs ({K}, {N})" + assert out.shape == (M, N), f"out shape mismatch: {out.shape} vs ({M}, {N})" + + use_fp8_scales = x_scale is not None and w_scale is not None and o_scale is not None + + tile_shape_mn, cluster_shape_mn = _select_tile_config(M) + acc_dtype = cutlass.Float32 + + # Cache kernel object to avoid re-creating it every call + kern_key = (acc_dtype, tile_shape_mn, cluster_shape_mn, use_fp8_scales) + kernel = _KERNEL_OBJ_CACHE.get(kern_key) + if kernel is None: + kernel = HopperDualGemmKernel( + acc_dtype, tile_shape_mn, cluster_shape_mn, use_fp8_scales=use_fp8_scales + ) + _KERNEL_OBJ_CACHE[kern_key] = kernel + + # Determine cutlass dtype for FP8 override + fp8_ab_dtype = cutlass.Float8E4M3FN if use_fp8_scales else None + fp8_c_dtype = cutlass.Float8E4M3FN if use_fp8_scales else None + + # Cache weight transposes to avoid GPU copy ops in CUDA graph + mB0 = _get_cached_weight(w_gate, fp8_ab_dtype) + mB1 = _get_cached_weight(w_up, fp8_ab_dtype) + + # x and out change each call — create fresh cute tensors (CPU-only, no GPU ops) + x_3d = _to_3d_k_major(x) + out_3d = _to_3d_k_major(out) + mA = _make_cute_tensor(x_3d, fp8_ab_dtype) + mC = _make_cute_tensor(out_3d, fp8_c_dtype) + + stream = _get_cached_stream(torch.cuda.current_stream()) + + # Prepare scale tensors for the kernel + if use_fp8_scales: + m_x_scale = _get_cached_raw_scale(x_scale) + m_o_scale = _get_cached_raw_scale(o_scale) + # Support both per-tensor (scalar) and per-channel w_scale. + # Per-channel scales have one value per column of the combined + # [w_gate | w_up] weight; gate uses index 0, up uses index N. + if w_scale.numel() == 1: + m_w_scale_gate = _get_cached_raw_scale(w_scale) + m_w_scale_up = m_w_scale_gate + else: + m_w_scale_gate = _get_cached_raw_scale(w_scale, 0) + m_w_scale_up = _get_cached_raw_scale(w_scale, N) + else: + dummy = _get_cached_dummy_scale(x.device) + m_x_scale = dummy + m_w_scale_gate = dummy + m_w_scale_up = dummy + m_o_scale = dummy + + # max_active_clusters: use device SM count for optimal persistent occupancy + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + + compile_key = ( + x.dtype, + out.dtype, + tile_shape_mn, + cluster_shape_mn, + K, + N, + use_fp8_scales, + sm_count, + ) + + if compile_key not in _COMPILED_KERNELS: + _COMPILED_KERNELS[compile_key] = cute.compile( + kernel, + mA, + mB0, + mB1, + mC, + sm_count, + stream, + m_x_scale, + m_w_scale_gate, + m_w_scale_up, + m_o_scale, + ) + + compiled_kernel = _COMPILED_KERNELS[compile_key] + # max_active_clusters is Constexpr — baked into compiled kernel, not passed at runtime + compiled_kernel( + mA, mB0, mB1, mC, stream, m_x_scale, m_w_scale_gate, m_w_scale_up, m_o_scale + ) + + +def _cutedsl_dual_gemm_fake( + x: torch.Tensor, + w: torch.Tensor, + out: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fake impl for torch.compile shape inference.""" + return out + + +@register_custom_op( + op_name="cutedsl_dual_gemm", + mutates_args=["out"], + fake_impl=_cutedsl_dual_gemm_fake, +) +def cutedsl_dual_gemm_fused_op( + x: torch.Tensor, + w: torch.Tensor, + out: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused dual GEMM op matching the Triton dual_gemm interface. + + Splits combined weight w = [w_gate | w_up] (shape K, 2*N) and calls + the CuteDSL kernel. Used as a torch custom op for torch.compile fusion. + """ + N = w.shape[1] // 2 + w_gate, w_up = torch.split(w, N, dim=1) + cutedsl_dual_gemm(x, w_gate, w_up, out, x_scale, w_scale, o_scale) + return out + + +def _run_test(M, K, N, dtype=torch.bfloat16): + torch.manual_seed(42) + # Use smaller values to reduce BF16 accumulation error + x = torch.randn(M, K, device="cuda", dtype=dtype) * 0.1 + w_gate = torch.randn(K, N, device="cuda", dtype=dtype) * 0.1 + w_up = torch.randn(K, N, device="cuda", dtype=dtype) * 0.1 + out = torch.empty(M, N, device="cuda", dtype=dtype) + + cutedsl_dual_gemm(x, w_gate, w_up, out) + torch.cuda.synchronize() + + ref = torch.nn.functional.silu(x @ w_gate) * (x @ w_up) + torch.testing.assert_close(out, ref, atol=1.0, rtol=0.03) + print(f" PASS (M={M}, K={K}, N={N}, dtype={dtype})") + + +def _run_fp8_test(M, K, N): + torch.manual_seed(42) + # Create FP8 inputs by quantizing random data + x_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) * 0.1 + w_gate_fp32 = torch.randn(K, N, device="cuda", dtype=torch.float32) * 0.1 + w_up_fp32 = torch.randn(K, N, device="cuda", dtype=torch.float32) * 0.1 + + # Compute per-tensor scales + x_scale = x_fp32.abs().max() / 448.0 + w_scale = max(w_gate_fp32.abs().max(), w_up_fp32.abs().max()) / 448.0 + o_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) + + # Quantize to FP8 + x_fp8 = (x_fp32 / x_scale).to(torch.float8_e4m3fn) + w_gate_fp8 = (w_gate_fp32 / w_scale).to(torch.float8_e4m3fn) + w_up_fp8 = (w_up_fp32 / w_scale).to(torch.float8_e4m3fn) + out_fp8 = torch.empty(M, N, device="cuda", dtype=torch.float8_e4m3fn) + + x_scale_t = x_scale.reshape(1).to(torch.float32) + w_scale_t = w_scale.reshape(1).to(torch.float32) + + cutedsl_dual_gemm( + x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale_t, w_scale_t, o_scale + ) + torch.cuda.synchronize() + + # Reference: dequantize, compute in FP32, requantize + x_deq = x_fp8.float() * x_scale + wg_deq = w_gate_fp8.float() * w_scale + wu_deq = w_up_fp8.float() * w_scale + ref_fp32 = torch.nn.functional.silu(x_deq @ wg_deq) * (x_deq @ wu_deq) + ref_fp8 = torch.clamp(ref_fp32 / o_scale, -448.0, 448.0).to(torch.float8_e4m3fn) + + # Compare in float (relaxed tolerance for FP8 WGMMA vs float32 matmul differences) + torch.testing.assert_close(out_fp8.float(), ref_fp8.float(), atol=8.0, rtol=0.15) + print(f" PASS FP8 (M={M}, K={K}, N={N})") + + +if __name__ == "__main__": + print("Testing CuteDSL Dual GEMM kernel...") + # Test various shapes + for M, K, N in [ + (128, 4096, 11008), + (256, 4096, 11008), + (64, 4096, 4096), + ]: + _run_test(M, K, N, torch.bfloat16) + # Test FP16 + _run_test(128, 4096, 4096, torch.float16) + # Test FP8 + print("Testing FP8 mode...") + _run_fp8_test(128, 4096, 4096) + print("All tests passed!") diff --git a/python/sglang/jit_kernel/norm.py b/python/sglang/jit_kernel/norm.py index e3b2aee1b110..a0c1939fa5c1 100644 --- a/python/sglang/jit_kernel/norm.py +++ b/python/sglang/jit_kernel/norm.py @@ -11,6 +11,7 @@ load_jit, make_cpp_args, ) +from sglang.srt.utils.custom_op import register_custom_op if TYPE_CHECKING: from tvm_ffi.module import Module @@ -49,6 +50,33 @@ def _jit_fused_add_rmsnorm_module(dtype: torch.dtype) -> Module: ) +@cache_once +def _jit_fused_add_rmsnorm_quant_module(dtype: torch.dtype) -> Module: + args = make_cpp_args(dtype) + return load_jit( + "fused_add_rmsnorm_quant", + *args, + cuda_files=["elementwise/fused_add_rmsnorm_quant.cuh"], + cuda_wrappers=[ + ( + "fused_add_rmsnorm_quant", + f"FusedAddRMSNormQuantKernel<{args}>::run", + ) + ], + ) + + +@cache_once +def _jit_rmsnorm_quant_module(hidden_size: int, dtype: torch.dtype) -> Module: + args = make_cpp_args(hidden_size, is_arch_support_pdl(), dtype) + return load_jit( + "rmsnorm_quant", + *args, + cuda_files=["elementwise/rmsnorm_quant.cuh"], + cuda_wrappers=[("rmsnorm_quant", f"RMSNormQuantKernel<{args}>::run")], + ) + + @cache_once def _jit_qknorm_across_heads_module(dtype: torch.dtype) -> Module: args = make_cpp_args(dtype) @@ -112,6 +140,34 @@ def fused_add_rmsnorm( module.fused_add_rmsnorm(input, residual, weight, eps) +@register_custom_op(op_name="jit_rmsnorm_quant", mutates_args=["out"]) +def rmsnorm_quant( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, +) -> None: + hidden_size = input.size(-1) + module = _jit_rmsnorm_quant_module(hidden_size, input.dtype) + module.rmsnorm_quant(input, weight, out, scale.view(-1), eps) + + +@register_custom_op( + op_name="jit_fused_add_rmsnorm_quant", mutates_args=["out", "residual"] +) +def fused_add_rmsnorm_quant( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, +) -> None: + module = _jit_fused_add_rmsnorm_quant_module(input.dtype) + module.fused_add_rmsnorm_quant(input, residual, weight, out, scale.view(-1), eps) + + def fused_inplace_qknorm_across_heads( q: torch.Tensor, k: torch.Tensor, diff --git a/python/sglang/jit_kernel/tests/test_cutedsl_dual_gemm.py b/python/sglang/jit_kernel/tests/test_cutedsl_dual_gemm.py new file mode 100644 index 000000000000..5bcca3dc4efd --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_cutedsl_dual_gemm.py @@ -0,0 +1,173 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +"""Correctness tests for CuteDSL SM90 dual GEMM kernel.""" + +import itertools + +import pytest +import torch + +try: + import cutlass # noqa: F401 + + from sglang.jit_kernel.cutedsl_dual_gemm import cutedsl_dual_gemm + + CUTEDSL_AVAILABLE = True +except ImportError: + CUTEDSL_AVAILABLE = False + cutedsl_dual_gemm = None + +SM90_AVAILABLE = ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 +) + +DEVICE = "cuda" + +M_LIST = [1, 16, 64, 128, 512, 2048] +K_LIST = [4096, 8192] +N_LIST = [11008, 14336] +DTYPE_LIST = [torch.bfloat16, torch.float16] + + +@pytest.mark.skipif(not SM90_AVAILABLE, reason="Requires SM90+ (Hopper) GPU") +@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason="CuTe DSL not available") +@pytest.mark.parametrize( + "M,K,N,dtype", + list(itertools.product(M_LIST, K_LIST, N_LIST, DTYPE_LIST)), +) +def test_cutedsl_dual_gemm_correctness( + M: int, K: int, N: int, dtype: torch.dtype +) -> None: + """Test CuteDSL dual GEMM against PyTorch reference.""" + torch.manual_seed(42) + x = torch.randn(M, K, device=DEVICE, dtype=dtype) * 0.1 + w_gate = torch.randn(K, N, device=DEVICE, dtype=dtype) * 0.1 + w_up = torch.randn(K, N, device=DEVICE, dtype=dtype) * 0.1 + out = torch.empty(M, N, device=DEVICE, dtype=dtype) + + cutedsl_dual_gemm(x, w_gate, w_up, out) + torch.cuda.synchronize() + + ref = torch.nn.functional.silu(x @ w_gate) * (x @ w_up) + + torch.testing.assert_close(out, ref, atol=1e-1, rtol=1e-1) + + +FP8_M_LIST = [64, 128, 512] +FP8_K_LIST = [4096] +FP8_N_LIST = [4096, 11008] + + +@pytest.mark.skipif(not SM90_AVAILABLE, reason="Requires SM90+ (Hopper) GPU") +@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason="CuTe DSL not available") +@pytest.mark.parametrize( + "M,K,N", + list(itertools.product(FP8_M_LIST, FP8_K_LIST, FP8_N_LIST)), +) +def test_cutedsl_dual_gemm_fp8_correctness(M: int, K: int, N: int) -> None: + """Test CuteDSL dual GEMM FP8 mode against PyTorch reference.""" + torch.manual_seed(42) + + # Create reference data in FP32, then quantize + x_fp32 = torch.randn(M, K, device=DEVICE, dtype=torch.float32) * 0.1 + w_gate_fp32 = torch.randn(K, N, device=DEVICE, dtype=torch.float32) * 0.1 + w_up_fp32 = torch.randn(K, N, device=DEVICE, dtype=torch.float32) * 0.1 + + # Per-tensor scales + x_scale = x_fp32.abs().max() / 448.0 + w_scale = max(w_gate_fp32.abs().max(), w_up_fp32.abs().max()) / 448.0 + o_scale = torch.tensor(1.0, device=DEVICE, dtype=torch.float32) + + # Quantize to FP8 + x_fp8 = (x_fp32 / x_scale).to(torch.float8_e4m3fn) + w_gate_fp8 = (w_gate_fp32 / w_scale).to(torch.float8_e4m3fn) + w_up_fp8 = (w_up_fp32 / w_scale).to(torch.float8_e4m3fn) + out_fp8 = torch.empty(M, N, device=DEVICE, dtype=torch.float8_e4m3fn) + + x_scale_t = x_scale.reshape(1).to(torch.float32) + w_scale_t = w_scale.reshape(1).to(torch.float32) + + cutedsl_dual_gemm( + x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale_t, w_scale_t, o_scale + ) + torch.cuda.synchronize() + + # Reference: dequantize, compute in FP32, requantize + x_deq = x_fp8.float() * x_scale + wg_deq = w_gate_fp8.float() * w_scale + wu_deq = w_up_fp8.float() * w_scale + ref_fp32 = torch.nn.functional.silu(x_deq @ wg_deq) * (x_deq @ wu_deq) + ref_fp8 = torch.clamp(ref_fp32 / o_scale, -448.0, 448.0).to(torch.float8_e4m3fn) + + # Relaxed tolerance: FP8 WGMMA vs float32 matmul have inherent precision diffs + torch.testing.assert_close(out_fp8.float(), ref_fp8.float(), atol=8.0, rtol=0.15) + + +@pytest.mark.skipif(not SM90_AVAILABLE, reason="Requires SM90+ (Hopper) GPU") +@pytest.mark.skipif(not CUTEDSL_AVAILABLE, reason="CuTe DSL not available") +@pytest.mark.parametrize( + "M,K,N,dtype", + list(itertools.product(M_LIST, K_LIST, N_LIST, DTYPE_LIST)), +) +def test_cutedsl_dual_gemm_matches_triton( + M: int, K: int, N: int, dtype: torch.dtype +) -> None: + """Compare CuteDSL output against the Triton dual_gemm kernel.""" + from sglang.srt.compilation.fusion.ops.triton_ops.dual_gemm import ( + dual_gemm_kernel, + ) + + torch.manual_seed(42) + x = torch.randn(M, K, device=DEVICE, dtype=dtype) * 0.1 + w_gate = torch.randn(K, N, device=DEVICE, dtype=dtype) * 0.1 + w_up = torch.randn(K, N, device=DEVICE, dtype=dtype) * 0.1 + + # CuteDSL path + out_cutedsl = torch.empty(M, N, device=DEVICE, dtype=dtype) + cutedsl_dual_gemm(x, w_gate, w_up, out_cutedsl) + torch.cuda.synchronize() + + # Triton path + import triton + + out_triton = torch.empty(M, N, device=DEVICE, dtype=dtype) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_K"]), + ) + + dual_gemm_kernel[grid]( + x, + w_gate, + w_up, + out_triton, + None, + None, + None, + False, + x.stride(0), + x.stride(1), + w_gate.stride(0), + w_gate.stride(1), + out_triton.stride(0), + out_triton.stride(1), + M, + K, + N, + torch.finfo(dtype).min, + torch.finfo(dtype).max, + 128, + 64, + 128, + 1, + num_warps=4, + num_stages=4, + ) + torch.cuda.synchronize() + + torch.testing.assert_close(out_cutedsl, out_triton, atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm_quant.py b/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm_quant.py new file mode 100644 index 000000000000..40c45e231e6c --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_fused_add_rmsnorm_quant.py @@ -0,0 +1,105 @@ +import itertools + +import pytest +import torch + +from sglang.jit_kernel.utils import get_ci_test_range + + +def reference_fused_add_rmsnorm_quant(input, residual, weight, scale, eps=1e-6): + updated_residual = residual + input + x = updated_residual.float() + mean = x.pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + normed = x * norm * weight.float() + inv_scale = 1.0 / scale + quantized = (normed * inv_scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + return quantized, updated_residual + + +HIDDEN_SIZE_LIST = get_ci_test_range( + [512, 1024, 1536, 2048, 3072, 4096, 5120, 8192], + [512, 2048, 8192], +) +BS_LIST = get_ci_test_range( + [1, 4, 16, 64, 256], + [1, 16, 256], +) +DTYPE_LIST = [torch.bfloat16, torch.float16] +SCALE_LIST = get_ci_test_range( + [0.1, 1.0, 4.0, 10.0], + [0.1, 1.0], +) +DEVICE = "cuda" + + +@pytest.mark.parametrize( + "batch_size,hidden_size,dtype,scale_val", + list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST, DTYPE_LIST, SCALE_LIST)), +) +def test_fused_add_rmsnorm_quant_correctness( + batch_size: int, hidden_size: int, dtype: torch.dtype, scale_val: float +) -> None: + from sglang.jit_kernel.norm import fused_add_rmsnorm_quant + + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + residual = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + scale = torch.tensor([scale_val], device=DEVICE, dtype=torch.float32) + + original_residual = residual.clone() + + output = torch.empty(input.shape, dtype=torch.float8_e4m3fn, device=input.device) + fused_add_rmsnorm_quant(output, input, residual, weight, scale) + expected_output, expected_residual = reference_fused_add_rmsnorm_quant( + input, original_residual, weight, scale + ) + + assert output.dtype == torch.float8_e4m3fn + torch.testing.assert_close( + output.float(), expected_output.float(), atol=1.5e-1, rtol=1.5e-1 + ) + torch.testing.assert_close( + residual, input + original_residual, atol=1e-4, rtol=1e-4 + ) + + +@pytest.mark.parametrize( + "batch_size,hidden_size", + list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)), +) +def test_fused_add_rmsnorm_quant_matches_separate_ops( + batch_size: int, hidden_size: int +) -> None: + from sglang.jit_kernel.norm import fused_add_rmsnorm, fused_add_rmsnorm_quant + from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 + + dtype = torch.bfloat16 + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + residual = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + scale = torch.tensor([1.0], device=DEVICE, dtype=torch.float32) + + # Fused path + residual_fused = residual.clone() + fused_output = torch.empty( + input.shape, dtype=torch.float8_e4m3fn, device=input.device + ) + fused_add_rmsnorm_quant(fused_output, input, residual_fused, weight, scale) + + # Separate path: fused_add_rmsnorm then per-tensor fp8 quantize + input_sep = input.clone() + residual_sep = residual.clone() + fused_add_rmsnorm(input_sep, residual_sep, weight) + separate_output = torch.empty_like(input_sep, dtype=torch.float8_e4m3fn) + per_tensor_quant_fp8(input_sep, separate_output, scale, is_static=True) + + assert fused_output.dtype == torch.float8_e4m3fn + torch.testing.assert_close(residual_fused, residual_sep, atol=1e-2, rtol=1e-2) + torch.testing.assert_close( + fused_output.float(), separate_output.float(), atol=1.5e-1, rtol=1.5e-1 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/jit_kernel/tests/test_rmsnorm_quant.py b/python/sglang/jit_kernel/tests/test_rmsnorm_quant.py new file mode 100644 index 000000000000..e6e70ac88121 --- /dev/null +++ b/python/sglang/jit_kernel/tests/test_rmsnorm_quant.py @@ -0,0 +1,91 @@ +import itertools + +import pytest +import torch + +from sglang.jit_kernel.utils import get_ci_test_range + + +def reference_rmsnorm_quant(input, weight, scale, eps=1e-6): + x = input.float() + mean = x.pow(2).mean(dim=-1, keepdim=True) + norm = (mean + eps).rsqrt() + normed = x * norm * weight.float() + inv_scale = 1.0 / scale + quantized = (normed * inv_scale).clamp(-448.0, 448.0).to(torch.float8_e4m3fn) + return quantized + + +HIDDEN_SIZE_LIST = get_ci_test_range( + [512, 1024, 1536, 2048, 3072, 4096, 5120, 8192], + [512, 2048, 8192], +) +BS_LIST = get_ci_test_range( + [1, 4, 16, 64, 256], + [1, 16, 256], +) +DTYPE_LIST = [torch.bfloat16, torch.float16] +SCALE_LIST = get_ci_test_range( + [0.1, 1.0, 4.0, 10.0], + [0.1, 1.0], +) +DEVICE = "cuda" + + +@pytest.mark.parametrize( + "batch_size,hidden_size,dtype,scale_val", + list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST, DTYPE_LIST, SCALE_LIST)), +) +def test_rmsnorm_quant_correctness( + batch_size: int, hidden_size: int, dtype: torch.dtype, scale_val: float +) -> None: + from sglang.jit_kernel.norm import rmsnorm_quant as jit_rmsnorm_quant + + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + scale = torch.tensor([scale_val], device=DEVICE, dtype=torch.float32) + + output = torch.empty(input.shape, dtype=torch.float8_e4m3fn, device=input.device) + jit_rmsnorm_quant(output, input, weight, scale) + expected = reference_rmsnorm_quant(input, weight, scale) + + assert output.dtype == torch.float8_e4m3fn + torch.testing.assert_close( + output.float(), expected.float(), atol=1.5e-1, rtol=1.5e-1 + ) + + +@pytest.mark.parametrize( + "batch_size,hidden_size", + list(itertools.product(BS_LIST, HIDDEN_SIZE_LIST)), +) +def test_rmsnorm_quant_matches_separate_ops(batch_size: int, hidden_size: int) -> None: + from sglang.jit_kernel.norm import rmsnorm + from sglang.jit_kernel.norm import rmsnorm_quant as jit_rmsnorm_quant + from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8 + + dtype = torch.bfloat16 + input = torch.randn(batch_size, hidden_size, device=DEVICE, dtype=dtype) + weight = torch.randn(hidden_size, device=DEVICE, dtype=dtype) + scale = torch.tensor([1.0], device=DEVICE, dtype=torch.float32) + + # Fused path + fused_output = torch.empty( + input.shape, dtype=torch.float8_e4m3fn, device=input.device + ) + jit_rmsnorm_quant(fused_output, input, weight, scale) + + # Separate path: rmsnorm then quantize + normed = input.clone() + rmsnorm(normed, weight, output=normed) + separate_output = torch.empty_like(normed, dtype=torch.float8_e4m3fn) + per_tensor_quant_fp8(normed, separate_output, scale, is_static=True) + + assert fused_output.dtype == torch.float8_e4m3fn + torch.testing.assert_close( + fused_output.float(), separate_output.float(), atol=1.5e-1, rtol=1.5e-1 + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/python/sglang/srt/compilation/backend.py b/python/sglang/srt/compilation/backend.py index 8af025707f55..7f1de8ba06fa 100644 --- a/python/sglang/srt/compilation/backend.py +++ b/python/sglang/srt/compilation/backend.py @@ -20,6 +20,7 @@ from sglang.srt.compilation.compiler_interface import EagerAdapter, InductorAdaptor from sglang.srt.compilation.cuda_piecewise_backend import CUDAPiecewiseBackend from sglang.srt.compilation.npu_piecewise_backend import NPUPiecewiseBackend +from sglang.srt.compilation.pass_config import PassConfig from sglang.srt.compilation.pass_manager import PostGradPassManager from sglang.srt.utils.common import is_npu, rank0_log @@ -373,13 +374,14 @@ class SGLangBackend: def __init__( self, config: CompilationConfig, + pass_config: PassConfig, graph_pool: Any, ): rank0_log(f"Initializing SGLangBackend") assert graph_pool is not None self.graph_pool = graph_pool - self.post_grad_pass_manager = PostGradPassManager() + self.post_grad_pass_manager = PostGradPassManager(pass_config) self.sym_tensor_indices = [] self.input_buffers = [] diff --git a/python/sglang/srt/compilation/compile.py b/python/sglang/srt/compilation/compile.py index 46a9240fb259..1a46bb6e847d 100644 --- a/python/sglang/srt/compilation/compile.py +++ b/python/sglang/srt/compilation/compile.py @@ -7,8 +7,8 @@ from typing import Any, Callable, Optional, Union import torch - from sglang.srt.compilation.compilation_config import CompilationConfig +from sglang.srt.compilation.pass_config import PassConfig from sglang.srt.compilation.piecewise_context_manager import is_in_piecewise_cuda_graph from sglang.srt.utils.common import rank0_log @@ -115,6 +115,7 @@ def install_torch_compiled( dynamic_arg_dims: dict[str, Union[int, list[int]]] | None = None, backend_factory: Optional[Callable[[torch.fx.GraphModule, list], Callable]] = None, compile_config: CompilationConfig = None, + pass_config: PassConfig = None, fullgraph: bool = True, graph_pool: Any = None, ): @@ -129,9 +130,9 @@ def install_torch_compiled( if backend_factory is None: from sglang.srt.compilation.backend import SGLangBackend - backend_factory = lambda gm, ex: SGLangBackend(compile_config, graph_pool)( - gm, ex - ) + backend_factory = lambda gm, ex: SGLangBackend( + compile_config, pass_config, graph_pool + )(gm, ex) compiled_codes: list[type(original_code)] = [] state = {"compiled": False, "compiled_callable": None} diff --git a/python/sglang/srt/compilation/fusion/ops/__init__.py b/python/sglang/srt/compilation/fusion/ops/__init__.py new file mode 100644 index 000000000000..fc8818b0ffb7 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/ops/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .flashinfer_ops import register_flashinfer_fused_ops +from .triton_ops import register_triton_fused_ops + + +def register_fused_ops() -> None: + register_triton_fused_ops() + register_flashinfer_fused_ops() diff --git a/python/sglang/srt/compilation/fusion/ops/flashinfer_ops.py b/python/sglang/srt/compilation/fusion/ops/flashinfer_ops.py new file mode 100644 index 000000000000..0f09d086dc64 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/ops/flashinfer_ops.py @@ -0,0 +1,87 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional + +import torch + +from sglang.srt.utils import ( + direct_register_custom_op, + is_flashinfer_rmsnorm_quant_kernels_available, +) + + +def register_flashinfer_fused_ops(): + if is_flashinfer_rmsnorm_quant_kernels_available(): + import flashinfer + + def _flashinfer_rms_norm_quant( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, + enable_pdl: Optional[bool] = None, + ) -> torch.Tensor: + return flashinfer.norm.rmsnorm_quant( + out, input, weight, scale, eps, enable_pdl + ) + + def _flashinfer_rms_norm_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float, + enable_pdl: Optional[bool], + ) -> None: + pass + + direct_register_custom_op( + op_name="flashinfer_rmsnorm_quant", + op_func=_flashinfer_rms_norm_quant, + mutates_args=["out"], + fake_impl=_flashinfer_rms_norm_quant_fake, + ) + + def _flashinfer_fused_add_rmsnorm_quant( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, + enable_pdl: Optional[bool] = None, + ) -> torch.Tensor: + return flashinfer.norm.fused_add_rmsnorm_quant( + out, input, residual, weight, scale, eps, enable_pdl + ) + + def _flashinfer_fused_add_rmsnorm_quant_fake( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + eps: float = 1e-6, + enable_pdl: Optional[bool] = None, + ) -> None: + pass + + direct_register_custom_op( + op_name="flashinfer_fused_add_rmsnorm_quant", + op_func=_flashinfer_fused_add_rmsnorm_quant, + mutates_args=["out", "residual"], + fake_impl=_flashinfer_fused_add_rmsnorm_quant_fake, + ) diff --git a/python/sglang/srt/compilation/fusion/ops/triton_ops/__init__.py b/python/sglang/srt/compilation/fusion/ops/triton_ops/__init__.py new file mode 100644 index 000000000000..b592f06adbb2 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/ops/triton_ops/__init__.py @@ -0,0 +1,40 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Optional + +import torch + +from sglang.srt.utils import direct_register_custom_op + +from .dual_gemm import dual_gemm, dual_gemm_fake + + +def dual_gemm_fwd( + x: torch.Tensor, + w: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.ops.sglang.dual_gemm(x, w, x_scale, w_scale, o_scale) + + +def register_triton_fused_ops(): + direct_register_custom_op( + op_name="triton_dual_gemm", + op_func=dual_gemm, + mutates_args=[], + fake_impl=dual_gemm_fake, + ) diff --git a/python/sglang/srt/compilation/fusion/ops/triton_ops/dual_gemm.py b/python/sglang/srt/compilation/fusion/ops/triton_ops/dual_gemm.py new file mode 100644 index 000000000000..81f5802081f5 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/ops/triton_ops/dual_gemm.py @@ -0,0 +1,182 @@ +# Copyright 2024-2025 Ben Fattori and SGLang Team +# +# Adapted from https://github.com/fattorib/fusedswiglu +# +# Permission is hereby granted, free of charge, to any person obtaining a +# copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +# ============================================================================== + +# Computes fused dual GEMM forward pass adapted from `GLU Variants Improve Transformer` +# + +from typing import Optional + +import torch +import triton +import triton.language as tl + + +@triton.jit +def silu(x): + return x * tl.sigmoid(x) + + +@triton.jit +def dual_gemm_kernel( + x_ptr, + w_gate_ptr, + w_up_ptr, + o_ptr, + x_scale_ptr, + w_scale_ptr, + o_scale_ptr, + USE_SCALE: tl.constexpr, + xrow_stride: tl.constexpr, + xcol_stride: tl.constexpr, + wrow_stride: tl.constexpr, + wcol_stride: tl.constexpr, + orow_stride: tl.constexpr, + ocol_stride: tl.constexpr, + dim_m, + dim_n, + dim_k, + min_val: tl.constexpr, + max_val: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + # fmt: on + pid = tl.program_id(0) + + num_pid_row = tl.cdiv(dim_m, BLOCK_SIZE_M) + num_pid_col = tl.cdiv(dim_k, BLOCK_SIZE_K) + + num_pid_in_group = GROUP_SIZE_M * num_pid_col + group_id = pid // num_pid_in_group + first_pid_row = group_id * GROUP_SIZE_M + group_size_row = min(num_pid_row - first_pid_row, GROUP_SIZE_M) + pid_row = first_pid_row + (pid % group_size_row) + pid_col = (pid % num_pid_in_group) // group_size_row + + acc_gate = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32) + acc_up = tl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_K], dtype=tl.float32) + + x_block_ptr = tl.make_block_ptr( + x_ptr, + shape=(dim_m, dim_n), + strides=(xrow_stride, xcol_stride), + offsets=(pid_row * BLOCK_SIZE_M, 0), + block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0), + ) + w_gate_block_ptr = tl.make_block_ptr( + w_gate_ptr, + shape=(dim_n, dim_k), + strides=(wrow_stride, wcol_stride), + offsets=( + 0, + pid_col * BLOCK_SIZE_K, + ), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_K), + order=(1, 0), + ) + w_up_block_ptr = tl.make_block_ptr( + w_up_ptr, + shape=(dim_n, dim_k), + strides=(wrow_stride, wcol_stride), + offsets=( + 0, + pid_col * BLOCK_SIZE_K, + ), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_K), + order=(1, 0), + ) + + for _ in range(0, tl.cdiv(dim_n, BLOCK_SIZE_N)): + x_block = tl.load(x_block_ptr, boundary_check=(0, 1)) + + w_gate_block = tl.load(w_gate_block_ptr, boundary_check=(0, 1)) + w_up_block = tl.load(w_up_block_ptr, boundary_check=(0, 1)) + + acc_gate += tl.dot(x_block, w_gate_block, allow_tf32=False) + acc_up += tl.dot(x_block, w_up_block, allow_tf32=False) + + x_block_ptr = tl.advance(x_block_ptr, offsets=(0, BLOCK_SIZE_N)) + w_gate_block_ptr = tl.advance(w_gate_block_ptr, offsets=(BLOCK_SIZE_N, 0)) + w_up_block_ptr = tl.advance(w_up_block_ptr, offsets=(BLOCK_SIZE_N, 0)) + + if USE_SCALE: + scale = tl.load(w_scale_ptr) * tl.load(x_scale_ptr) + o_scale_inv = 1.0 / tl.load(o_scale_ptr) + acc_up = (acc_up * scale) * silu(acc_gate * scale) + acc_up = tl.clamp(acc_up * o_scale_inv, min_val, max_val) + else: + acc_up *= silu(acc_gate) + + offs_out_m = pid_row * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_k = pid_col * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask = (offs_out_m[:, None] < dim_m) & (offs_out_k[None, :] < dim_k) + o_ptrs = ( + (o_ptr) + offs_out_m[:, None] * orow_stride + offs_out_k[None, :] * ocol_stride + ) + tl.store(o_ptrs, acc_up.to(o_ptr.type.element_ty), mask=mask) + + +def dual_gemm( + x: torch.Tensor, + w: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + M, N, K = x.shape[0], x.shape[1], w.shape[1] // 2 + out = torch.empty((M, K), device=x.device, dtype=x.dtype) + w_gate, w_up = torch.split(w, w.shape[1] // 2, dim=1) + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(K, META["BLOCK_SIZE_K"]), + ) + + # fmt: off + dual_gemm_kernel[grid]( + x,w_gate,w_up,out,x_scale,w_scale,o_scale, + x_scale is not None and w_scale is not None and o_scale is not None, + x.stride(0),x.stride(1), + w_up.stride(0),w_up.stride(1), + out.stride(0),out.stride(1), + M,N,K, + torch.finfo(out.dtype).min, + torch.finfo(out.dtype).max, + 128,64,128,1,num_warps=4,num_stages=4, + ) + + # fmt: on + return out + + +def dual_gemm_fake( + x: torch.Tensor, + w: torch.Tensor, + x_scale: Optional[torch.Tensor] = None, + w_scale: Optional[torch.Tensor] = None, + o_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + M, N, K = x.shape[0], x.shape[1], w.shape[1] // 2 + return torch.empty((M, K), device=x.device, dtype=x.dtype) diff --git a/python/sglang/srt/compilation/fusion/passes/__init__.py b/python/sglang/srt/compilation/fusion/passes/__init__.py new file mode 100644 index 000000000000..b3da8f86d45c --- /dev/null +++ b/python/sglang/srt/compilation/fusion/passes/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from .fused_activation import FusedActivationPass +from .rmsnorm_quant import RMSNormQuantPass diff --git a/python/sglang/srt/compilation/fusion/passes/fused_activation.py b/python/sglang/srt/compilation/fusion/passes/fused_activation.py new file mode 100644 index 000000000000..439581e277f3 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/passes/fused_activation.py @@ -0,0 +1,130 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized_v2 + +from sglang.srt.compilation.fusion.pattern import OpPattern, pattern_builder +from sglang.srt.compilation.fusion.pattern.dual_gemm_pattern import ( + DualGemmFp8PatternRegistery, + DualGemmPatternRegistery, +) +from sglang.srt.compilation.fusion.pattern.gemm_fp8_pattern import ( + CutlassFp8ScaledMMPattern, + GemmFp8PatternRegistery, + TorchScaledMMPattern, +) +from sglang.srt.compilation.fusion.pattern.quant_fp8_pattern import ( + PerTensorQuantFp8Pattern, + QuantFp8PatternRegistery, + StaticQuantFp8Pattern, +) +from sglang.srt.compilation.inductor_pass import SGLangPatternMatcherInductorPass + + +class FusedActivationPass(SGLangPatternMatcherInductorPass): + def register_dual_gemm_replacement_pattern(self, dual_gemm_op: OpPattern) -> None: + def pattern(x, w, out): + mm = torch.ops.aten.mm.default(x, w) + silu_and_mul = auto_functionalized_v2( + torch.ops.sgl_kernel.silu_and_mul.default, + input=mm, + _out_base_index=0, + _all_bases=[out], + ) + return silu_and_mul[1] + + def replacement(x, w, out): + dual_gemm_fp8_op_result = dual_gemm_op.pattern(x, w, out) + return dual_gemm_fp8_op_result + + M, K, N = 16, 16, 16 + example_inputs = [ + torch.empty(M, K).half().cuda(), # X + torch.empty(K, N).half().cuda().T, # W.T + torch.empty(M, N // 2).half().cuda(), # out + ] + + self.register_replacement_pattern(pattern, replacement, example_inputs) + + def register_dual_gemm_fp8_replacement_pattern( + self, + quant_fp8_op: OpPattern, + gemm_fp8_op: OpPattern, + dual_gemm_fp8_op: OpPattern, + ) -> None: + def pattern(x, w, x_scale, w_scale, o_scale, out, output_q): + gemm_fp8_op_result = gemm_fp8_op.pattern(x, w, x_scale, w_scale, out.dtype) + silu_and_mul = auto_functionalized_v2( + torch.ops.sgl_kernel.silu_and_mul.default, + input=gemm_fp8_op_result, + _out_base_index=0, + _all_bases=[out], + ) + quant_fp8_op_result = quant_fp8_op.pattern(silu_and_mul, output_q, o_scale) + return ( + quant_fp8_op_result[0], + quant_fp8_op_result[1], + ) + + def replacement(x, w, x_scale, w_scale, o_scale, out, output_q): + dual_gemm_fp8_op_result = dual_gemm_fp8_op.pattern( + x, w, x_scale, w_scale, o_scale, output_q + ) + if quant_fp8_op.op_type == StaticQuantFp8Pattern: + repeated_o_scale = o_scale.view(1, 1).expand(x.shape[0], 1) + else: + repeated_o_scale = o_scale + return dual_gemm_fp8_op_result, repeated_o_scale + + M, K, N = 16, 16, 16 + if quant_fp8_op.op_type == StaticQuantFp8Pattern: + SM, SN = M, N + else: + SM, SN = 1, 1 + + example_inputs = [ + torch.empty(M, K, device="cuda", dtype=torch.float8_e4m3fn), # X + torch.empty(K, N, device="cuda", dtype=torch.float8_e4m3fn).T, # W.T + torch.empty(SM, 1, device="cuda", dtype=torch.float32), # X_Scale [M, 1] + torch.empty(SN, 1, device="cuda", dtype=torch.float32), # W_Scale [N, 1] + torch.empty( + 1, device="cuda", dtype=torch.float32 + ), # O_Scale (or (1,1) if needed) + torch.empty(M, N // 2, device="cuda", dtype=torch.float16), # out + torch.empty( + M, N // 2, device="cuda", dtype=torch.float8_e4m3fn + ), # output_q + ] + + self.register_replacement_pattern(pattern, replacement, example_inputs) + + def build_pass(self): + pattern_builder( + self.register_dual_gemm_replacement_pattern, + [DualGemmPatternRegistery], + ) + + pattern_builder( + self.register_dual_gemm_fp8_replacement_pattern, + [ + QuantFp8PatternRegistery, + GemmFp8PatternRegistery, + DualGemmFp8PatternRegistery, + ], + ignore_combinations=[ + (StaticQuantFp8Pattern, TorchScaledMMPattern), + (PerTensorQuantFp8Pattern, CutlassFp8ScaledMMPattern), + ], + ) diff --git a/python/sglang/srt/compilation/fusion/passes/rmsnorm_quant.py b/python/sglang/srt/compilation/fusion/passes/rmsnorm_quant.py new file mode 100644 index 000000000000..9e785ec7bd53 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/passes/rmsnorm_quant.py @@ -0,0 +1,140 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized_v2 + +from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.compilation.fusion.pattern import OpPattern, pattern_builder +from sglang.srt.compilation.fusion.pattern.quant_fp8_pattern import ( + QuantFp8PatternRegistery, + StaticQuantFp8Pattern, +) +from sglang.srt.compilation.fusion.pattern.rmsnorm_quant_fp8_pattern import ( + FusedAddRmsnormQuantFp8PatternRegistery, + RmsnormQuantFp8PatternRegistery, +) +from sglang.srt.compilation.inductor_pass import SGLangPatternMatcherInductorPass + + +class RMSNormQuantPass(SGLangPatternMatcherInductorPass): + def register_rmsnorm_quant_replacement_pattern( + self, quant_fp8_op: OpPattern, rmsnorm_quant_fp8_op: OpPattern + ): + def pattern(x, rms_result, weight, scale, eps, output): + rmsnorm = auto_functionalized_v2( + torch.ops.sgl_kernel.rmsnorm.default, + input=x, + weight=weight, + eps=eps, + enable_pdl=is_arch_support_pdl(), + _output_base_index=0, + _all_bases=[rms_result], + ) + quant_fp8_op_result = quant_fp8_op.pattern(rmsnorm, output, scale) + return ( + quant_fp8_op_result[0], + quant_fp8_op_result[1], + ) + + def replacement(x, rms_result, weight, scale, eps, output): + rmsnorm_quant_fp8_op_result = rmsnorm_quant_fp8_op.pattern( + x, weight, scale, eps, output + ) + if quant_fp8_op.op_type == StaticQuantFp8Pattern: + repeated_scale = scale.view(1, 1).expand(x.shape[0], 1) + else: + repeated_scale = scale + return rmsnorm_quant_fp8_op_result[1], repeated_scale + + M, N, K = 16, 16, 16 + example_inputs = [ + torch.empty(M, K).half().cuda(), + torch.empty(N, K).half().cuda(), + torch.empty(M).half().cuda(), + torch.empty(()).cuda(), + torch.empty(M, N).to(dtype=torch.float8_e4m3fn).cuda(), + ] + + for eps in self.pass_config.rms_norm_eps: + self.register_replacement_pattern( + pattern, replacement, example_inputs, scalar_workaround={"eps": eps} + ) + + def register_fused_add_rmsnorm_quant_replacement_pattern( + self, + quant_fp8_op: OpPattern, + fused_add_rmsnorm_quant_fp8_op: OpPattern, + ): + def pattern(x, residual, weight, scale, result, eps): + fused_add_rmsnorm = auto_functionalized_v2( + torch.ops.sgl_kernel.fused_add_rmsnorm.default, + weight=weight, + eps=eps, + enable_pdl=is_arch_support_pdl(), + _input_base_index=0, + _residual_base_index=1, + _all_bases=[x, residual], + ) + + quant_fp8_op_result = quant_fp8_op.pattern(fused_add_rmsnorm, result, scale) + + return ( + quant_fp8_op_result[0], + quant_fp8_op_result[1], + fused_add_rmsnorm[1], + fused_add_rmsnorm[2], + ) + + def replacement(x, residual, weight, scale, result, eps): + fused_add_rmsnorm_quant_fp8_op_result = ( + fused_add_rmsnorm_quant_fp8_op.pattern( + x, residual, weight, scale, result, eps + ) + ) + if quant_fp8_op.op_type == StaticQuantFp8Pattern: + repeated_scale = scale.view(1, 1).expand(x.shape[0], 1) + else: + repeated_scale = scale + return ( + fused_add_rmsnorm_quant_fp8_op_result[1], + repeated_scale, + fused_add_rmsnorm_quant_fp8_op_result[2], + fused_add_rmsnorm_quant_fp8_op_result[2], + ) + + M, N, K = 16, 16, 16 + example_inputs = [ + torch.empty(M, K).half().cuda(), + torch.empty(N, K).half().cuda(), + torch.empty(M).half().cuda(), + torch.empty(()).cuda(), + torch.empty(M, N).to(dtype=torch.float8_e4m3fn).cuda(), + ] + + for eps in self.pass_config.rms_norm_eps: + self.register_replacement_pattern( + pattern, replacement, example_inputs, scalar_workaround={"eps": eps} + ) + + def build_pass(self): + pattern_builder( + self.register_rmsnorm_quant_replacement_pattern, + [QuantFp8PatternRegistery, RmsnormQuantFp8PatternRegistery], + ) + pattern_builder( + self.register_fused_add_rmsnorm_quant_replacement_pattern, + [QuantFp8PatternRegistery, FusedAddRmsnormQuantFp8PatternRegistery], + ) diff --git a/python/sglang/srt/compilation/fusion/pattern/__init__.py b/python/sglang/srt/compilation/fusion/pattern/__init__.py new file mode 100644 index 000000000000..c15e33e8565f --- /dev/null +++ b/python/sglang/srt/compilation/fusion/pattern/__init__.py @@ -0,0 +1,73 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import inspect +import itertools +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, List, Optional, Set, Tuple, Type + + +class OpPatternBase(ABC): + @staticmethod + @abstractmethod + def pattern(*args, **kwargs): + pass + + +@dataclass +class OpPattern: + op_type: Type[OpPatternBase] + pattern: Callable + + +class OpPatternRegistery(ABC): + def __init__(self): + self._patterns: List[OpPattern] = [] + self.build_op_pattern_registery() + + def register_op_pattern(self, op: Type[OpPattern]): + self._patterns.append(OpPattern(op, op.pattern)) + + @abstractmethod + def build_op_pattern_registery(self): + pass + + @property + def patterns(self): + return self._patterns + + +def pattern_builder( + builder: Callable, + op_pattern_registeries: List[OpPatternRegistery], + ignore_combinations: Optional[List[Tuple[Type[OpPatternBase], ...]]] = None, +): + num_args = sum( + 1 + for param in inspect.signature(builder).parameters.values() + if param.kind in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD) + and param.default is param.empty + ) + assert ( + len(op_pattern_registeries) == num_args + ), f"Expected {num_args} op_pattern_registeries, got {len(op_pattern_registeries)}" + + patterns = list(map(lambda x: x.patterns, op_pattern_registeries)) + for ops in itertools.product(*patterns): + if ignore_combinations: + ops_key = set(map(lambda x: x.op_type, ops)) + if any(ops_key.issuperset(t) for t in ignore_combinations): + continue + builder(*ops) diff --git a/python/sglang/srt/compilation/fusion/pattern/dual_gemm_pattern.py b/python/sglang/srt/compilation/fusion/pattern/dual_gemm_pattern.py new file mode 100644 index 000000000000..58fece7d3414 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/pattern/dual_gemm_pattern.py @@ -0,0 +1,100 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from abc import abstractmethod + +import torch + +from sglang.srt.compilation.fusion.pattern import OpPatternBase, OpPatternRegistery + + +def _is_cutedsl_dual_gemm_available(): + try: + from sglang.jit_kernel.cutedsl_dual_gemm import ( # noqa: F401 + cutedsl_dual_gemm_fused_op, + ) + + return True + except Exception: + return False + + +class _DualGemmPattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, w, out): + pass + + +# op: fused_ops/triton_ops/dual_gemm +# TODO: This is most probably broken, fix it for oss +class TritonFusedOpsDualGemmPattern(_DualGemmPattern): + @staticmethod + def pattern(x, w, out): + return torch.ops.sglang.triton_dual_gemm.default(x, w) + + +# op: cutedsl_dual_gemm (registered via @register_custom_op decorator) +class CuteDSLDualGemmPattern(_DualGemmPattern): + @staticmethod + def pattern(x, w, out): + return torch.ops.sglang.cutedsl_dual_gemm.default(x, w, out) + + +class _DualGemmPatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + if _is_cutedsl_dual_gemm_available(): + self.register_op_pattern(CuteDSLDualGemmPattern) + else: + self.register_op_pattern(TritonFusedOpsDualGemmPattern) + + +DualGemmPatternRegistery = _DualGemmPatternRegistery() + + +class _DualGemmFp8Pattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, w, x_scale, w_scale, o_scale, output_q): + pass + + +# op: fused_ops/triton_ops/dual_gemm +# TODO: This is most probably broken, fix it for oss +class TritonFusedOpsDualGemmFp8Pattern(_DualGemmFp8Pattern): + @staticmethod + def pattern(x, w, x_scale, w_scale, o_scale, output_q): + return torch.ops.sglang.triton_dual_gemm.default( + x, w, x_scale, w_scale, o_scale + ) + + +# op: cutedsl_dual_gemm (registered via @register_custom_op decorator) +class CuteDSLDualGemmFp8Pattern(_DualGemmFp8Pattern): + @staticmethod + def pattern(x, w, x_scale, w_scale, o_scale, output_q): + return torch.ops.sglang.cutedsl_dual_gemm.default( + x, w, output_q, x_scale, w_scale, o_scale + ) + + +class _DualGemmFp8PatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + if _is_cutedsl_dual_gemm_available(): + self.register_op_pattern(CuteDSLDualGemmFp8Pattern) + else: + self.register_op_pattern(TritonFusedOpsDualGemmFp8Pattern) + + +DualGemmFp8PatternRegistery = _DualGemmFp8PatternRegistery() diff --git a/python/sglang/srt/compilation/fusion/pattern/gemm_fp8_pattern.py b/python/sglang/srt/compilation/fusion/pattern/gemm_fp8_pattern.py new file mode 100644 index 000000000000..522ff43861df --- /dev/null +++ b/python/sglang/srt/compilation/fusion/pattern/gemm_fp8_pattern.py @@ -0,0 +1,53 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from abc import abstractmethod + +import torch + +from sglang.srt.compilation.fusion.pattern import OpPatternBase, OpPatternRegistery + + +class _GemmFp8Pattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, w, x_scale, w_scale, out_dtype): + pass + + +# op: aten/_scaled_mm +class TorchScaledMMPattern(_GemmFp8Pattern): + @staticmethod + def pattern(x, w, x_scale, w_scale, out_dtype): + return torch.ops.aten._scaled_mm.default( + x, w, x_scale, w_scale, None, None, out_dtype + ) + + +# op: sgl-kernel/fp8_scaled_mm +class CutlassFp8ScaledMMPattern(_GemmFp8Pattern): + @staticmethod + def pattern(x, w, x_scale, w_scale, out_dtype): + return torch.ops.sgl_kernel.fp8_scaled_mm.default( + x, w, x_scale, w_scale, out_dtype, None + ) + + +class _GemmFp8PatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + self.register_op_pattern(TorchScaledMMPattern) + self.register_op_pattern(CutlassFp8ScaledMMPattern) + + +GemmFp8PatternRegistery = _GemmFp8PatternRegistery() diff --git a/python/sglang/srt/compilation/fusion/pattern/quant_fp8_pattern.py b/python/sglang/srt/compilation/fusion/pattern/quant_fp8_pattern.py new file mode 100644 index 000000000000..27076c8e3105 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/pattern/quant_fp8_pattern.py @@ -0,0 +1,74 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from abc import abstractmethod + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized_v2 + +from sglang.srt.compilation.fusion.pattern import OpPatternBase, OpPatternRegistery + + +class _QuantFp8Pattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, output, scale): + pass + + +# op: sgl-kernel/per_tensor_quant_fp8 +class PerTensorQuantFp8Pattern(_QuantFp8Pattern): + @staticmethod + def pattern(x, output, scale): + per_tensor_quant_fp8 = auto_functionalized_v2( + torch.ops.sglang.per_tensor_quant_fp8.default, + input=x[1], + is_static=True, + _output_q_base_index=0, + _output_s_base_index=1, + _all_bases=[output, scale], + ) + + return ( + per_tensor_quant_fp8[1], + per_tensor_quant_fp8[2], + ) + + +# op: fp8_kernels.py/static_quant_fp8_fwd +class StaticQuantFp8Pattern(_QuantFp8Pattern): + @staticmethod + def pattern(x, output, scale): + static_quant_fp8 = auto_functionalized_v2( + torch.ops.sglang.static_quant_fp8.default, + x=x[1], + x_s=scale, + repeat_scale=True, + _x_q_base_index=0, + _all_bases=[output], + ) + + return ( + static_quant_fp8[1], + static_quant_fp8[0], + ) + + +class _QuantFp8PatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + self.register_op_pattern(PerTensorQuantFp8Pattern) + self.register_op_pattern(StaticQuantFp8Pattern) + + +QuantFp8PatternRegistery = _QuantFp8PatternRegistery() diff --git a/python/sglang/srt/compilation/fusion/pattern/rmsnorm_quant_fp8_pattern.py b/python/sglang/srt/compilation/fusion/pattern/rmsnorm_quant_fp8_pattern.py new file mode 100644 index 000000000000..3ab910f47c83 --- /dev/null +++ b/python/sglang/srt/compilation/fusion/pattern/rmsnorm_quant_fp8_pattern.py @@ -0,0 +1,167 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from abc import abstractmethod + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized_v2 + +from sglang.jit_kernel.utils import is_arch_support_pdl +from sglang.srt.compilation.fusion.pattern import OpPatternBase, OpPatternRegistery +from sglang.srt.utils import is_flashinfer_rmsnorm_quant_kernels_available + + +def _is_jit_rmsnorm_quant_available(): + try: + from sglang.jit_kernel.norm import rmsnorm_quant # noqa: F401 + + return True + except Exception: + return False + + +class _RmsnormQuantFp8Pattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, weight, scale, eps, output): + pass + + +# op: jit/rmsnorm_quant +class JitRmsnormQuantFp8Pattern(_RmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, weight, scale, eps, output): + return auto_functionalized_v2( + torch.ops.sglang.jit_rmsnorm_quant.default, + input=x, + weight=weight, + scale=scale, + eps=eps, + _out_base_index=0, + _all_bases=[output], + ) + + +# op: flashinfer/rmsnorm_quant +class FlashinferRmsnormQuantFp8Pattern(_RmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, weight, scale, eps, output): + return auto_functionalized_v2( + torch.ops.sglang.flashinfer_rmsnorm_quant.default, + input=x, + weight=weight, + scale=scale, + eps=eps, + enable_pdl=is_arch_support_pdl(), + _out_base_index=0, + _all_bases=[output], + ) + + +# op: sgl-kernel/rmsnorm_quant +class SglangRmsnormQuantFp8Pattern(_RmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, weight, scale, eps, output): + return auto_functionalized_v2( + torch.ops.sgl_kernel.rms_norm_static_fp8_quant.default, + input=x, + weight=weight, + scale=scale, + epsilon=eps, + _result_base_index=0, + _all_bases=[output], + ) + + +class _RmsnormQuantFp8PatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + if _is_jit_rmsnorm_quant_available(): + self.register_op_pattern(JitRmsnormQuantFp8Pattern) + elif is_flashinfer_rmsnorm_quant_kernels_available(): + self.register_op_pattern(FlashinferRmsnormQuantFp8Pattern) + else: + self.register_op_pattern(SglangRmsnormQuantFp8Pattern) + + +RmsnormQuantFp8PatternRegistery = _RmsnormQuantFp8PatternRegistery() + + +class _FusedAddRmsnormQuantFp8Pattern(OpPatternBase): + @staticmethod + @abstractmethod + def pattern(x, residual, weight, scale, result, eps): + pass + + +# op: jit/fused_add_rmsnorm_quant +class JitFusedAddRmsnormQuantFp8Pattern(_FusedAddRmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, residual, weight, scale, result, eps): + return auto_functionalized_v2( + torch.ops.sglang.jit_fused_add_rmsnorm_quant.default, + input=x, + residual=residual, + weight=weight, + scale=scale, + eps=eps, + _out_base_index=0, + _residual_base_index=1, + _all_bases=[result, residual], + ) + + +# op: flashinfer/fused_add_rmsnorm_quant +class FlashinferFusedAddRmsnormQuantFp8Pattern(_FusedAddRmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, residual, weight, scale, result, eps): + return auto_functionalized_v2( + torch.ops.sglang.flashinfer_fused_add_rmsnorm_quant.default, + input=x, + weight=weight, + scale=scale, + eps=eps, + enable_pdl=is_arch_support_pdl(), + _out_base_index=0, + _residual_base_index=1, + _all_bases=[result, residual], + ) + + +# op: sgl-kernel/fused_add_rmsnorm_quant +class SglangFusedAddRmsnormQuantFp8Pattern(_FusedAddRmsnormQuantFp8Pattern): + @staticmethod + def pattern(x, residual, weight, scale, result, eps): + return auto_functionalized_v2( + torch.ops.sgl_kernel.fused_add_rms_norm_static_fp8_quant.default, + input=x, + weight=weight, + scale=scale, + epsilon=eps, + _result_base_index=0, + _residual_base_index=1, + _all_bases=[result, residual], + ) + + +class _FusedAddRmsnormQuantFp8PatternRegistery(OpPatternRegistery): + def build_op_pattern_registery(self): + if _is_jit_rmsnorm_quant_available(): + self.register_op_pattern(JitFusedAddRmsnormQuantFp8Pattern) + elif is_flashinfer_rmsnorm_quant_kernels_available(): + self.register_op_pattern(FlashinferFusedAddRmsnormQuantFp8Pattern) + else: + self.register_op_pattern(SglangFusedAddRmsnormQuantFp8Pattern) + + +FusedAddRmsnormQuantFp8PatternRegistery = _FusedAddRmsnormQuantFp8PatternRegistery() diff --git a/python/sglang/srt/compilation/inductor_pass.py b/python/sglang/srt/compilation/inductor_pass.py index acbde65bf8ab..d38d02f2cb87 100644 --- a/python/sglang/srt/compilation/inductor_pass.py +++ b/python/sglang/srt/compilation/inductor_pass.py @@ -6,6 +6,7 @@ import logging import time import types +from abc import abstractmethod from contextlib import contextmanager from typing import Any, Callable, Optional, Union @@ -13,6 +14,14 @@ from torch import fx from torch._dynamo.utils import lazy_format_graph_code from torch._inductor.custom_graph_pass import CustomGraphPass +from torch._inductor.pattern_matcher import ( + PatternMatcherPass, + fwd_only, + register_replacement, +) +from torch.fx.experimental.proxy_tensor import make_fx + +from sglang.srt.compilation.pass_config import PassConfig logger = logging.getLogger(__name__) @@ -119,7 +128,13 @@ def __init__( self.pass_name = self.__class__.__name__ def dump_graph(self, graph: torch.fx.Graph, stage: str): - lazy_format_graph_code(stage, graph.owning_module) + return lazy_format_graph_code( + stage, + graph.owning_module, + include_stride=True, + include_device=True, + colored=True, + ) def begin(self): self._start_time = time.perf_counter_ns() @@ -138,3 +153,93 @@ def __init__(self, name: str): def __call__(self, graph: torch.fx.Graph): self.dump_graph(graph, self.name) + + +class SGLangPatternMatcherInductorPass(SGLangInductorPass): + def __init__(self, pass_config: PassConfig): + self.pass_config = pass_config + self.pass_name = self.__class__.__name__ + self.patterns = PatternMatcherPass(self.pass_name) + self.build_pass() + + def __call__(self, graph: torch.fx.graph): + if self.pass_config.enable_torch_compile_graph_trace_logs: + logger.info("%s", str(self.dump_graph(graph, f"Before_{self.pass_name}"))) + + self.begin() + count = self.patterns.apply(graph) + self.end_and_log(count) + + if count > 0 and self.pass_config.enable_torch_compile_graph_trace_logs: + logger.info("%s", str(self.dump_graph(graph, f"After_{self.pass_name}"))) + + @abstractmethod + def build_pass(self) -> None: + pass + + def register_replacement_pattern( + self, pattern: Callable, replacement: Callable, example_inputs: Any, **kwargs + ) -> None: + register_replacement( + search_fn=pattern, + replace_fn=replacement, + example_inputs=example_inputs, + trace_fn=fwd_only, + pass_dicts=self.patterns, + **kwargs, + ) + + if self.pass_config.enable_torch_compile_graph_trace_logs: + scalar_workaround = kwargs.get("scalar_workaround", {}) + trace_inputs = self._build_trace_inputs( + pattern, example_inputs, scalar_workaround + ) + + pattern_trace = types.SimpleNamespace( + owning_module=make_fx(pattern, tracing_mode="symbolic")(*trace_inputs) + ) + replacement_trace = types.SimpleNamespace( + owning_module=make_fx(replacement, tracing_mode="symbolic")( + *trace_inputs + ) + ) + + logger.info( + "%s", str(self.dump_graph(pattern_trace, f"{self.pass_name}_Pattern")) + ) + logger.info( + "%s", + str( + self.dump_graph(replacement_trace, f"{self.pass_name}_Replacement") + ), + ) + + def end_and_log(self, count: int): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug( + "%s completed in %.1f ms, matched %s times", + self.pass_name, + duration_ms, + count, + ) + + @staticmethod + def _build_trace_inputs(fn, example_inputs, scalar_workaround): + """Insert scalar_workaround values into example_inputs at the + correct positions based on the function signature.""" + if not scalar_workaround: + return example_inputs + + params = list(inspect.signature(fn).parameters.keys()) + # Map scalar param names to their positional index + scalar_positions = { + name: idx for idx, name in enumerate(params) if name in scalar_workaround + } + + result = list(example_inputs) + # Insert in order of position (so earlier inserts don't shift later indices) + for name, idx in sorted(scalar_positions.items(), key=lambda x: x[1]): + result.insert(idx, scalar_workaround[name]) + + return result diff --git a/python/sglang/srt/compilation/pass_config.py b/python/sglang/srt/compilation/pass_config.py new file mode 100644 index 000000000000..a184c9be232a --- /dev/null +++ b/python/sglang/srt/compilation/pass_config.py @@ -0,0 +1,70 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import hashlib +import json +import logging +from dataclasses import asdict, dataclass +from typing import Optional + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.server_args import ServerArgs + +logger = logging.getLogger(__name__) + + +@dataclass +class PassConfig: + device: Optional[str] + model_dtype: Optional[str] + + enable_fusion: bool + disable_rmsnorm_quant_pass: bool + disable_fused_activation_pass: bool + + rms_norm_eps: list[float] + + enable_torch_compile_graph_trace_logs: bool + + def uuid(self): + encoded = json.dumps(asdict(self), sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + @staticmethod + def from_server_args_and_model_config( + server_args: ServerArgs, model_config: ModelConfig + ): + disable_rmsnorm_quant_pass = True + rms_norm_eps = [] + if not server_args.disable_rmsnorm_quant_pass: + disable_rmsnorm_quant_pass = False + if model_config.hf_config.rms_norm_eps is not None: + rms_norm_eps.append(model_config.hf_config.rms_norm_eps) + else: + logger.warning( + "RMSNorm epsilon value not found in hugging face config, " + "registering fusion passes for default (1e-05, 1e-06) values." + ) + rms_norm_eps.append(1e-05) + rms_norm_eps.append(1e-06) + + return PassConfig( + device=server_args.device if server_args.device else None, + model_dtype=server_args.dtype if server_args.dtype else None, + enable_fusion=server_args.enable_torch_compile_fusion, + disable_rmsnorm_quant_pass=disable_rmsnorm_quant_pass, + disable_fused_activation_pass=server_args.disable_fused_activation_pass, + enable_torch_compile_graph_trace_logs=server_args.enable_torch_compile_graph_trace_logs, + rms_norm_eps=rms_norm_eps, + ) diff --git a/python/sglang/srt/compilation/pass_manager.py b/python/sglang/srt/compilation/pass_manager.py index 9173976f1878..dfadaa125f55 100644 --- a/python/sglang/srt/compilation/pass_manager.py +++ b/python/sglang/srt/compilation/pass_manager.py @@ -4,13 +4,16 @@ from torch import fx as fx -from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass +from sglang.srt.compilation.fusion.ops import register_fused_ops +from sglang.srt.compilation.fusion.passes import FusedActivationPass, RMSNormQuantPass + +# from sglang.srt.compilation.fix_functionalization import FixFunctionalizationPass from sglang.srt.compilation.inductor_pass import ( CustomGraphPass, InductorPass, SGLangInductorPass, - get_pass_context, ) +from sglang.srt.compilation.pass_config import PassConfig logger = logging.getLogger(__name__) @@ -30,23 +33,38 @@ class PostGradPassManager(CustomGraphPass): This way, all passes operate on a functionalized graph. """ - def __init__(self): + def __init__(self, pass_config: PassConfig): + self.pass_config = pass_config self.passes: list[SGLangInductorPass] = [] def __call__(self, graph: fx.Graph): - shape = get_pass_context().runtime_shape + logger.debug("Running custom inductor passes.") + + # TODO pass context is not set when running the pass manager + # directly, i.e during torch compile in cuda graph runner + # shape = get_pass_context().runtime_shape for pass_ in self.passes: - if pass_.is_applicable_for_shape(shape): - pass_(graph) + # if pass_.is_applicable_for_shape(shape): + pass_(graph) + # TODO: not required if using auto_functionalized_v2 # always run fix_functionalization last - self.fix_functionalization(graph) + # self.fix_functionalization(graph) + + def configure(self): + register_fused_ops() + + # self.fix_functionalization = FixFunctionalizationPass() + if self.pass_config.enable_fusion: + if not self.pass_config.disable_rmsnorm_quant_pass: + self.passes.append(RMSNormQuantPass(self.pass_config)) + + if not self.pass_config.disable_fused_activation_pass: + self.passes.append(FusedActivationPass(self.pass_config)) - def configure( - self, - ): - self.pass_config = dict() - self.fix_functionalization = FixFunctionalizationPass() + logger.debug( + f"Passes Configured: {list(map(lambda x: x.pass_name, self.passes))}" + ) def add(self, pass_: InductorPass): assert isinstance(pass_, InductorPass) @@ -58,9 +76,8 @@ def uuid(self): affects compilation caching. Its uuid depends on the UUIDs of all dependent passes and the pass config. See InductorPass for more info. """ - pass_manager_uuid = "fshdakhsa" - state = {"pass_config": pass_manager_uuid, "passes": []} + state = {"pass_config": self.pass_config.uuid(), "passes": []} for pass_ in self.passes: state["passes"].append(pass_.uuid()) - state["passes"].append(self.fix_functionalization.uuid()) + # state["passes"].append(self.fix_functionalization.uuid()) return InductorPass.hash_dict(state) diff --git a/python/sglang/srt/layers/attention/dummy_backend.py b/python/sglang/srt/layers/attention/dummy_backend.py new file mode 100644 index 000000000000..6edb5deba2a4 --- /dev/null +++ b/python/sglang/srt/layers/attention/dummy_backend.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union + +import torch + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend + +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + + +class DummyAttentionBackend(AttentionBackend): + + def set_out(self, out: torch.Tensor): + self.out = out + + def init_forward_metadata(self, forward_batch: ForwardBatch): + """Init the metadata for a forward pass.""" + raise NotImplementedError() + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + """Init the global shared states for cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + """Init the metadata for a forward pass for capturing a cuda graph.""" + raise NotImplementedError() + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + """Init the metadata for a forward pass for replaying a cuda graph.""" + raise NotImplementedError() + + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" + raise NotImplementedError() + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + out_view = self.out.view(-1) + out_view[0] = q.view(-1)[0] + k.view(-1)[0] + v.view(-1)[0] + return self.out + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): + out_view = self.out.view(-1) + out_view[0] = q.view(-1)[0] + k.view(-1)[0] + v.view(-1)[0] + return self.out diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 1466bac6bec4..e00f4874a6df 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -31,6 +31,7 @@ from sglang.srt.layers import deep_gemm_wrapper from sglang.srt.utils import ( ceil_align, + direct_register_custom_op, get_bool_env_var, get_device_core_count, get_device_name, @@ -646,18 +647,19 @@ def _static_quant_fp8( tl.store(y_s_repeat_ptr, y_s) -def static_quant_fp8( +def _static_quant_fp8_fwd( x: torch.Tensor, + x_q: torch.Tensor, x_s: torch.Tensor, repeat_scale: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> torch.Tensor: """Function to perform static quantization using the given scale on an input tensor `x`. It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. Args: - x: The input tensor with ndim >= 2. + x: The input tenosr with ndim >= 2. x_s: The quantization scale. repeat_scale: Whether to broadcast per-tensor scale to per-channel scale. dtype: The dype of output tensor. @@ -668,7 +670,6 @@ def static_quant_fp8( assert x.is_contiguous(), "`x` is not contiguous" assert x_s.numel() == 1, "only supports per-tensor scale" - x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) M = x.numel() // x.shape[-1] N = x.shape[-1] if repeat_scale: @@ -680,6 +681,10 @@ def static_quant_fp8( else: x_s_repeat = None + # convert scalars to vectors because torch.compile can't deal with tl.load on a scalar + if x_s.shape == (): + x_s = x_s.unsqueeze(0) + BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) @@ -699,6 +704,43 @@ def static_quant_fp8( num_stages=num_stages, ) x_s = x_s_repeat if repeat_scale else x_s + return x_s + + +def _static_quant_fp8_fake( + x: torch.Tensor, + x_q: torch.Tensor, + x_s: torch.Tensor, + repeat_scale: bool = False, +) -> torch.Tensor: + M = x.numel() // x.shape[-1] + if repeat_scale: + x_s_repeat = torch.empty( + (M, 1), + device=x.device, + dtype=torch.float32, + ) + else: + x_s_repeat = None + x_s = x_s_repeat if repeat_scale else x_s + return x_s + + +direct_register_custom_op( + op_name="static_quant_fp8", + op_func=_static_quant_fp8_fwd, + mutates_args=["x_q"], + fake_impl=_static_quant_fp8_fake, +) + + +def static_quant_fp8( + x: torch.Tensor, + x_s: torch.Tensor, + repeat_scale: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + x_s = torch.ops.sglang.static_quant_fp8(x, x_q, x_s, repeat_scale) return x_q, x_s diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 449bc867067e..d285c2e163e2 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -115,14 +115,27 @@ def forward( k = k.view(-1, self.tp_k_head_num, self.v_head_dim) if forward_batch.forward_mode.is_extend() and get_forward_context() is not None: - if self.qk_head_dim != self.v_head_dim: - output = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) - else: - output = torch.empty_like(q) - unified_attention_with_output( - q, k, v, output, save_kv_cache, self.layer_id, **kwargs + return unified_attention_with_output( + q, + k, + v, + self.tp_q_head_num * self.v_head_dim, + save_kv_cache, + self.layer_id, + **kwargs, + ) + elif ( + forward_batch.forward_mode.is_decode() and get_forward_context() is not None + ): + return unified_attention_with_output( + q, + k, + v, + self.tp_q_head_num * self.v_head_dim, + save_kv_cache, + self.layer_id, + **kwargs, ) - return output else: return forward_batch.attn_backend.forward( q, @@ -135,20 +148,35 @@ def forward( ) -@register_custom_op(mutates_args=["output"]) +def _unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_out_features: int, + save_kv_cache: bool, + layer_id: int, + *, + q_rope: Optional[torch.Tensor] = None, + k_rope: Optional[torch.Tensor] = None, + sinks: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return query.new_empty(query.shape[0], num_out_features) + + +@register_custom_op(fake_impl=_unified_attention_fake) @register_split_op() def unified_attention_with_output( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - output: torch.Tensor, + num_out_features: int, save_kv_cache: bool, layer_id: int, *, q_rope: Optional[torch.Tensor] = None, k_rope: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, -) -> None: +) -> torch.Tensor: context = get_forward_context() forward_batch = context.forward_batch attention_layers = context.attention_layers @@ -162,12 +190,6 @@ def unified_attention_with_output( if sinks is not None: kwargs["sinks"] = sinks - ret = forward_batch.attn_backend.forward( + return forward_batch.attn_backend.forward( query, key, value, attention_layer, forward_batch, save_kv_cache, **kwargs ) - assert ( - output.numel() == ret.numel() - ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}" - - output.view(ret.shape).copy_(ret) - return diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 40fe210df2d8..988ef0e8fb06 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -96,10 +96,10 @@ def _set_kv_buffer_impl( store_dtype: torch.dtype, device_module: Any, alt_stream: Optional[torch.cuda.Stream] = None, - same_kv_dim: bool = True, + use_store_cache: bool = False, ) -> None: row_bytes = row_dim * store_dtype.itemsize - if (_is_cuda or _is_hip) and same_kv_dim and can_use_store_cache(row_bytes): + if use_store_cache: return store_cache( k.view(-1, row_dim), v.view(-1, row_dim), @@ -779,6 +779,12 @@ def __init__( # for store_cache JIT kernel self.row_dim = self.head_num * self.head_dim self.same_kv_dim = self.head_dim == self.v_head_dim + row_bytes = self.row_dim * self.store_dtype.itemsize + self.use_store_cache = ( + (_is_cuda or _is_hip) + and self.same_kv_dim + and can_use_store_cache(row_bytes) + ) def _init_kv_copy_and_warmup(self): # Heuristics for KV copy tiling @@ -1013,7 +1019,7 @@ def set_kv_buffer( store_dtype=self.store_dtype, device_module=self.device_module, alt_stream=self.alt_stream, - same_kv_dim=self.same_kv_dim, + use_store_cache=self.use_store_cache, ) def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index ebe4c41d757b..0fefe22609ec 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -30,6 +30,7 @@ from torch.profiler import ProfilerActivity, profile from sglang.srt.batch_overlap.two_batch_overlap import TboCudaGraphRunnerPlugin +from sglang.srt.compilation.piecewise_context_manager import set_forward_context from sglang.srt.constants import GPU_MEMORY_TYPE_CUDA_GRAPH from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.device_communicators.pynccl_allocator import ( @@ -402,10 +403,25 @@ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): _to_torch(sub, reverse, num_tokens) +def _torch_compile_wrapper(forward): + return torch.compile( + torch.no_grad()(forward), + mode=os.environ.get("SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"), + dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"), + fullgraph=True, + ) + + +def torch_compile(model: torch.nn.Module, server_args, model_config): + set_torch_compile_config(server_args, model_config) + model.forward = _torch_compile_wrapper(model.forward) + + @contextmanager def patch_model( model: torch.nn.Module, enable_compile: bool, + enable_fusion: bool, num_tokens: int, tp_group: GroupCoordinator, ): @@ -414,28 +430,24 @@ def patch_model( try: if enable_compile: - _to_torch(model, reverse=False, num_tokens=num_tokens) + if not enable_fusion: + _to_torch(model, reverse=False, num_tokens=num_tokens) backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, # even with ENABLE_INTRA_NODE_COMM=1. # tp_group.ca_comm = None - yield torch.compile( - torch.no_grad()(model.forward), - mode=os.environ.get( - "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" - ), - dynamic=_is_hip and get_bool_env_var("SGLANG_TORCH_DYNAMIC_SHAPE"), - ) + yield _torch_compile_wrapper(model.forward) else: yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True, num_tokens=num_tokens) + if not enable_fusion: + _to_torch(model, reverse=True, num_tokens=num_tokens) tp_group.ca_comm = backup_ca_comm -def set_torch_compile_config(): +def set_torch_compile_config(server_args, model_config): import torch._dynamo.config import torch._inductor.config @@ -443,6 +455,18 @@ def set_torch_compile_config(): torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future + if server_args.enable_torch_compile_fusion: + from sglang.srt.compilation.pass_config import PassConfig + from sglang.srt.compilation.pass_manager import PostGradPassManager + + pass_config = PassConfig.from_server_args_and_model_config( + server_args, model_config + ) + pass_manager = PostGradPassManager(pass_config) + pass_manager.configure() + + torch._inductor.config.post_grad_custom_post_pass = pass_manager + # FIXME: tmp workaround torch._dynamo.config.accumulated_cache_size_limit = 1024 if hasattr(torch._dynamo.config, "cache_size_limit"): @@ -510,6 +534,9 @@ def __init__(self, model_runner: ModelRunner): self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.enable_torch_compile_fusion = ( + model_runner.server_args.enable_torch_compile_fusion + ) self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) @@ -591,7 +618,9 @@ def __init__(self, model_runner: ModelRunner): self.encoder_len_fill_value = 0 if self.enable_torch_compile: - set_torch_compile_config() + set_torch_compile_config( + self.model_runner.server_args, self.model_runner.model_config + ) if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info( @@ -630,6 +659,12 @@ def __init__(self, model_runner: ModelRunner): self.tbo_plugin = TboCudaGraphRunnerPlugin() + # Forward context for unified attention op + self.attention_layers = model_runner.attention_layers + self.quant_config = getattr(model_runner.model, "quant_config", None) + self.moe_layers = model_runner.moe_layers + self.moe_fusions = model_runner.moe_fusions + # Speculative_inference if ( model_runner.spec_algorithm.is_eagle3() @@ -777,6 +812,7 @@ def _capture_one_stream(stream_idx: Optional[int] = None): with patch_model( self.model_runner.model, bs in self.compile_bs, + self.enable_torch_compile_fusion, num_tokens=bs * self.num_tokens_per_bs, tp_group=self.model_runner.tp_group, ) as forward: @@ -992,12 +1028,19 @@ def run_once(): {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) - logits_output_or_pp_proxy_tensors = forward( - input_ids, - forward_batch.positions, + with set_forward_context( forward_batch, - **kwargs, - ) + self.attention_layers, + self.quant_config, + self.moe_layers, + self.moe_fusions, + ): + logits_output_or_pp_proxy_tensors = forward( + input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) return logits_output_or_pp_proxy_tensors self.deepep_adapter.capture(is_extend_in_batch=False) @@ -1139,7 +1182,15 @@ def replay( graph_key = f"{get_current_stream_idx()}_{self.bs}" else: graph_key = self.bs - self.graphs[graph_key].replay() + + with set_forward_context( + forward_batch, + self.attention_layers, + self.quant_config, + self.moe_layers, + self.moe_fusions, + ): + self.graphs[graph_key].replay() output = self.output_buffers[graph_key] if isinstance(output, LogitsProcessorOutput): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 02b856c6f0c6..f1e2419d2af9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -125,6 +125,7 @@ CudaGraphRunner, DecodeInputBuffers, set_torch_compile_config, + torch_compile, ) from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -614,6 +615,9 @@ def initialize(self, pre_model_load_memory: float): # Init routed experts capturer self.init_routed_experts_capturer() + # Collect attention and MoE layers from the model + self.collect_attention_and_moe_layers() + if self.device == "cuda" or self.device == "musa": self.init_cublas() self.init_attention_backend() @@ -2205,6 +2209,8 @@ def init_device_graphs(self): return if self.device != "cpu" and self.server_args.disable_cuda_graph: + if self.server_args.enable_torch_compile: + torch_compile(self.model, self.server_args, self.model_config) return if self.device == "cpu" and not self.server_args.enable_torch_compile: @@ -2238,36 +2244,18 @@ def init_device_graphs(self): f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB." ) - def init_piecewise_cuda_graphs(self): - """Initialize piecewise CUDA graph runner.""" - self.piecewise_cuda_graph_runner = None - - if self.server_args.disable_piecewise_cuda_graph: - logger.info( - "Disable piecewise CUDA graph because --disable-piecewise-cuda-graph is set" - ) - return + def collect_attention_and_moe_layers(self): + """Collect attention layers and MoE layers from the model for use by + CudaGraphRunner and PiecewiseCudaGraphRunner.""" + self.attention_layers = [] + self.moe_layers = [] + self.moe_fusions = [] - # Disable piecewise CUDA graph for non-language models if not hasattr(self.model, "model"): - logger.warning( - "Disable piecewise CUDA graph because the model is not a language model" - ) return - # Disable piecewise CUDA graph for non capture size - if not self.server_args.piecewise_cuda_graph_tokens: - logger.warning( - "Disable piecewise CUDA graph because the capture size is not set" - ) - return - - # Collect attention layers and moe layers from the model self.model.model = resolve_language_model(self.model) language_model = getattr(self.model, "language_model", self.model) - self.attention_layers = [] - self.moe_layers = [] - self.moe_fusions = [] for layer in language_model.model.layers: attn_layer = None if hasattr(layer, "self_attn"): @@ -2319,6 +2307,30 @@ def init_piecewise_cuda_graphs(self): self.moe_layers.append(moe_block) self.moe_fusions.append(moe_fusion) + def init_piecewise_cuda_graphs(self): + """Initialize piecewise CUDA graph runner.""" + self.piecewise_cuda_graph_runner = None + + if self.server_args.disable_piecewise_cuda_graph: + logger.info( + "Disable piecewise CUDA graph because --disable-piecewise-cuda-graph is set" + ) + return + + # Disable piecewise CUDA graph for non-language models + if not hasattr(self.model, "model"): + logger.warning( + "Disable piecewise CUDA graph because the model is not a language model" + ) + return + + # Disable piecewise CUDA graph for non capture size + if not self.server_args.piecewise_cuda_graph_tokens: + logger.warning( + "Disable piecewise CUDA graph because the capture size is not set" + ) + return + if len(self.attention_layers) < self.model_config.num_hidden_layers: # TODO(yuwei): support Non-Standard GQA log_info_on_rank0( diff --git a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py index a7f6840d3e8b..83f0a6627dad 100644 --- a/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/piecewise_cuda_graph_runner.py @@ -20,14 +20,14 @@ import logging from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, TYPE_CHECKING, Union import torch import tqdm - from sglang.srt.batch_overlap.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.compilation.compilation_config import CompilationConfig from sglang.srt.compilation.compile import install_torch_compiled +from sglang.srt.compilation.pass_config import PassConfig from sglang.srt.compilation.piecewise_context_manager import ( enable_piecewise_cuda_graph, enable_piecewise_cuda_graph_compile, @@ -177,6 +177,10 @@ def __init__(self, model_runner: ModelRunner): self.model_runner.server_args.piecewise_cuda_graph_compiler, self.model_runner.server_args.enable_torch_compile_debug_mode, ) + self.pass_config = PassConfig.from_server_args_and_model_config( + self.model_runner.server_args, + self.model_runner.model_config, + ) if get_moe_a2a_backend().is_deepep() or get_moe_a2a_backend().is_mooncake(): self.compile_config.add_split_op( "sglang.moe_forward_piecewise_cuda_graph_impl" @@ -289,6 +293,7 @@ def __init__(self, model_runner: ModelRunner): fullgraph=True, dynamic_arg_dims=None, compile_config=self.compile_config, + pass_config=self.pass_config, graph_pool=get_global_graph_memory_pool(), ) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 4be11541a924..a5cef9862a6a 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -73,7 +73,7 @@ def enable_hf_transfer(): if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: try: # enable hf hub transfer if available - import hf_transfer # type: ignore # noqa + import hf_transfer # type: ignore # noqa huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True except ImportError: @@ -1167,10 +1167,19 @@ def initialize_dummy_weights( generator = torch.Generator(device=param.data.device) generator.manual_seed(seed) if torch.finfo(param.data.dtype).bits < 16: - # uniform_ doesn't support < 16-bit datatypes (FP8) + # uniform_ doesn't support < 16-bit datatypes (FP8). + # Use a wider range so values survive the cast back to + # low-precision types (e.g. FP8 e4m3fn smallest subnormal + # is ~0.002, so the default [-1e-3, 1e-3] would all round + # to zero). dtype = param.data.dtype + finfo = torch.finfo(dtype) + lo = max(low, -finfo.max) + hi = min(high, finfo.max) + if hi - lo < 2 * finfo.tiny: + lo, hi = -finfo.tiny * 2, finfo.tiny * 2 tmp_param = param.data.to(torch.float16) - tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype) + tmp_param = tmp_param.uniform_(lo, hi, generator=generator).to(dtype) param.data.copy_(tmp_param) else: param.uniform_(low, high, generator=generator) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 448c7d6c44f7..c38fa56295c9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -625,6 +625,10 @@ class ServerArgs: enable_single_batch_overlap: bool = False tbo_token_distribution_threshold: float = 0.48 enable_torch_compile: bool = False + enable_torch_compile_fusion: bool = False + disable_rmsnorm_quant_pass: bool = False + disable_fused_activation_pass: bool = False + enable_torch_compile_graph_trace_logs: bool = False disable_piecewise_cuda_graph: bool = False enforce_piecewise_cuda_graph: bool = False enable_torch_compile_debug_mode: bool = False @@ -5204,6 +5208,26 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Enable debug mode for torch compile", ) + parser.add_argument( + "--enable-torch-compile-fusion", + action="store_true", + help="Enables operator fusion using custom torch compile/inductor passes. Experimental feature.", + ) + parser.add_argument( + "--disable-rmsnorm-quant-pass", + action="store_true", + help="Enables the rmsnorm quant pass for torch compile fusion, to be used with --enable-torch-compile-fusion.", + ) + parser.add_argument( + "--disable-fused-activation-pass", + action="store_true", + help="Enables the fused activation pass for torch compile fusion, to be used with --enable-torch-compile-fusion.", + ) + parser.add_argument( + "--enable-torch-compile-graph-trace-logs", + action="store_true", + help="Enables logging of traced before and after graphs for fusion passes, to be use with log level debug.", + ) parser.add_argument( "--disable-piecewise-cuda-graph", action="store_true", @@ -5363,6 +5387,8 @@ def add_cli_args(parser: argparse.ArgumentParser): nargs="+", help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess.", ) + + # Debug tensor dumps parser.add_argument( "--enable-deterministic-inference", action="store_true", diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index cd504fc775c8..ab3b817376f7 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -315,6 +315,17 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() +def is_flashinfer_rmsnorm_quant_kernels_available(): + return False # TODO: remove this once flashinfer fixes land + + if importlib.util.find_spec("flashinfer") is not None: + import flashinfer.norm + + return hasattr(flashinfer.norm, "rmsnorm_quant") + else: + return False + + def is_nvidia_cublas_version_ge_12_9(): """ temporary fix for issue #11272 (cublas 12.9+) diff --git a/python/sglang/test/model_bench.py b/python/sglang/test/model_bench.py new file mode 100644 index 000000000000..8affbcd6f57a --- /dev/null +++ b/python/sglang/test/model_bench.py @@ -0,0 +1,542 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import math +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from types import SimpleNamespace +from typing import Callable, Optional, Self + +import torch +from transformers import PretrainedConfig + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig, LoadFormat +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel, +) +from sglang.srt.layers.attention.dummy_backend import DummyAttentionBackend +from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, ReqToTokenPool +from sglang.srt.model_executor.cuda_graph_runner import torch_compile +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.model_loader import get_model +from sglang.srt.model_loader.loader import _get_quantization_config +from sglang.srt.model_loader.utils import get_model_architecture +from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.utils import configure_logger + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelBenchArgs: + num_tokens: int + forward_mode: ForwardMode + init_device: str = "cuda" + exec_device: str = "cuda" + use_real_weights: bool = False + warmup_iters: int = 0 + bench_iters: int = 1 + disable_nvtx_tracing: bool = False + + @staticmethod + def from_args(args): + model_bench_args = ModelBenchArgs( + num_tokens=args.num_tokens, + init_device=args.init_device, + exec_device=args.exec_device, + warmup_iters=args.warmup_iters, + bench_iters=args.bench_iters, + forward_mode=ForwardMode(args.forward_mode), + disable_nvtx_tracing=args.disable_nvtx_tracing, + ) + return model_bench_args + + +class ModelBench(ABC): + def __init__( + self, + server_args: ServerArgs, + bench_args: ModelBenchArgs, + initializer: Callable[ + [Self, PretrainedConfig, Optional[QuantizationConfig]], torch.nn.Module + ], + ) -> None: + self._server_args = server_args + self._bench_args = bench_args + self._initializer = initializer + + def __enter__(self): + set_global_server_args_for_scheduler(self._server_args) + + configure_logger(self._server_args, " Model Bench") + + logger.info("====================") + logger.info(f"{self._server_args}") + logger.info("====================") + logger.info(f"{self._bench_args=}") + logger.info("====================") + + if self._bench_args.init_device != self._bench_args.exec_device: + logger.warning( + "Init and exec device are different, data will be moved and may impact measurements." + ) + + torch.set_default_device(self._bench_args.init_device) + + # distributed setup required for parallel layers + init_distributed_environment( + backend="nccl", + world_size=self._server_args.tp_size * self._server_args.pp_size, + rank=0, + local_rank=self._server_args.base_gpu_id, + distributed_init_method=f"tcp://127.0.0.1:{self._server_args.nccl_port}", + timeout=self._server_args.dist_timeout, + ) + initialize_model_parallel( + tensor_model_parallel_size=self._server_args.tp_size, + pipeline_model_parallel_size=self._server_args.pp_size, + expert_model_parallel_size=self._server_args.ep_size, + duplicate_tp_group=self._server_args.enable_pdmux, + ) + + # Pre processing required for model loader + self._device_config = DeviceConfig(device=self._bench_args.init_device) + self._model_config = ModelConfig.from_server_args(self._server_args) + self._model_class, _ = get_model_architecture(self._model_config) + self._load_config = LoadConfig( + load_format=( + LoadFormat.AUTO + if self._bench_args.use_real_weights + else LoadFormat.DUMMY + ) + ) + self._quant_config = _get_quantization_config( + self._model_config, self._load_config + ) + self._hf_config = self._model_config.hf_config + + self._init_model() + self._init_memory_pool() + self._init_attention_backend() + + logger.info( + f"ModelBench initialized for {self._server_args.model_path}, class: {self._model_class}, quant_config: {self._quant_config}" + ) + + return self + + def __exit__(self, exc_type, exc_value, traceback): + destroy_model_parallel() + destroy_distributed_environment() + + def torch_compile(self): + torch_compile(self._model, self._server_args, self._model_config) + + # TODO: automate calling this in default exec or have warnings + # if not called in cases where it's required like fa3 + def prepare_exec(self, forward_batch): + if type(self._attn_backend) is FlashAttentionBackend: + self._attn_backend.init_forward_metadata(forward_batch) + + def default_exec(self, *args): + inputs = self._ensure_inputs_on_exec_device(*args) + + # model modifications + if self._server_args.enable_torch_compile: + self.torch_compile() + + # warmup + for iter in range(self._bench_args.warmup_iters): + _ = self._model(*inputs) + + # benchmark + tic = time.perf_counter() + for iter in range(self._bench_args.bench_iters): + if not self._bench_args.disable_nvtx_tracing: + torch.cuda.nvtx.range_push("RIG_PROFILE") + result = self._model(*inputs) + if not self._bench_args.disable_nvtx_tracing: + torch.cuda.nvtx.range_pop() + torch.cuda.synchronize() + latency = time.perf_counter() - tic + throughput = self._bench_args.bench_iters / latency + logger.info( + f"Total latency: {latency:6.5f} s, throughput: {throughput:9.2f} iters/sec" + ) + return result + + def _init_model(self): + """Initialize the model using the custom initializer via monkey-patching. + + Monkey-patches ``loader._initialize_model`` so that ``get_model()`` + delegates model construction to ``self._initializer`` instead of + the default ``model_class(**kwargs)`` path. This lets callers + supply a partial model (e.g. just the MLP) while still going + through the full weight-loading and device-placement pipeline. + The original ``_initialize_model`` is always restored in the + ``finally`` block. + """ + import sglang.srt.model_loader.loader as loader_module + + original_initialize_model = loader_module._initialize_model + + def _custom_initialize_model(model_config, load_config, quant_config=None): + return self._initializer(self, model_config.hf_config, quant_config) + + loader_module._initialize_model = _custom_initialize_model + try: + self._model = get_model( + model_config=self._model_config, + load_config=self._load_config, + device_config=self._device_config, + ) + finally: + loader_module._initialize_model = original_initialize_model + + def _init_attention_backend(self): + if self._server_args.attention_backend == "dummy": + self._attn_backend = DummyAttentionBackend() + self._attn_backend.set_out( + torch.rand( + ( + self._bench_args.num_tokens, + self._model_config.hf_config.hidden_size, + ), + dtype=self._model_config.dtype, + ) + ) + elif self._server_args.attention_backend == "fa3": + mock_model_runner = SimpleNamespace( + mock_obj_name=ModelRunner.__name__, + model_config=self._model_config, + sliding_window_size=self._model_config.attention_chunk_size, + server_args=self._server_args, + device=self._bench_args.exec_device, + req_to_token_pool=self._req_to_token_pool, + kv_cache_dtype=self._kv_cache_dtype, + is_hybrid=False, + page_size=self._server_args.page_size, + ) + self._attn_backend = FlashAttentionBackend(model_runner=mock_model_runner) # type: ignore + else: + logger.debug( + "No/Invalid attention backend was specified in server args hence none was initialized." + ) + self._attn_backend = None + + # TODO: Expose more bench args controlling the memory pool initialization + def _init_memory_pool(self): + if self._server_args.kv_cache_dtype == "auto": + kv_cache_quant_algo = getattr( + self._quant_config, "kv_cache_quant_algo", None + ) + if ( + isinstance(kv_cache_quant_algo, str) + and kv_cache_quant_algo.upper() == "FP8" + ): + self._kv_cache_dtype = torch.float8_e4m3fn + else: + self._kv_cache_dtype = self._model_config.dtype + elif self._server_args.kv_cache_dtype == "fp8_e5m2": + self._kv_cache_dtype = torch.float8_e5m2 + elif self._server_args.kv_cache_dtype == "fp8_e4m3": + self._kv_cache_dtype = torch.float8_e4m3fn + else: + raise ValueError( + f"Unsupported kv_cache_dtype: {self._server_args.kv_cache_dtype}." + ) + + logger.info(f"Using KV cache dtype: {self._kv_cache_dtype}") + + # TODO: Mock req_to_token pool if needed + self._req_to_token_pool = ReqToTokenPool( + size=self._bench_args.num_tokens, + max_context_len=self._bench_args.num_tokens, + device=self._bench_args.exec_device, + enable_memory_saver=False, + ) + + # TODO: Mock token_to_kv_pool (KVCache) if needed + self._token_to_kv_pool = MHATokenToKVPool( + # Page index 0 is reserved (for something), KV cache allocation happens from page index 1 + size=( + 1 + math.ceil(self._bench_args.num_tokens / self._server_args.page_size) + ) + * self._server_args.page_size, + page_size=self._server_args.page_size, + dtype=self._kv_cache_dtype, + head_num=self.get_num_kv_heads(), + head_dim=self._model_config.head_dim, + # TODO: This is gonna depend on the model architecture run using the bench + # essentially this number should be same as number of attention layers + # should be configurable as part of the bench args + layer_num=1, + device=self._bench_args.exec_device, + enable_memory_saver=False, + ) + + def _ensure_inputs_on_exec_device(self, *args) -> list: + inputs = [] + # move to exec device if required + if self._bench_args.exec_device != self._bench_args.init_device: + self._model.to(self._bench_args.exec_device) + + if type(self._attn_backend) is DummyAttentionBackend: + self._attn_backend.out_to(self._bench_args.exec_device) + + if type(self._attn_backend) is FlashAttentionBackend: + self._attn_backend.forward_metadata.cache_seqlens_int32 = ( + self._attn_backend.forward_metadata.cache_seqlens_int32.to( + self._bench_args.exec_device + ) + ) + self._attn_backend.forward_metadata.cu_seqlens_q = ( + self._attn_backend.forward_metadata.cu_seqlens_q.to( + self._bench_args.exec_device + ) + ) + self._attn_backend.forward_metadata.cu_seqlens_k = ( + self._attn_backend.forward_metadata.cu_seqlens_k.to( + self._bench_args.exec_device + ) + ) + + for arg in args: + if type(arg) is ForwardBatch or ( + hasattr(arg, "mock_obj_name") + and arg.mock_obj_name == ForwardBatch.__name__ + ): + arg.out_cache_loc.to(self._bench_args.exec_device) + arg.seq_lens.to(self._bench_args.exec_device) + arg.req_pool_indices.to(self._bench_args.exec_device) + inputs.append(arg) + continue + + if hasattr(arg, "to"): + inputs.append(arg.to(self._bench_args.exec_device)) + else: + inputs.append(arg) + else: + inputs = args + return inputs + + @property + def model_config(self): + return self._model_config + + @property + def model(self): + return self._model + + @property + def bench_args(self): + return self._bench_args + + @property + def attn_backend(self): + return self._attn_backend + + @property + def hf_config(self): + return self._hf_config + + @abstractmethod + def get_num_kv_heads(self) -> int: + pass + + +class LlamaBench(ModelBench): + def __init__( + self, + server_args: ServerArgs, + bench_args: ModelBenchArgs, + initializer: Callable[ + [Self, PretrainedConfig, Optional[QuantizationConfig]], torch.nn.Module + ], + ) -> None: + super().__init__(server_args, bench_args, initializer) + + def get_num_kv_heads(self): + return self._hf_config.num_key_value_heads + + def init_attention(self): + from sglang.srt.layers.radix_attention import RadixAttention + + return RadixAttention( + self._hf_config.num_attention_heads, + self._hf_config.head_dim, + self._hf_config.head_dim**-0.5, + num_kv_heads=self._hf_config.num_key_value_heads, + layer_id=0, # TODO: incerement this based on number of instantiations + quant_config=self._quant_config, + prefix="llama_attention", + ) + + def init_qkv_parallel_linear(self): + from sglang.srt.layers.linear import QKVParallelLinear + + return QKVParallelLinear( + self._hf_config.hidden_size, + self._hf_config.head_dim, + self._hf_config.num_attention_heads, + self._hf_config.num_key_value_heads, + bias=getattr(self._hf_config, "attention_bias", False) + or getattr(self._hf_config, "bias", False), + quant_config=self._quant_config, + prefix="llama_qkv_proj", + ) + + def init_o_parallel_linear(self): + from sglang.srt.layers.linear import RowParallelLinear + + return RowParallelLinear( + self._hf_config.num_attention_heads * self._hf_config.head_dim, + self._hf_config.hidden_size, + bias=getattr(self._hf_config, "attention_bias", False) + or getattr(self._hf_config, "bias", False), + quant_config=self._quant_config, + prefix="llama_o_proj", + ) + + def init_rope(self): + from sglang.srt.layers.rotary_embedding import get_rope + + head_dim = self._hf_config.hidden_size // self._hf_config.num_attention_heads + return get_rope( + head_size=head_dim, + rotary_dim=int(getattr(self._hf_config, "partial_rotary_factor", 1)) + * head_dim, + max_position=getattr(self._hf_config, "max_position_embeddings", 8192), + base=getattr(self._hf_config, "rope_theta", 10000), + rope_scaling=getattr(self._hf_config, "rope_scaling", None), + is_neox_style=getattr(self._hf_config, "rope_is_neox_style", True), + ) + + def init_norm(self): + from sglang.srt.layers.layernorm import RMSNorm + + return RMSNorm(self._hf_config.hidden_size, eps=self._hf_config.rms_norm_eps) + + def init_decoder(self): + from sglang.srt.models.llama import LlamaDecoderLayer + + return LlamaDecoderLayer( + config=self._hf_config, # type: ignore + quant_config=self._quant_config, + prefix="llama_decoder", + ) + + def init_mlp(self): + from sglang.srt.models.llama import LlamaMLP + + return LlamaMLP( + self._hf_config.hidden_size, + self._hf_config.intermediate_size, # type: ignore + self._hf_config.hidden_act, # type: ignore + self._quant_config, + "llama_mlp", + ) + + def get_rand_input_forward_batch(self): + forward_batch = SimpleNamespace( + mock_obj_name=ForwardBatch.__name__, + forward_mode=self._bench_args.forward_mode, + attn_backend=self._attn_backend, + out_cache_loc=torch.arange( + start=self._server_args.page_size, + end=self._server_args.page_size + self._bench_args.num_tokens, + dtype=torch.int64, + ), + token_to_kv_pool=self._token_to_kv_pool, + req_to_token_pool=self._req_to_token_pool, + req_pool_indices=torch.zeros((1), dtype=torch.int64), + seq_lens=torch.tensor([self._bench_args.num_tokens], dtype=torch.int64), + seq_lens_cpu=torch.tensor( + [self._bench_args.num_tokens], dtype=torch.int64, device="cpu" + ), + batch_size=1, # TODO: Add support of batch size as a bench arg ? + extend_prefix_lens_cpu=[], + encoder_lens=None, + spec_info=None, + ) + return forward_batch + + def get_rand_input_positions(self): + positions = torch.randint( + 100, (self._bench_args.num_tokens,), dtype=torch.int64 + ) + return positions + + def get_rand_input_hidden_states(self): + hidden_states = torch.rand( + (self._bench_args.num_tokens, self._model_config.hf_config.hidden_size), + dtype=self._model_config.dtype, + ) + return hidden_states + + def get_rand_input_q(self): + num_heads = self._model_config.num_attention_heads + head_dim = self._model_config.hidden_size // num_heads + q = torch.rand( + (self._bench_args.num_tokens, num_heads * head_dim), + dtype=self._model_config.dtype, + ) + return q + + def get_rand_input_k(self): + num_heads = self._model_config.num_attention_heads + num_kv_heads = self._model_config.num_key_value_heads + head_dim = self._model_config.hidden_size // num_heads + k = torch.rand( + (self._bench_args.num_tokens, num_kv_heads * head_dim), + dtype=self._model_config.dtype, + ) + return k + + def get_rand_input_v(self): + num_heads = self._model_config.num_attention_heads + num_kv_heads = self._model_config.num_key_value_heads + head_dim = self._model_config.hidden_size // num_heads + v = torch.rand( + (self._bench_args.num_tokens, num_kv_heads * head_dim), + dtype=self._model_config.dtype, + ) + return v + + def get_rand_attn_output(self): + attn_output = torch.rand( + ( + self._bench_args.num_tokens, + self._model_config.hf_config.hidden_size, + ), + dtype=self._model_config.dtype, + ) + return attn_output + + def get_rand_residual(self): + residual = torch.rand( + (self._bench_args.num_tokens, self._model_config.hf_config.hidden_size), + dtype=self._model_config.dtype, + ) + return residual diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 743c29104b51..63e0511954b8 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -272,6 +272,7 @@ set(SOURCES "csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/pos_enc.cu" "csrc/elementwise/topk.cu" + "csrc/elementwise/rms_quant/layernorm_quant_kernels.cu" "csrc/expert_specialization/es_fp8_blockwise.cu" "csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu" "csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index cdce0064b2f0..9852a8c3fa51 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -69,6 +69,18 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + m.def( + "rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, " + "Tensor scale, float epsilon) -> " + "()"); + m.impl("rms_norm_static_fp8_quant", torch::kCUDA, &rms_norm_static_fp8_quant); + + m.def( + "fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, " + "Tensor! residual, Tensor weight, " + "Tensor scale, float epsilon) -> ()"); + m.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()"); m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); diff --git a/sgl-kernel/csrc/elementwise/rms_quant/cub_helpers.h b/sgl-kernel/csrc/elementwise/rms_quant/cub_helpers.h new file mode 100644 index 000000000000..60329c7bdbfd --- /dev/null +++ b/sgl-kernel/csrc/elementwise/rms_quant/cub_helpers.h @@ -0,0 +1,17 @@ +#pragma once + +#ifndef USE_ROCM +#include +#if CUB_VERSION >= 200800 +#include +using CubAddOp = cuda::std::plus<>; +using CubMaxOp = cuda::maximum<>; +#else // if CUB_VERSION < 200800 +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; +#endif // CUB_VERSION +#else +#include +using CubAddOp = cub::Sum; +using CubMaxOp = cub::Max; +#endif // USE_ROCM diff --git a/sgl-kernel/csrc/elementwise/rms_quant/layernorm_quant_kernels.cu b/sgl-kernel/csrc/elementwise/rms_quant/layernorm_quant_kernels.cu new file mode 100644 index 000000000000..98d157bcfbd0 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/rms_quant/layernorm_quant_kernels.cu @@ -0,0 +1,277 @@ +/* + * This file contains the CUDA kernels for the fused quantized layernorm. + * The kernels correspond to the kernels in layernorm_kernels.cu, except they + * also produce quantized output directly. + * Currently, only static fp8 quantization is supported. + */ + +#include +#include +#include + +#include + +#include "cub_helpers.h" +#include "type_convert.cuh" +#include "utils.h" + +#define AT_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ + switch (TYPE) { \ + case at::ScalarType::Float8_e4m3fn: { \ + using fp8_t = at::Float8_e4m3fn; \ + __VA_ARGS__(); \ + break; \ + } \ + /* For ROCm or future CUDA FP8 variants */ \ + case at::ScalarType::Float8_e4m3fnuz: { \ + using fp8_t = at::Float8_e4m3fnuz; \ + __VA_ARGS__(); \ + break; \ + } \ + default: \ + AT_ERROR(NAME, " not implemented for '", toString(TYPE), "'"); \ + } + +template +__device__ __forceinline__ fp8_type scaled_fp8_conversion(float val, float scale) { + float x = 0.0f; + x = val * scale; + float r = fmaxf(-FP8_E4M3_MAX, fminf(x, FP8_E4M3_MAX)); +#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) + return static_cast(r); +#else + return c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation, fp8::fp8_type::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +#endif +} + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + const float x = (float)input[blockIdx.x * input_stride + idx]; + variance += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)input[blockIdx.x * input_stride + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion(out_norm, scale_inv); + } +} + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, + const int num_tokens, + const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + const int vec_input_stride = input_stride / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int stride_id = blockIdx.x * vec_input_stride + idx; + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[stride_id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; +#pragma unroll + for (int i = 0; i < width; ++i) { + out[id * width + i] = scaled_fp8_conversion(float(temp.data[i]), scale_inv); + } + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( + fp8_type* __restrict__ out, // [..., hidden_size] + scalar_t* __restrict__ input, // [..., hidden_size] + const int input_stride, + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float* __restrict__ scale, // [1] + const float epsilon, + const int num_tokens, + const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * input_stride + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, CubAddOp{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + // invert scale to avoid division + float const scale_inv = 1.0f / *scale; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; + out[blockIdx.x * hidden_size + idx] = scaled_fp8_conversion(out_norm, scale_inv); + } +} + +void rms_norm_static_fp8_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + int hidden_size = input.size(-1); + int input_stride = input.stride(-2); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "rms_norm_static_fp8_quant", [&] { + AT_DISPATCH_FP8_TYPES(out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { + rms_norm_static_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + input_stride, + weight.data_ptr(), + scale.data_ptr(), + epsilon, + num_tokens, + hidden_size); + }); + }); +} + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + AT_DISPATCH_FLOATING_TYPES_AND2(at::kHalf, at::kBFloat16, input.scalar_type(), "fused_add_rms_norm", [&] { \ + AT_DISPATCH_FP8_TYPES(out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ + fused_add_rms_norm_static_fp8_quant_kernel<<>>( \ + out.data_ptr(), \ + input.data_ptr(), \ + input_stride, \ + residual.data_ptr(), \ + weight.data_ptr(), \ + scale.data_ptr(), \ + epsilon, \ + num_tokens, \ + hidden_size); \ + }); \ + }); +void fused_add_rms_norm_static_fp8_quant( + torch::Tensor& out, // [..., hidden_size], + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + torch::Tensor& scale, // [1] + double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(residual.is_contiguous()); + int hidden_size = input.size(-1); + int input_stride = input.stride(-2); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} diff --git a/sgl-kernel/csrc/elementwise/rms_quant/type_convert.cuh b/sgl-kernel/csrc/elementwise/rms_quant/type_convert.cuh new file mode 100644 index 000000000000..fa45192c125c --- /dev/null +++ b/sgl-kernel/csrc/elementwise/rms_quant/type_convert.cuh @@ -0,0 +1,166 @@ +#pragma once + +#include + +#ifndef USE_ROCM +#include +#include +#else +#include +#include + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if (defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { + return __half2float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) + data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) + data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 8bb8f4684999..959ca35ac862 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -129,6 +129,15 @@ int64_t cutlass_mla_get_workspace_size( void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); void sgl_fused_add_rmsnorm( torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl); +void rms_norm_static_fp8_quant( + torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); +void fused_add_rms_norm_static_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& residual, + torch::Tensor& weight, + torch::Tensor& scale, + double epsilon); void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl); void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl); void silu_and_mul(at::Tensor& out, at::Tensor& input); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index ca56261dd536..ca3435048935 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -22,11 +22,13 @@ concat_mla_k, copy_to_gpu_no_ce, downcast_fp8, + fused_add_rms_norm_static_fp8_quant, fused_add_rmsnorm, gelu_and_mul, gelu_tanh_and_mul, gemma_fused_add_rmsnorm, gemma_rmsnorm, + rms_norm_static_fp8_quant, rmsnorm, rotary_embedding, silu_and_mul, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 1ed1ae474a79..68d0196f39f5 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -165,6 +165,31 @@ def fused_add_rmsnorm( _flashinfer_norm.fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) +def rms_norm_static_fp8_quant( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float = 1e-6, +) -> None: + torch.ops.sgl_kernel.rms_norm_static_fp8_quant.default( + out, input, weight, scale, epsilon + ) + + +def fused_add_rms_norm_static_fp8_quant( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float = 1e-6, +): + torch.ops.sgl_kernel.fused_add_rms_norm_static_fp8_quant.default( + out, input, residual, weight, scale, epsilon + ) + + def gemma_rmsnorm( input: torch.Tensor, weight: torch.Tensor, diff --git a/sgl-kernel/tests/test_layernorm_quant_kernels.py b/sgl-kernel/tests/test_layernorm_quant_kernels.py new file mode 100644 index 000000000000..6c8b9f5dea2b --- /dev/null +++ b/sgl-kernel/tests/test_layernorm_quant_kernels.py @@ -0,0 +1,169 @@ +from typing import Union + +import pytest +import sgl_kernel +import torch +import torch.nn.functional as F + +DEVICE = "cuda" +FP8_DTYPE = torch.float8_e4m3fn +# maximum value for e4m3fn for clamping in kernel +FP8_E4M3_MAX = 448.0 +# FP8 is low precision, so the tolerance needs to be higher +TOLERANCE = {"atol": 1.5e-1, "rtol": 1.5e-1} +FP_TOLERANCE = {"atol": 1e-4, "rtol": 1e-4} + +# PyTorch Reference Implementations + + +def scaled_fp8_conversion_ref( + val: torch.Tensor, scale: torch.Tensor, fp8_dtype: torch.dtype +) -> torch.Tensor: + """Helper function matching the scaled_fp8_conversion device function.""" + quant_scale = 1.0 / scale + + x = val * quant_scale + + r = torch.clamp(x, min=-FP8_E4M3_MAX, max=FP8_E4M3_MAX) + + if r.dtype != fp8_dtype: + return r.to(fp8_dtype) + return r + + +def rms_norm_static_fp8_quant_ref( + out: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float, +) -> torch.Tensor: + """Pure PyTorch reference for rms_norm_static_fp8_quant_kernel.""" + # RMS Normalization + variance = (input.pow(2)).to(torch.float32).mean(dim=-1, keepdim=True) + inv_rms = torch.rsqrt(variance + epsilon).to(input.dtype) + normalized_input = input * inv_rms + + # Apply Weight + out_norm = normalized_input * weight + + # Static FP8 Quantization + fp8_dtype = out.dtype + quantized_output = scaled_fp8_conversion_ref(out_norm, scale.squeeze(), fp8_dtype) + + out.copy_(quantized_output) + return out + + +def fused_add_rms_norm_static_fp8_quant_ref( + out: torch.Tensor, + input: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + epsilon: float, +) -> torch.Tensor: + """Pure PyTorch reference for fused_add_rms_norm_static_fp8_quant_kernel.""" + # Fused Add + residual.add_(input) + norm_input = residual + + # RMS Normalization + variance = (norm_input.pow(2)).to(torch.float32).mean(dim=-1, keepdim=True) + inv_rms = torch.rsqrt(variance + epsilon).to(input.dtype) + normalized_residual = norm_input * inv_rms + + # Apply Weight + out_norm = normalized_residual * weight + + # Static FP8 Quantization + fp8_dtype = out.dtype + quantized_output = scaled_fp8_conversion_ref(out_norm, scale.squeeze(), fp8_dtype) + + out.copy_(quantized_output) + return out + + +@pytest.mark.parametrize("batch_size", [1, 2048]) +@pytest.mark.parametrize("hidden_size", [64, 128, 255, 1023, 1024, 1025, 4096]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_rms_norm_static_fp8_quant(batch_size, hidden_size, dtype): + """ + Tests the standard rms_norm_static_fp8_quant kernel against the reference. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available. Skipping kernel test.") + + epsilon = 1e-5 + + input_t = torch.randn(batch_size, hidden_size, dtype=dtype, device=DEVICE) + weight_t = torch.randn(hidden_size, dtype=dtype, device=DEVICE) + scale = torch.tensor([4.0], dtype=torch.float32, device=DEVICE) + + out_kernel = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE).to(DEVICE) + out_ref = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE).to(DEVICE) + + rms_norm_static_fp8_quant_ref(out_ref, input_t.clone(), weight_t, scale, epsilon) + + sgl_kernel.rms_norm_static_fp8_quant(out_kernel, input_t, weight_t, scale, epsilon) + + max_diff = torch.abs(out_kernel.float() - out_ref.float()).max() + + assert torch.allclose( + out_kernel.float(), + out_ref.float(), + atol=TOLERANCE["atol"], + rtol=TOLERANCE["rtol"], + equal_nan=True, + ), f"RMS Norm ({dtype}) kernel output mismatch. BS={batch_size}, HS={hidden_size}. Max diff: {max_diff.item():.10f}" + + +@pytest.mark.parametrize("batch_size", [1, 2048]) +@pytest.mark.parametrize("hidden_size", [64, 128, 255, 1023, 1024, 1025, 4096]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_fused_add_rms_norm_static_fp8_quant(batch_size, hidden_size, dtype): + """ + Tests the fused_add_rms_norm_static_fp8_quant kernel against the reference. + """ + if not torch.cuda.is_available(): + pytest.skip("CUDA not available. Skipping kernel test.") + + epsilon = 1e-6 + + input_t = torch.randn(batch_size, hidden_size, dtype=dtype, device=DEVICE) + base_residual = torch.randn_like(input_t) + weight_t = torch.randn(hidden_size, dtype=dtype, device=DEVICE) + scale = torch.tensor([4.0], dtype=torch.float32, device=DEVICE) + + residual_ref = base_residual.clone() + residual_kernel = base_residual.clone() + + out_kernel = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE).to(DEVICE) + out_ref = torch.empty((batch_size, hidden_size), dtype=FP8_DTYPE).to(DEVICE) + + fused_add_rms_norm_static_fp8_quant_ref( + out_ref, input_t.clone(), residual_ref, weight_t, scale, epsilon + ) + + sgl_kernel.fused_add_rms_norm_static_fp8_quant( + out_kernel, input_t, residual_kernel, weight_t, scale, epsilon + ) + + max_diff_fp8 = torch.abs(out_kernel.float() - out_ref.float()).max() + max_diff_fp = torch.abs(residual_kernel - residual_ref).max() + + assert torch.allclose( + out_kernel.float(), + out_ref.float(), + atol=TOLERANCE["atol"], + rtol=TOLERANCE["rtol"], + equal_nan=True, + ), f"Fused RMS Norm ({dtype}) FP8 output mismatch. BS={batch_size}, HS={hidden_size}. Max diff: {max_diff_fp8.item():.10f}" + + assert torch.allclose( + residual_kernel, + residual_ref, + atol=FP_TOLERANCE["atol"], + rtol=FP_TOLERANCE["rtol"], + equal_nan=True, + ), f"Fused RMS Norm ({dtype}) in-place residual update mismatch. BS={batch_size}, HS={hidden_size}. Max diff: {max_diff_fp.item():.10f}" diff --git a/test/srt/compilation/fusion/passes/test_fused_activation.py b/test/srt/compilation/fusion/passes/test_fused_activation.py new file mode 100644 index 000000000000..ac7ef2d7d932 --- /dev/null +++ b/test/srt/compilation/fusion/passes/test_fused_activation.py @@ -0,0 +1,100 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from typing import Optional + +import pytest +import torch +from torch._inductor.utils import run_and_get_code +from transformers import LlamaConfig + +from sglang.srt.compilation.fusion.pattern.dual_gemm_pattern import ( + _is_cutedsl_dual_gemm_available, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import ServerArgs +from sglang.test.model_bench import LlamaBench, ModelBenchArgs + + +def init_llama_mlp( + bench: LlamaBench, config: LlamaConfig, quant_config: Optional[QuantizationConfig] +) -> torch.nn.Module: + return bench.init_mlp() + + +test_data = [ + { + "models": [ + # TODO: Llama-3.2-1B passes only with torch compile mode default + # and torch._inductor.config.coordinate_descent_tuning = False + # "meta-llama/Llama-3.2-1B", + "RedHatAI/Llama-2-7b-chat-hf-FP8", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + ], + "model_initializer": init_llama_mlp, + } +] + + +test_cases = [] +for data in test_data: + for model in data["models"]: + test_cases.append((model, data["model_initializer"])) + + +@pytest.mark.parametrize("model, model_initializer", test_cases) +def test_fused_activation_pass(model, model_initializer): + server_args = ServerArgs( + model_path=model, + attention_backend="none", + enable_torch_compile=True, + enable_torch_compile_fusion=True, + disable_rmsnorm_quant_pass=True, + enable_torch_compile_graph_trace_logs=True, + nccl_port=12345 + + int(os.environ.get("PYTEST_XDIST_WORKER", "gw0").split("gw")[1]), + ) + + bench_args = ModelBenchArgs( + num_tokens=1, + forward_mode=ForwardMode.DECODE, + # use_real_weights=True, + ) + + with LlamaBench(server_args, bench_args, model_initializer) as bench: + # prepare input + hidden_states = bench.get_rand_input_hidden_states() + + # reference should be done before torch compile + ref_res = bench.model(hidden_states) + + # torch compile run + bench.torch_compile() + res, source_codes = run_and_get_code(bench.model, hidden_states) + code = "\n".join(source_codes) + + torch.testing.assert_close(ref_res, res) + + if _is_cutedsl_dual_gemm_available(): + assert "sglang.cutedsl_dual_gemm" in code + else: + assert "sglang.triton_dual_gemm" in code + + assert "sgl_kernel.silu_and_mul" not in code + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/compilation/fusion/passes/test_fusion_example.py b/test/srt/compilation/fusion/passes/test_fusion_example.py new file mode 100644 index 000000000000..e1aacd3bc1f0 --- /dev/null +++ b/test/srt/compilation/fusion/passes/test_fusion_example.py @@ -0,0 +1,193 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +from types import SimpleNamespace + +import torch +from torch._inductor.utils import run_and_get_code + +from sglang.srt.compilation.inductor_pass import SGLangPatternMatcherInductorPass + +# FusionConfig.enable_torch_compile_graph_trace_logs requires +# log level debug to print the pre and post graph changes +logging.basicConfig( + level=logging.DEBUG, + format="[%(asctime)s Fusion Example] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, +) + + +class ExampleModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward(self, gating_output: torch.Tensor, topk: int): + softmax_output = torch.softmax(gating_output, dim=-1) + topk_weights_ref, topk_indices_ref = torch.topk(softmax_output, topk, dim=-1) + return topk_weights_ref, topk_indices_ref + + +""" +Fake op registration, dynamo uses this while tracing to avoid +running the actual kernel and slow down the compilation process. + +This registration should be part of sgl-kernel and can be done +in python or C++ +""" + + +@torch.library.register_fake("sgl_kernel::topk_softmax") +def topk_softmax_fake( + topk_weights: torch.Tensor, + topk_indices: torch.Tensor, + gating_output: float, + renormalize: bool = False, +): + pass + + +class ExampleFusionPass(SGLangPatternMatcherInductorPass): + def build_pass(self): + """graph trace of Example Model obtained using TORCH_LOGS="post_grad_graphs" + def forward(self, arg0_1: "f32[1, 4][4, 1]cuda:0"): + prepare_softmax_online_default = torch.ops.prims.prepare_softmax_online.default(arg0_1, -1) + getitem_2: "f32[1, 1][1, 1]cuda:0" = prepare_softmax_online_default[0] + getitem_3: "f32[1, 1][1, 1]cuda:0" = prepare_softmax_online_default[1]; prepare_softmax_online_default = None + sub_tensor: "f32[1, 4][4, 1]cuda:0" = torch.ops.aten.sub.Tensor(arg0_1, getitem_2); arg0_1 = getitem_2 = None + exp_default: "f32[1, 4][4, 1]cuda:0" = torch.ops.aten.exp.default(sub_tensor); sub_tensor = None + + div: "f32[1, 4][4, 1]cuda:0" = torch.ops.aten.div.Tensor(exp_default, getitem_3); exp_default = getitem_3 = None + + topk = torch.ops.aten.topk.default(div, 1); div = None + getitem: "f32[1, 1][1, 1]cuda:0" = topk[0] + getitem_1: "i64[1, 1][1, 1]cuda:0" = topk[1]; topk = None + return (getitem, getitem_1) + """ + + def pattern(gating_output, topk): + prepare_softmax_online_default = ( + torch.ops.prims.prepare_softmax_online.default(gating_output, -1) + ) + sub_tensor = torch.ops.aten.sub.Tensor( + gating_output, prepare_softmax_online_default[0] + ) + exp_default = torch.ops.aten.exp.default(sub_tensor) + div = torch.ops.aten.div.Tensor( + exp_default, prepare_softmax_online_default[1] + ) + topk_op = torch.ops.aten.topk.default(div, topk) + return topk_op[0], topk_op[1] + + """ Replacement graph obtained by running topk_softmax_kernel_compiled_run with TORCH_LOGS="post_grad_graphs" + def forward(self, arg0_1: "f32[1, 4][4, 1]cuda:0"): + empty: "f32[1, 1][1, 1]cuda:0" = torch.ops.aten.empty.memory_format([1, 1], dtype = torch.float32, device = device(type='cuda'), pin_memory = False) + + empty_1: "i32[1, 1][1, 1]cuda:0" = torch.ops.aten.empty.memory_format([1, 1], dtype = torch.int32, device = device(type='cuda'), pin_memory = False) + + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.sgl_kernel.topk_softmax.default, gating_output = arg0_1, renormalize = False, _topk_weights_base_index = 0, _topk_indices_base_index = 1, _all_bases = [empty, empty_1]); arg0_1 = empty = empty_1 = None + getitem_1: "f32[1, 1][1, 1]cuda:0" = auto_functionalized_v2[1] + getitem_2: "i32[1, 1][1, 1]cuda:0" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + return (getitem_1, getitem_2) + """ + + def replacement(gating_output, topk): + empty = torch.empty( + (gating_output.shape[0], topk), dtype=torch.float32, device="cuda" + ) + empty_1 = torch.empty( + (gating_output.shape[0], topk), dtype=torch.int32, device="cuda" + ) + topk_softmax = torch.ops.higher_order.auto_functionalized_v2( + torch.ops.sgl_kernel.topk_softmax.default, + gating_output=gating_output, + renormalize=False, + _topk_weights_base_index=0, + _topk_indices_base_index=1, + _all_bases=[empty, empty_1], + ) + return topk_softmax[1], topk_softmax[2] + + """ + Input used by graph for tracing, this is passed to the fake op + The absolute shape don't matter as much as the relative shapes of the input, + In this case since there is only 1 input we only need to match no. of dims + """ + example_inputs = [torch.empty(16, 16).cuda()] + + self.register_replacement_pattern( + pattern=pattern, + replacement=replacement, + example_inputs=example_inputs, + # Handling scalars is not the cleanest I feel, this essentially requires + # passes to be registered for each expected scalar value + scalar_workaround={"topk": 2}, + ) + + +def mock_fusion_manager(graph: torch.fx.graph): + ExampleFusionPass( + fusion_config=SimpleNamespace(enable_torch_compile_graph_trace_logs=True) + )(graph) + + +def fusion_example_pass(num_experts, num_tokens, topk): + model, model_ref = ExampleModel(), ExampleModel() + + torch._inductor.config.post_grad_custom_post_pass = mock_fusion_manager + + model.compile() + + gating_output = torch.randn( + (num_tokens, num_experts), dtype=torch.float32, device="cuda" + ) + + topk_weights_ref, topk_indices_ref = model_ref(gating_output, topk) + + res, source_codes = run_and_get_code(model, gating_output, topk) + code = "\n".join(source_codes) + + assert "sgl_kernel.topk_softmax" in code + + torch.testing.assert_close(res[0], topk_weights_ref, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(res[1].int(), topk_indices_ref.int(), atol=0, rtol=0) + + +# Use this with TORCH_LOGS="post_grad_graphs" to print the fx graph +# of torch compiled kernel to figure out the replacement +def topk_softmax_kernel_compiled_run(num_experts, num_tokens, topk): + @torch.compile() + def fwd(gating_output, topk): + topk_weights = torch.empty( + (gating_output.shape[0], topk), dtype=torch.float32, device="cuda" + ) + topk_indices = torch.empty( + (gating_output.shape[0], topk), dtype=torch.int32, device="cuda" + ) + torch.ops.sgl_kernel.topk_softmax.default( + topk_weights, topk_indices, gating_output, False + ) + return topk_weights, topk_indices + + gating_output = torch.randn( + (num_tokens, num_experts), dtype=torch.float32, device="cuda" + ) + + _, _ = fwd(gating_output, topk) + + +if __name__ == "__main__": + # topk_softmax_kernel_compiled_run(8, 16, 2) + fusion_example_pass(8, 16, 2) diff --git a/test/srt/compilation/fusion/passes/test_rmsnorm_quant.py b/test/srt/compilation/fusion/passes/test_rmsnorm_quant.py new file mode 100644 index 000000000000..de764d6fe2a5 --- /dev/null +++ b/test/srt/compilation/fusion/passes/test_rmsnorm_quant.py @@ -0,0 +1,113 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +from typing import Optional + +import pytest +import torch +from torch._inductor.utils import run_and_get_code +from transformers import LlamaConfig + +from sglang.srt.compilation.fusion.pattern.rmsnorm_quant_fp8_pattern import ( + _is_jit_rmsnorm_quant_available, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import is_flashinfer_rmsnorm_quant_kernels_available +from sglang.test.model_bench import LlamaBench, ModelBenchArgs + + +def init_llama_decoder( + bench: LlamaBench, config: LlamaConfig, quant_config: Optional[QuantizationConfig] +) -> torch.nn.Module: + return bench.init_decoder() + + +test_data = [ + { + "models": [ + "RedHatAI/Llama-2-7b-chat-hf-FP8", + "RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8", + ], + "model_initializer": init_llama_decoder, + } +] + + +test_cases = [] +for data in test_data: + for model in data["models"]: + test_cases.append((model, data["model_initializer"])) + + +@pytest.mark.parametrize("model, model_initializer", test_cases) +def test_rmsnorm_quant_pass(model, model_initializer): + server_args = ServerArgs( + model_path=model, + attention_backend="dummy", + enable_torch_compile=True, + enable_torch_compile_fusion=True, + disable_fused_activation_pass=True, + enable_torch_compile_graph_trace_logs=True, + nccl_port=12345 + + int(os.environ.get("PYTEST_XDIST_WORKER", "gw0").split("gw")[1]), + ) + + bench_args = ModelBenchArgs( + num_tokens=1, + forward_mode=ForwardMode.DECODE, + ) + + with LlamaBench(server_args, bench_args, model_initializer) as bench: + # prepare input + positions = bench.get_rand_input_positions() + hidden_states = bench.get_rand_input_hidden_states() + forward_batch = bench.get_rand_input_forward_batch() + + # reference should be done before torch compile + ref_res = bench.model(positions, hidden_states, forward_batch, None) + + # torch compile run + bench.torch_compile() + res, source_codes = run_and_get_code( + bench.model, positions, hidden_states, forward_batch, None + ) + code = "\n".join(source_codes) + + torch.testing.assert_close(ref_res, res) + + if _is_jit_rmsnorm_quant_available(): + assert "sglang.jit_rmsnorm_quant" in code + assert "sgl_kernel.rmsnorm" not in code + + assert "sglang.jit_fused_add_rmsnorm_quant" in code + assert "sgl_kernel.fused_add_rmsnorm" not in code + elif is_flashinfer_rmsnorm_quant_kernels_available(): + assert "sglang.flashinfer_rmsnorm_quant" in code + assert "sgl_kernel.rmsnorm" not in code + + assert "sglang.flashinfer_fused_add_rmsnorm_quant" in code + assert "sgl_kernel.fused_add_rmsnorm" not in code + else: + assert "sgl_kernel.rms_norm_static_fp8_quant" in code + assert "sgl_kernel.rmsnorm" not in code + + assert "sgl_kernel.fused_add_rms_norm_static_fp8_quant" in code + assert "sgl_kernel.fused_add_rmsnorm" not in code + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/srt/compilation/fusion/triton_ops/test_triton_fused_dual_gemm.py b/test/srt/compilation/fusion/triton_ops/test_triton_fused_dual_gemm.py new file mode 100644 index 000000000000..9f8898aeb1d4 --- /dev/null +++ b/test/srt/compilation/fusion/triton_ops/test_triton_fused_dual_gemm.py @@ -0,0 +1,87 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import math +import random +from typing import Tuple + +import numpy as np +import pytest +import torch + +from sglang.srt.compilation.fusion.triton_ops.dual_gemm import dual_gemm_fwd + + +def seed_rng(): + SEED = 42 + np.random.seed(SEED) # For Numpy + torch.manual_seed(SEED) # For CPU tensors + torch.cuda.manual_seed_all(SEED) # For CUDA tensors + random.seed(SEED) # For Python's own RNG + + +@pytest.fixture(autouse=True, scope="module") +def module_fixture(): + seed_rng() + + +def make_input_and_weights( + batch_size: int, seq_len: int, d_model: int, d_intermediate: int, dtype: torch.dtype +) -> Tuple[torch.HalfTensor, torch.HalfTensor, torch.HalfTensor]: + x = torch.randn( + (batch_size * seq_len, d_model), device="cuda", dtype=dtype + ).requires_grad_(False) + + w_gate = torch.randn((d_model, d_intermediate), device="cuda", dtype=dtype) / ( + math.sqrt(d_model) + ) + w_up = torch.randn((d_model, d_intermediate), device="cuda", dtype=dtype) / ( + math.sqrt(d_model) + ) + + w_gate.requires_grad_(False) + w_up.requires_grad_(False) + + w = torch.concat((w_gate, w_up), dim=1) + + return x, w + + +def dual_gemm_ref_torch( + x: torch.HalfTensor, w_gate: torch.HalfTensor, w_up: torch.HalfTensor +) -> torch.HalfTensor: + """Reference PyTorch implementation.""" + x_gate = torch.matmul(x, w_gate) + x_up = torch.matmul(x, w_up) + + return torch.nn.functional.silu(x_gate) * x_up + + +@pytest.mark.parametrize("d_model", [1024, 2048, 4096]) +@pytest.mark.parametrize("d_intermediate", [1024, 2048, 4096, 8192]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seq_len", [1, 32, 128]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16, torch.float16], ids=["bfloat16", "float16"] +) +def test_dual_gemm(d_model, d_intermediate, batch_size, seq_len, dtype): + x, w = make_input_and_weights(batch_size, seq_len, d_model, d_intermediate, dtype) + w_gate, w_up = torch.split(w, w.shape[1] // 2, dim=1) + out_ref = dual_gemm_ref_torch(x, w_gate, w_up) + out = dual_gemm_fwd(x, w) + torch.testing.assert_close(out, out_ref, atol=1e-3, rtol=2e-2) + + +if __name__ == "__main__": + pytest.main([__file__])