diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index c3f804753568..5f16ca0284b0 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -1,10 +1,12 @@ import argparse import copy import itertools +from typing import Optional, Tuple import torch import triton from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from sgl_kernel import sgl_per_tensor_quant_fp8 from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant @@ -69,6 +71,21 @@ } +def sglang_scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + fp8_type_: torch.dtype = torch.float8_e4m3fn + output = torch.empty_like(input, device=input.device, dtype=fp8_type_) + is_static = True + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + is_static = False + sgl_per_tensor_quant_fp8(input, output, scale, is_static) + + return output, scale + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], @@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K): b = torch.ones((N, K), device="cuda") * 5.0 scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - b_fp8 = b_fp8.t() quantiles = [0.5, 0.2, 0.8] dtype = torch.float16 if "fp16" in provider else torch.bfloat16 if "vllm-fp8" in provider: + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() ms, min_ms, max_ms = triton.testing.do_bench( lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), quantiles=quantiles, ) elif "sglang-fp8" in provider: + a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() ms, min_ms, max_ms = triton.testing.do_bench( lambda: sgl_scaled_mm( a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None diff --git a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index 0d25e9985f7e..77b5c500f04e 100644 --- a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -48,6 +48,7 @@ limitations under the License. #include #include +#include "math.hpp" #include "utils.h" using namespace cute; @@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias( const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { - using CTAShape = Shape<_256, _128, _64>; - using ClusterShape = Shape<_2, _2, _1>; + using CTAShapeDefault = Shape<_256, _128, _64>; + using ClusterShapeDefault = Shape<_2, _2, _1>; + + using CTAShape256 = Shape<_128, _128, _128>; + using ClusterShape256 = Shape<_2, _1, _1>; + + using CTAShape64 = Shape<_64, _64, _128>; + using ClusterShape64 = Shape<_1, _1, _1>; + + using CTAShape16 = Shape<_64, _64, _128>; + using ClusterShape16 = Shape<_1, _4, _1>; + using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto; using TileSchedulerType = void; @@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias( using ElementOutput = OutType; using AccumElementType = float; + // Gemm type with bias + using BiasGemmDefault = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShapeDefault, + ClusterShapeDefault, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>; + using BiasGemm256 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape256, + ClusterShape256, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>; + using BiasGemm64 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape64, + ClusterShape64, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>; + using BiasGemm16 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape16, + ClusterShape16, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + true>; + + // Gemm type without bias + using GemmDefault = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShapeDefault, + ClusterShapeDefault, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>; + using Gemm256 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape256, + ClusterShape256, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>; + using Gemm64 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape64, + ClusterShape64, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>; + using Gemm16 = DeviceGemmFp8RowwiseSm100< + ElementInput, + ElementOutput, + AccumElementType, + CTAShape16, + ClusterShape16, + MainloopScheduleType, + EpilogueScheduleType, + TileSchedulerType, + false>; + + // next power of 2 (minimum 16) + uint32_t const m = a.size(0); + uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); + if (bias) { - using Gemm = DeviceGemmFp8RowwiseSm100< - ElementInput, - ElementOutput, - AccumElementType, - CTAShape, - ClusterShape, - MainloopScheduleType, - EpilogueScheduleType, - TileSchedulerType, - true>; - return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + if (mp2 <= 16) { + // m in [1, 16] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 64) { + // m in (16, 64] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 256) { + // m in (64, 256] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + // m in (256, inf] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } } else { - using Gemm = DeviceGemmFp8RowwiseSm100< - ElementInput, - ElementOutput, - AccumElementType, - CTAShape, - ClusterShape, - MainloopScheduleType, - EpilogueScheduleType, - TileSchedulerType, - false>; - return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + if (mp2 <= 16) { + // m in [1, 16] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 64) { + // m in (16, 64] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 256) { + // m in (64, 256] + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm100_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } } } diff --git a/sgl-kernel/csrc/gemm/math.hpp b/sgl-kernel/csrc/gemm/math.hpp new file mode 100644 index 000000000000..6764e1fd6054 --- /dev/null +++ b/sgl-kernel/csrc/gemm/math.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +inline constexpr uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +}