diff --git a/benchmark/kernels/quantization/bench_fp4_quant.py b/benchmark/kernels/quantization/bench_fp4_quant.py new file mode 100644 index 00000000000..318e820adda --- /dev/null +++ b/benchmark/kernels/quantization/bench_fp4_quant.py @@ -0,0 +1,133 @@ +import argparse +import itertools + +import torch +import triton +from sgl_kernel import scaled_fp4_grouped_quant, silu_and_mul_scaled_fp4_grouped_quant +from sgl_kernel.elementwise import silu_and_mul + +from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd +from sglang.srt.layers.quantization import deep_gemm_wrapper + + +def _test_accuracy_once(E, M, K, input_dtype, device): + x = torch.randn(E, M, K, device=device, dtype=input_dtype) + glb_scales = torch.ones((E,), dtype=torch.float32, device=device) + masks = torch.full((E,), M, dtype=torch.int32, device=device) + out, blk_scales = silu_and_mul_scaled_fp4_grouped_quant(x, glb_scales, masks) + out1, blk_scales1 = scaled_fp4_grouped_quant( + silu_and_mul(x), + glb_scales, + masks, + ) + + torch.testing.assert_close(out, out1) + torch.testing.assert_close(blk_scales, blk_scales1) + print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK") + + +NUM_RANKS = 48 +M_PER_RANKs = [128, 256, 512, 1024] +Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs] +Ks = [2048, 4096, 7168] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "K"], + x_vals=list(itertools.product(Ms, Ks)), + x_log=False, + line_arg="provider", + line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], + line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], + styles=[("blue", "-"), ("orange", "-"), ("green", "-")], + ylabel="ms", + plot_name="fp4 quant", + args={}, + ) +) +def benchmark(M, K, provider): + E = 6 + device = "cuda" + x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16) + glb_scales = torch.ones((E,), dtype=torch.float32, device=device) + masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device) + fp8_out = torch.empty( + ( + x.shape[0], + x.shape[1], + x.shape[2] // 2, + ), + device=x.device, + dtype=torch.float8_e4m3fn, + ) + scale_block_size = 128 + fp8_scales = torch.empty( + ( + x.shape[0], + x.shape[1], + x.shape[2] // 2 // scale_block_size, + ), + device=x.device, + dtype=torch.float32, + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "triton_fp8": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: silu_and_mul_masked_post_quant_fwd( + x, + fp8_out, + fp8_scales, + scale_block_size, + masks, + scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, + ), + quantiles=quantiles, + ) + if provider == "cuda_unfused_fp4": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: scaled_fp4_grouped_quant( + silu_and_mul(x), + glb_scales, + masks, + ), + quantiles=quantiles, + ) + if provider == "cuda_fused_fp4": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: silu_and_mul_scaled_fp4_grouped_quant( + x, + glb_scales, + masks, + ), + quantiles=quantiles, + ) + + return ms, min_ms, max_ms + + +def test_accuracy(): + E = 6 + N_RANKS = 48 + Ms = [128, 256, 512, 1024] + Ks = [2048, 4096, 7168] + input_dtype = torch.bfloat16 + for M in Ms: + for K in Ks: + _test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./bench_fp4_quant_res", + help="Path to save fp4 quant benchmark results", + ) + args = parser.parse_args() + + test_accuracy() + + benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index c204dc1513b..8ff06f45421 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -159,8 +159,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.def( "silu_and_mul_scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," - "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," - "Tensor output_scale_offset_by_experts, Tensor mask) -> ()"); + "Tensor input, Tensor input_global_scale, Tensor mask, bool use_silu_and_mul) -> ()"); m.impl("silu_and_mul_scaled_fp4_experts_quant", torch::kCUDA, &silu_and_mul_scaled_fp4_experts_quant); m.def( diff --git a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu index 3f996f66852..38d94365327 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu @@ -347,7 +347,7 @@ cvt_fp16_to_fp4( } } - // Eerly exit when using masks. + // Early exit when using masks. if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { continue; } @@ -383,6 +383,107 @@ cvt_fp16_to_fp4( #endif } +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4_expert( +#else +cvt_fp16_to_fp4_expert( +#endif + int32_t numRows, + int32_t numCols, + Type const* in, + float const* SFScale, + uint32_t* out, + uint32_t* SFout, + int32_t* mask, + bool use_silu_and_mul, + int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, "Vec size is not matched."); + + // Input tensor row/col loops. + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = (gridDim.x * blockDim.x) / n_experts; + int remainder = (gridDim.x * blockDim.x) % n_experts; + int expert_idx; + int tid_in_expert; + int actual_stride; + if (remainder > 0) { + int bound = remainder * (stride + 1); + if (tid < bound) { + expert_idx = tid / (stride + 1); + tid_in_expert = tid % (stride + 1); + actual_stride = stride + 1; + } else { + expert_idx = remainder + (tid - bound) / stride; + tid_in_expert = (tid - bound) % stride; + actual_stride = stride; + } + } else { + expert_idx = tid / stride; + tid_in_expert = tid % stride; + actual_stride = stride; + } + int m = numRows / n_experts; + int padded_m = (m + (128 - 1)) / 128 * 128; + + int colsPerRow = numCols / CVT_FP4_ELTS_PER_THREAD; + // TODO(kaixih@nvidia): For now, we assume mask is used together with + // silu_and_mal. Maybe we want a more general behavior of mask later. In the + // silu case, the input last dim doubles. + bool use_mask = mask != nullptr; + int actualColsPerRow = use_silu_and_mul ? colsPerRow * 2 : colsPerRow; + + // Each global thread processes one element + for (int globalIdx = tid_in_expert + expert_idx * m * colsPerRow; globalIdx < (expert_idx + 1) * m * colsPerRow; + globalIdx += actual_stride) { + // Calculate which row and column this global thread should process + int rowIdx = globalIdx / colsPerRow; + int colIdx = globalIdx % colsPerRow; + + // Find index within the experts + int rowIdx_in_expert = rowIdx - expert_idx * m; + + // Early exit when using masks. + if (use_mask && rowIdx_in_expert >= mask[expert_idx]) { + break; + } + + int64_t inOffset = rowIdx * actualColsPerRow + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + if (use_silu_and_mul) { + PackedVec in_vec_mul = reinterpret_cast(in)[inOffset + colsPerRow]; + silu_and_mul(in_vec, in_vec_mul); + } + + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = rowIdx * colsPerRow + colIdx; + auto& out_pos = out[outOffset]; + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = SFout + expert_idx * padded_m * numCols_SFout; + + auto sf_out = cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } +#endif +} + // Kernel for LARGE_M_TOPK = true (large m_topk optimized version) template __global__ void @@ -499,6 +600,7 @@ void quant_impl( void* input_offset_by_experts, void* output_scale_offset_by_experts, void* mask, + bool use_silu_and_mul, int m_topk, int k, int n_experts, @@ -522,6 +624,22 @@ void quant_impl( block.x = (block.x + 1) / 2; } + // TODO(kaixih@nvidia): Should relax this to allow any grid size. + if (mask != nullptr) { + grid.x = (grid.x + n_experts - 1) / n_experts * n_experts; + cvt_fp16_to_fp4_expert<<>>( + m_topk, + k, + reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(mask), + use_silu_and_mul, + n_experts); + return; + } + int const blockRepeat = (totalWorkSize + block.x * grid.x - 1) / (block.x * grid.x); if (blockRepeat > 1) { size_t shared_mem_size = (n_experts + 1) * sizeof(uint32_t); @@ -652,6 +770,7 @@ void scaled_fp4_experts_quant_sm100a( input_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(), nullptr, // mask + false, // use_silu_and_mul m_topk, k, n_experts, @@ -665,6 +784,7 @@ void scaled_fp4_experts_quant_sm100a( input_offset_by_experts.data_ptr(), output_scale_offset_by_experts.data_ptr(), nullptr, // mask + false, // use_silu_and_mul m_topk, k, n_experts, @@ -679,28 +799,21 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, - torch::Tensor const& mask) { + torch::Tensor const& mask, + bool use_silu_and_mul) { CHECK_INPUT(output, "output must be a CUDA tensor"); CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); CHECK_INPUT(input, "input must be a CUDA tensor"); CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); - CHECK_INPUT(input_offset_by_experts, "input_offset_by_experts must be a CUDA tensor"); - CHECK_INPUT(output_scale_offset_by_experts, "output_scale_offset_by_experts must be a CUDA tensor"); CHECK_INPUT(mask, "mask must be a CUDA tensor"); TORCH_CHECK(output.dim() == 2); TORCH_CHECK(output_scale.dim() == 2); TORCH_CHECK(input.dim() == 2); TORCH_CHECK(input_global_scale.dim() == 1); - TORCH_CHECK(input_offset_by_experts.dim() == 1); - TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); - TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); - TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); TORCH_CHECK(mask.scalar_type() == INT); // output is uint8 (two nvfp4 values are packed into one uint8) // output_scale is int32 (four fp8 values are packed into one int32) @@ -710,12 +823,12 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( const int BLOCK_SIZE = 16; auto m_topk = input.size(0); auto k_by_2 = input.size(1); - TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2"); - auto k = k_by_2 / 2; - TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto k = k_by_2; + if (use_silu_and_mul) { + TORCH_CHECK(k_by_2 % 2 == 0, "k must be a multiple of 2"); + k = k_by_2 / 2; + } auto n_experts = input_global_scale.size(0); - TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); - TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); TORCH_CHECK(mask.size(0) == n_experts); TORCH_CHECK(output.size(0) == m_topk); TORCH_CHECK(output.size(1) == k / 2); @@ -734,9 +847,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( output_scale.data_ptr(), input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), + nullptr, // input_offset_by_experts + nullptr, // output_scale_offset_by_experts mask.data_ptr(), + use_silu_and_mul, m_topk, k, n_experts, @@ -747,9 +861,10 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( output_scale.data_ptr(), input.data_ptr(), input_global_scale.data_ptr(), - input_offset_by_experts.data_ptr(), - output_scale_offset_by_experts.data_ptr(), + nullptr, // input_offset_by_experts + nullptr, // output_scale_offset_by_experts mask.data_ptr(), + use_silu_and_mul, m_topk, k, n_experts, diff --git a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu index 335fd512a8d..d960aa73017 100644 --- a/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu +++ b/sgl-kernel/csrc/gemm/nvfp4_quant_entry.cu @@ -32,9 +32,8 @@ void silu_and_mul_scaled_fp4_experts_quant_sm100a( torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, - torch::Tensor const& mask); + torch::Tensor const& mask, + bool use_silu_and_mul); #endif @@ -65,12 +64,11 @@ void silu_and_mul_scaled_fp4_experts_quant( torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, - torch::Tensor const& mask) { + torch::Tensor const& mask, + bool use_silu_and_mul) { #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return silu_and_mul_scaled_fp4_experts_quant_sm100a( - output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts, mask); + output, output_scale, input, input_global_scale, mask, use_silu_and_mul); #endif TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 experts quantization kernel"); } diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 5765a0b7ee3..28422ad18ef 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -394,9 +394,8 @@ void silu_and_mul_scaled_fp4_experts_quant( torch::Tensor& output_scale, torch::Tensor const& input, torch::Tensor const& input_global_scale, - torch::Tensor const& input_offset_by_experts, - torch::Tensor const& output_scale_offset_by_experts, - torch::Tensor const& mask); + torch::Tensor const& mask, + bool use_silu_and_mul); /* * From csrc/moe/cutlass_moe/w4a8 */ diff --git a/sgl-kernel/python/sgl_kernel/gemm.py b/sgl-kernel/python/sgl_kernel/gemm.py index bd85ee94935..36672877d70 100644 --- a/sgl-kernel/python/sgl_kernel/gemm.py +++ b/sgl-kernel/python/sgl_kernel/gemm.py @@ -298,6 +298,7 @@ def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): def scaled_fp4_grouped_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, + mask: torch.Tensor, ): """ Quantize input tensor to FP4 and return quantized tensor and scale, for @@ -331,22 +332,14 @@ def scaled_fp4_grouped_quant( output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 ) - input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device) - output_offsets = torch.arange( - 0, - (l + 1) * padded_m, - step=padded_m, - dtype=torch.int, - device=device, - ) - torch.ops.sgl_kernel.scaled_fp4_experts_quant.default( + torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default( output.view(l * m, k // 2), output_scales.view(l * padded_m, padded_k_int32), input_tensor.view(l * m, k), input_global_scale, - input_offsets, - output_offsets, + mask, + use_silu_and_mul=False, ) # The physical layout of the output is (l, m, k // 2), but we want to return a # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. @@ -400,23 +393,14 @@ def silu_and_mul_scaled_fp4_grouped_quant( output_scales = torch.empty( l, padded_m, padded_k_int32, device=device, dtype=torch.int32 ) - input_offsets = torch.arange(0, (l + 1) * m, step=m, dtype=torch.int, device=device) - output_offsets = torch.arange( - 0, - (l + 1) * padded_m, - step=padded_m, - dtype=torch.int, - device=device, - ) torch.ops.sgl_kernel.silu_and_mul_scaled_fp4_experts_quant.default( output.view(l * m, k // 2), output_scales.view(l * padded_m, padded_k_int32), input_tensor.view(l * m, k_by_2), input_global_scale, - input_offsets, - output_offsets, mask, + use_silu_and_mul=True, ) # The physical layout of the output is (l, m, k // 2), but we want to return a # logical layout (m, k // 2, l) required by the flashinfer masked group gemm. diff --git a/sgl-kernel/tests/test_fp4_quantize.py b/sgl-kernel/tests/test_fp4_quantize.py index 6f68330cd10..3e83e47ac67 100644 --- a/sgl-kernel/tests/test_fp4_quantize.py +++ b/sgl-kernel/tests/test_fp4_quantize.py @@ -174,17 +174,22 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: @pytest.mark.skipif( skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." ) -def test_quantize_to_fp4_grouped(): +@pytest.mark.parametrize("shape", [(2, 512, 2048), (2, 100, 128), (2, 128, 96)]) +def test_quantize_to_fp4_grouped(shape): torch.manual_seed(42) torch.set_default_device("cuda:0") - l, m, k = 2, 512, 2048 + l, m, k = shape x = torch.randn((l, m, k), dtype=torch.bfloat16) + max_m = m // 2 + assert max_m <= m + mask = torch.randint(1, max_m, (l,), dtype=torch.int32) tensor_amax = x.abs().amax(dim=(1, 2)).to(torch.float32) x_sf_global = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax output, output_scales = scaled_fp4_grouped_quant( x, x_sf_global, + mask, ) # output in logical (m, k, l), but its physical layout is (l, m, k). # So permute first to (l, m, k). @@ -195,23 +200,25 @@ def test_quantize_to_fp4_grouped(): output_scales = output_scales.permute(5, 2, 4, 0, 1, 3).view(l, padded_m, -1) for i in range(l): a_fp4, a_scale_interleaved = scaled_fp4_quant(x[i], x_sf_global[i]) - torch.testing.assert_close(a_fp4, output[i]) - torch.testing.assert_close( - a_scale_interleaved.to(torch.float), output_scales[i].to(torch.float) - ) + torch.testing.assert_close(a_fp4[: mask[i]], output[i][: mask[i]]) + # Recover swizzled scales to linear layout and drop padded values, so + # no extra checks on padding are needed. + scale_ref = recover_swizzled_scales(a_scale_interleaved, m, k) + scale_ans = recover_swizzled_scales(output_scales[i], m, k) + torch.testing.assert_close(scale_ref[: mask[i]], scale_ans[: mask[i]]) @pytest.mark.skipif( skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." ) -@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048)]) -def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None: +@pytest.mark.parametrize("shape", [(32, 100, 2048), (32, 512, 2048), (6, 6144, 2048)]) +def test_silu_and_mul_quantize_to_fp4_grouped(shape): torch.manual_seed(42) torch.set_default_device("cuda:0") l, m, k = shape x = torch.randn((l, m, k * 2), dtype=torch.bfloat16) - max_m = 8 + max_m = m // 2 assert max_m <= m mask = torch.randint(1, max_m, (l,), dtype=torch.int32) @@ -221,6 +228,7 @@ def test_silu_and_mul_quantize_to_fp4_grouped(shape: tuple[int, int]) -> None: ref_output, ref_output_scales = scaled_fp4_grouped_quant( ref_y, y_sf_global, + mask, ) output, output_scales = silu_and_mul_scaled_fp4_grouped_quant( x,