Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions benchmark/kernels/quantization/bench_fp4_quant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import argparse
import itertools

import torch
import triton
from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant
from sgl_kernel.elementwise import silu_and_mul

from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd
from sglang.srt.layers.quantization import deep_gemm_wrapper


def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks)
out1, blk_scales1 = scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
)

torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")


NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,
)

quantiles = [0.5, 0.2, 0.8]
if provider == "triton_fp8":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: scaled_fp4_grouped_quant(
silu_and_mul(x),
glb_scales,
masks,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: silu_and_mul_scaled_fp4_grouped_quant(
x,
glb_scales,
masks,
),
quantiles=quantiles,
)

return ms, min_ms, max_ms


def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()

test_accuracy()

benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
3 changes: 1 addition & 2 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {

m.def(
"silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
"Tensor output_scale_offset_by_experts, Tensor mask) -> ()");
"Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()");
m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant);

m.def(
Expand Down
153 changes: 134 additions & 19 deletions sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ cvt_fp16_to_fp4(
}
}

// Eerly exit when using masks.
// Early exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
continue;
}
Expand Down Expand Up @@ -383,6 +383,107 @@ cvt_fp16_to_fp4(
#endif
}

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert(
#else
cvt_fp16_to_fp4_expert(
#endif
int32_t numRows,
int32_t numCols,
Type const* in,
float const* SFScale,
uint32_t* out,
uint32_t* SFout,
int32_t* mask,
bool use_silu_and_mul,
int n_experts) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using PackedVec = PackedVec<Type>;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched.");

// Input tensor row/col loops.
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = (gridDim.x * blockDim.x) / n_experts;
int remainder = (gridDim.x * blockDim.x) % n_experts;
int expert_idx;
int tid_in_expert;
int actual_stride;
if (remainder > 0) {
int bound = remainder * (stride + 1);
if (tid < bound) {
expert_idx = tid / (stride + 1);
tid_in_expert = tid % (stride + 1);
actual_stride = stride + 1;
} else {
expert_idx = remainder + (tid - bound) / stride;
tid_in_expert = (tid - bound) % stride;
actual_stride = stride;
}
} else {
expert_idx = tid / stride;
tid_in_expert = tid % stride;
actual_stride = stride;
}
int m = numRows / n_experts;
int padded_m = (m + (128 - 1)) / 128 * 128;

int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD;
// TODO(kaixih@nvidia): For now, we assume mask is used together with
// silu_and_mal. Maybe we want a more general behavior of mask later. In the
// silu case, the input last dim doubles.
bool use_mask = mask != nullptr;
int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow;

// Each global thread processes one element
for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow;
globalIdx += actual_stride) {
// Calculate which row and column this global thread should process
int rowIdx = globalIdx / colsPerRow;
int colIdx = globalIdx % colsPerRow;

// Find index within the experts
int rowIdx_in_expert = rowIdx - expert_idx * m;

// Early exit when using masks.
if (use_mask && rowIdx_in_expert >= mask[expert_idx]) {
break;
}

int64_t inOffset = rowIdx * actualColsPerRow + colIdx;
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
if (use_silu_and_mul) {
PackedVec in_vec_mul = reinterpret_cast<PackedVec const*>(in)[inOffset + colsPerRow];
silu_and_mul(in_vec, in_vec_mul);
}

// Get the output tensor offset.
// Same as inOffset because 8 elements are packed into one uint32_t.
int64_t outOffset = rowIdx * colsPerRow + colIdx;
auto& out_pos = out[outOffset];

// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx];

int factor = CVT_FP4_SF_VEC_SIZE * 4;
// The actual output_scales dim is computed from the padded numCols.
int32_t numCols_padded = (numCols + factor - 1) / factor * factor;
int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4;
uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout;

auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<uint32_t, CVT_FP4_NUM_THREADS_PER_SF>(
rowIdx_in_expert, colIdx, numCols, SFout_in_expert);

out_pos = cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
}
#endif
}

// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
__global__ void
Expand Down Expand Up @@ -499,6 +600,7 @@ void quant_impl(
void* input_offset_by_experts,
void* output_scale_offset_by_experts,
void* mask,
bool use_silu_and_mul,
int m_topk,
int k,
int n_experts,
Expand All @@ -522,6 +624,22 @@ void quant_impl(
block.x = (block.x + 1) / 2;
}

// TODO(kaixih@nvidia): Should relax this to allow any grid size.
if (mask != nullptr) {
grid.x = (grid.x + n_experts - 1) / n_experts * n_experts;
cvt_fp16_to_fp4_expert<T, false><<<grid, block, 0, stream>>>(
m_topk,
k,
reinterpret_cast<T*>(input),
reinterpret_cast<float*>(input_global_scale),
reinterpret_cast<uint32_t*>(output),
reinterpret_cast<uint32_t*>(output_scale),
reinterpret_cast<int32_t*>(mask),
use_silu_and_mul,
n_experts);
return;
}

int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x);
if (blockRepeat > 1) {
size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t);
Expand Down Expand Up @@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
false, // use_silu_and_mul
m_topk,
k,
n_experts,
Expand All @@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a(
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // mask
false, // use_silu_and_mul
m_topk,
k,
n_experts,
Expand All @@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
torch::Tensor& output_scale,
torch::Tensor const& input,
torch::Tensor const& input_global_scale,
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts,
torch::Tensor const& mask) {
torch::Tensor const& mask,
bool use_silu_and_mul) {
CHECK_INPUT(output, "output must be a CUDA tensor");
CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor");
CHECK_INPUT(input, "input must be a CUDA tensor");
CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor");
CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor");
CHECK_INPUT(mask, "mask must be a CUDA tensor");

TORCH_CHECK(output.dim() == 2);
TORCH_CHECK(output_scale.dim() == 2);
TORCH_CHECK(input.dim() == 2);
TORCH_CHECK(input_global_scale.dim() == 1);
TORCH_CHECK(input_offset_by_experts.dim() == 1);
TORCH_CHECK(output_scale_offset_by_experts.dim() == 1);

TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16);
TORCH_CHECK(input_global_scale.scalar_type() == FLOAT);
TORCH_CHECK(input_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT);
TORCH_CHECK(mask.scalar_type() == INT);
// output is uint8 (two nvfp4 values are packed into one uint8)
// output_scale is int32 (four fp8 values are packed into one int32)
Expand All @@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
const int BLOCK_SIZE = 16;
auto m_topk = input.size(0);
auto k_by_2 = input.size(1);
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
auto k = k_by_2 / 2;
TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16");
auto k = k_by_2;
if (use_silu_and_mul) {
TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2");
k = k_by_2 / 2;
}
auto n_experts = input_global_scale.size(0);
TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1);
TORCH_CHECK(mask.size(0) == n_experts);
TORCH_CHECK(output.size(0) == m_topk);
TORCH_CHECK(output.size(1) == k / 2);
Expand All @@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // input_offset_by_experts
nullptr, // output_scale_offset_by_experts
mask.data_ptr(),
use_silu_and_mul,
m_topk,
k,
n_experts,
Expand All @@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a(
output_scale.data_ptr(),
input.data_ptr(),
input_global_scale.data_ptr(),
input_offset_by_experts.data_ptr(),
output_scale_offset_by_experts.data_ptr(),
nullptr, // input_offset_by_experts
nullptr, // output_scale_offset_by_experts
mask.data_ptr(),
use_silu_and_mul,
m_topk,
k,
n_experts,
Expand Down
Loading
Loading