Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions benchmark/kernels/quantization/bench_fp4_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse
import itertools

import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul

from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper


def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
)

torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")


NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)

quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)

return ms, min_ms, max_ms


def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()

test_accuracy()

benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
118 changes: 118 additions & 0 deletions sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
#endif
}

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
#else
cvt_fp16_to_fp4_expert(
#endif
int32_t numRows,
int32_t numCols,
Type const* in,
float const* SFScale,
uint32_t* out,
uint32_t* SFout,
uint32_t* input_offset_by_experts,
uint32_t* output_scale_offset_by_experts,
int32_t* mask,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");

// Input tensor row/col loops.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = (gridDim.x * blockDim.x) / n_experts;
int remainder = (gridDim.x * blockDim.x) % n_experts;
int expert_idx;
int tid_in_expert;
int actual_stride;
if (remainder > 0) {
int bound = remainder * (stride + 1);
if (tid < bound) {
expert_idx = tid / (stride + 1);
tid_in_expert = tid % (stride + 1);
actual_stride = stride + 1;
} else {
expert_idx = remainder + (tid - bound) / stride;
tid_in_expert = (tid - bound) % stride;
actual_stride = stride;
}
} else {
expert_idx = tid / stride;
tid_in_expert = tid % stride;
actual_stride = stride;
}

int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool use_mask = mask != nullptr;
int actualColsPerRow = use_mask ? colsPerRow * 2 : colsPerRow;

// Each global thread processes one element
for (int globalIdx = tid_in_expert + input_offset_by_experts[expert_idx] * colsPerRow;
globalIdx < input_offset_by_experts[expert_idx + 1] * colsPerRow;
globalIdx += actual_stride) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;

// Find index within the experts
int rowIdx_in_expert = rowIdx - input_offset_by_experts[expert_idx];

// Eerly exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
break;
}

int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_mask) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}

// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];

// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];

int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert = SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout;

auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);

out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
#endif
}

// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
Expand Down Expand Up @@ -522,6 +623,23 @@ void quant_impl(
block.x = (block.x + 1) / 2;
}

// TODO(kaixih@nvidia): Should relax this to allow any grid size.
if (mask != nullptr) {
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
m_topk,
k,
reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<uint32_t*>(input_offset_by_experts),
reinterpret_cast<uint32_t*>(output_scale_offset_by_experts),
reinterpret_cast<int32_t*>(mask),
n_experts);
return;
}

int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
Expand Down
19 changes: 11 additions & 8 deletions sgl-kernel/tests/test_fp4_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,12 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
def test_quantize_to_fp4_grouped():
@pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)])
def test_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")

l, m, k = 2, 512, 2048
l, m, k = shape
x = torch.randn((l, m, k), dtype=torch.bfloat16)
tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32)
x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
Expand All @@ -196,22 +197,24 @@ def test_quantize_to_fp4_grouped():
for i in range(l):
a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i])
torch.testing.assert_close(a_fp4, output[i])
torch.testing.assert_close(
a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float)
)
# Recover swizzled scales to linear layout and drop padded values, so
# no extra checks on padding are needed.
scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k)
scale_ans = recover_swizzled_scales(output_scales[i], m, k)
torch.testing.assert_close(scale_ref, scale_ans)


@pytest.mark.skipif(
skip_condition, reason="Nvfp4 Requires compute capability of 10 or above."
)
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None:
@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)])
def test_silu_and_mul_quantize_to_fp4_grouped(shape):
torch.manual_seed(42)
torch.set_default_device("cuda:0")

l, m, k = shape
x = torch.randn((l, m, k * 2), dtype=torch.bfloat16)
max_m = 8
max_m = m // 2
assert max_m <= m
mask = torch.randint(1, max_m, (l,), dtype=torch.int32)

Expand Down
Loading