Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions sgl-kernel/benchmark/bench_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import argparse
import copy
import itertools
from typing import Optional, Tuple

import torch
import triton
from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm
from sgl_kernel import sgl_per_tensor_quant_fp8
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant

Expand Down Expand Up @@ -69,6 +71,21 @@
}


def sglang_scaled_fp8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
fp8_type_: torch.dtype = torch.float8_e4m3fn
output = torch.empty_like(input, device=input.device, dtype=fp8_type_)
is_static = True
if scale is None:
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
is_static = False
sgl_per_tensor_quant_fp8(input, output, scale, is_static)

return output, scale


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
Expand Down Expand Up @@ -100,19 +117,22 @@ def benchmark(batch_size, provider, N, K):
b = torch.ones((N, K), device="cuda") * 5.0
scale_a = torch.randn((M,), device="cuda", dtype=torch.float32)
scale_b = torch.randn((N,), device="cuda", dtype=torch.float32)
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
quantiles = [0.5, 0.2, 0.8]

dtype = torch.float16 if "fp16" in provider else torch.bfloat16

if "vllm-fp8" in provider:
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
elif "sglang-fp8" in provider:
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
Comment on lines 133 to 135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider extracting the quantization and transpose operations outside the if/elif block to reduce code duplication. This improves maintainability by adhering to the DRY principle.

a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()

ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
Expand Down
150 changes: 126 additions & 24 deletions sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ limitations under the License.
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>

#include "math.hpp"
#include "utils.h"

using namespace cute;
Expand Down Expand Up @@ -1019,8 +1020,18 @@ void sm100_fp8_dispatch_bias(
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const c10::optional<torch::Tensor>& bias) {
using CTAShape = Shape<_256, _128, _64>;
using ClusterShape = Shape<_2, _2, _1>;
using CTAShapeDefault = Shape<_256, _128, _64>;
using ClusterShapeDefault = Shape<_2, _2, _1>;

using CTAShape256 = Shape<_128, _128, _128>;
using ClusterShape256 = Shape<_2, _1, _1>;

using CTAShape64 = Shape<_64, _64, _128>;
using ClusterShape64 = Shape<_1, _1, _1>;

using CTAShape16 = Shape<_64, _64, _128>;
using ClusterShape16 = Shape<_1, _4, _1>;

using MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto;
using EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto;
using TileSchedulerType = void;
Expand All @@ -1029,30 +1040,121 @@ void sm100_fp8_dispatch_bias(
using ElementOutput = OutType;
using AccumElementType = float;

// Gemm type with bias
using BiasGemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
using BiasGemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;

// Gemm type without bias
using GemmDefault = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShapeDefault,
ClusterShapeDefault,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm256 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape256,
ClusterShape256,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm64 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape64,
ClusterShape64,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
using Gemm16 = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape16,
ClusterShape16,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;

// next power of 2 (minimum 16)
uint32_t const m = a.size(0);
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));

if (bias) {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
true>;
return launch_sm100_fp8_scaled_mm<Gemm, true>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<BiasGemm16, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<BiasGemm64, true>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<BiasGemm256, true>(out, a, b, scales_a, scales_b, bias);
} else {
// m in (256, inf]
return launch_sm100_fp8_scaled_mm<BiasGemmDefault, true>(out, a, b, scales_a, scales_b, bias);
}
} else {
using Gemm = DeviceGemmFp8RowwiseSm100<
ElementInput,
ElementOutput,
AccumElementType,
CTAShape,
ClusterShape,
MainloopScheduleType,
EpilogueScheduleType,
TileSchedulerType,
false>;
return launch_sm100_fp8_scaled_mm<Gemm, false>(out, a, b, scales_a, scales_b, bias);
if (mp2 <= 16) {
// m in [1, 16]
return launch_sm100_fp8_scaled_mm<Gemm16, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 64) {
// m in (16, 64]
return launch_sm100_fp8_scaled_mm<Gemm64, false>(out, a, b, scales_a, scales_b, bias);
} else if (mp2 <= 256) {
// m in (64, 256]
return launch_sm100_fp8_scaled_mm<Gemm256, false>(out, a, b, scales_a, scales_b, bias);
} else {
return launch_sm100_fp8_scaled_mm<GemmDefault, false>(out, a, b, scales_a, scales_b, bias);
}
}
}

Expand Down
28 changes: 28 additions & 0 deletions sgl-kernel/csrc/gemm/math.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <climits>
#include <iostream>

inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}

template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The static keyword is unnecessary for a constexpr template function. Removing it will prevent potential linking issues if this header is included in multiple compilation units.

Suggested change
static inline constexpr auto div_ceil(A a, B b) {
inline constexpr auto div_ceil(A a, B b) {

return (a + b - 1) / b;
}

// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}

// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}
Loading