diff --git a/benchmark/kernels/quantization/bench_fp4_quant.py b/benchmark/kernels/quantization/bench_fp4_quant.py index 9baedf4077be..0d5b54aebd37 100644 --- a/benchmark/kernels/quantization/bench_fp4_quant.py +++ b/benchmark/kernels/quantization/bench_fp4_quant.py @@ -1,137 +1,137 @@ +"""Benchmark FP4 quantize: sglang jit_kernel vs flashinfer. + +Compares ``sglang.jit_kernel.nvfp4.scaled_fp4_quant`` against +``flashinfer.fp4_quantize`` over a sweep of (M, K) shapes. + +Timing uses ``flashinfer.testing.bench_gpu_time`` (CUDA-graph based with +rotating-buffer cold-L2). +""" + import argparse import itertools +import numpy as np import torch -import triton -from flashinfer import ( - scaled_fp4_grouped_quantize, - silu_and_mul_scaled_nvfp4_experts_quantize, -) -from sgl_kernel.elementwise import silu_and_mul - -from sglang.benchmark.bench_utils import run_bench -from sglang.srt.layers import deep_gemm_wrapper -from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd - - -def _test_accuracy_once(E, M, K, input_dtype, device): - x = torch.randn(E, M, K, device=device, dtype=input_dtype) - glb_scales = torch.ones((E,), dtype=torch.float32, device=device) - masks = torch.full((E,), M, dtype=torch.int32, device=device) - out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales) - out1, blk_scales1 = scaled_fp4_grouped_quantize( - silu_and_mul(x), - masks, - glb_scales, +from flashinfer import fp4_quantize as flashinfer_fp4_quantize +from flashinfer.testing import bench_gpu_time + +from sglang.jit_kernel.nvfp4 import scaled_fp4_quant + +Ms = [1, 8, 32, 128, 512, 1024, 2048, 4096, 8192, 16384, 32768] +Ks = [128, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 8192, 16384] + + +def _bench(fn, input_args) -> float: + times = bench_gpu_time( + fn=fn, + input_args=input_args, + use_cuda_graph=True, + dry_run_time_ms=25, + repeat_time_ms=100, ) + return float(np.median(times)) + + +def benchmark(M: int, K: int, dtype: torch.dtype, device: str): + x = torch.randn(M, K, device=device, dtype=dtype) + global_scale = torch.ones(1, device=device, dtype=torch.float32) - torch.testing.assert_close(out, out1) - torch.testing.assert_close(blk_scales, blk_scales1) - print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK") - - -NUM_RANKS = 48 -M_PER_RANKs = [128, 256, 512, 1024] -Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs] -Ks = [2048, 4096, 7168] - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["M", "K"], - x_vals=list(itertools.product(Ms, Ks)), - x_log=False, - line_arg="provider", - line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], - line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"], - styles=[("blue", "-"), ("orange", "-"), ("green", "-")], - ylabel="ms", - plot_name="fp4 quant", - args={}, + sglang_ms = _bench( + lambda x, gs: scaled_fp4_quant(x, gs), + input_args=(x, global_scale), ) -) -def benchmark(M, K, provider): - E = 6 - device = "cuda" - x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16) - glb_scales = torch.ones((E,), dtype=torch.float32, device=device) - masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device) - fp8_out = torch.empty( - ( - x.shape[0], - x.shape[1], - x.shape[2] // 2, - ), - device=x.device, - dtype=torch.float8_e4m3fn, + flashinfer_ms = _bench( + lambda x, gs: flashinfer_fp4_quantize(x, gs, backend="cute-dsl"), + input_args=(x, global_scale), ) - scale_block_size = 128 - fp8_scales = torch.empty( - ( - x.shape[0], - x.shape[1], - x.shape[2] // 2 // scale_block_size, - ), - device=x.device, - dtype=torch.float32, + + return sglang_ms, flashinfer_ms + + +def plot_speedup(rows, path): + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + Ms_unique = sorted({int(r[0]) for r in rows}) + Ks_unique = sorted({int(r[1]) for r in rows}) + grid = np.full((len(Ms_unique), len(Ks_unique)), np.nan) + m_idx = {m: i for i, m in enumerate(Ms_unique)} + k_idx = {k: i for i, k in enumerate(Ks_unique)} + for M, K, _, _, sp in rows: + grid[m_idx[int(M)], k_idx[int(K)]] = float(sp) + + fig, ax = plt.subplots(figsize=(12, 8)) + vmax = max(2.0, np.nanmax(grid)) + vmin = min(0.5, np.nanmin(grid)) + im = ax.imshow( + grid, + aspect="auto", + cmap="RdYlGn", + vmin=vmin, + vmax=vmax, + origin="lower", ) + ax.set_xticks(range(len(Ks_unique))) + ax.set_xticklabels(Ks_unique, rotation=45) + ax.set_yticks(range(len(Ms_unique))) + ax.set_yticklabels(Ms_unique) + ax.set_xlabel("K") + ax.set_ylabel("M") + ax.set_title("Speedup: flashinfer / sglang (>1 means sglang faster)") + for i in range(len(Ms_unique)): + for j in range(len(Ks_unique)): + v = grid[i, j] + if np.isfinite(v): + ax.text(j, i, f"{v:.2f}", ha="center", va="center", fontsize=7) + fig.colorbar(im, ax=ax, label="speedup") + fig.tight_layout() + fig.savefig(path, dpi=130) + print(f"Saved plot to {path}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16") + parser.add_argument("--device", default="cuda") + parser.add_argument("--csv", type=str, default=None) + parser.add_argument("--plot", type=str, default=None) + args = parser.parse_args() - quantiles = (0.5, 0.2, 0.8) - if provider == "triton_fp8": - ms, min_ms, max_ms = run_bench( - lambda: silu_and_mul_masked_post_quant_fwd( - x, - fp8_out, - fp8_scales, - scale_block_size, - masks, - scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0, - ), - quantiles=quantiles, - ) - if provider == "cuda_unfused_fp4": - ms, min_ms, max_ms = run_bench( - lambda: scaled_fp4_grouped_quantize( - silu_and_mul(x), - masks, - glb_scales, - ), - quantiles=quantiles, - ) - if provider == "cuda_fused_fp4": - ms, min_ms, max_ms = run_bench( - lambda: silu_and_mul_scaled_nvfp4_experts_quantize( - x, - masks, - glb_scales, - ), - quantiles=quantiles, - ) + dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16 - return ms, min_ms, max_ms + rows = [] + header = ( + f"{'M':>8} {'K':>8} {'sglang(us)':>12} {'flashinfer(us)':>16} {'speedup':>10}" + ) + print(header) + print("-" * len(header)) + + for M, K in itertools.product(Ms, Ks): + try: + sglang_ms, flashinfer_ms = benchmark(M, K, dtype, args.device) + except Exception as e: + print(f"{M:>8} {K:>8} skipped: {e}") + continue + sglang_us = sglang_ms * 1e3 + flashinfer_us = flashinfer_ms * 1e3 + speedup = flashinfer_us / sglang_us + print( + f"{M:>8} {K:>8} {sglang_us:>12.3f} {flashinfer_us:>16.3f} {speedup:>10.3f}" + ) + rows.append((M, K, sglang_us, flashinfer_us, speedup)) + if args.csv: + with open(args.csv, "w") as f: + f.write("M,K,sglang_us,flashinfer_us,speedup_flashinfer_over_sglang\n") + for M, K, s, fi, sp in rows: + f.write(f"{M},{K},{s:.6f},{fi:.6f},{sp:.6f}\n") + print(f"Saved CSV to {args.csv}") -def test_accuracy(): - E = 6 - N_RANKS = 48 - Ms = [128, 256, 512, 1024] - Ks = [2048, 4096, 7168] - input_dtype = torch.bfloat16 - for M in Ms: - for K in Ks: - _test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda") + if args.plot: + plot_speedup(rows, args.plot) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_path", - type=str, - default="./bench_fp4_quant_res", - help="Path to save fp4 quant benchmark results", - ) - args = parser.parse_args() - - test_accuracy() - - benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) + main() diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py index a9213d83bfee..be2ae6b42f06 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_cutedsl.py @@ -313,10 +313,9 @@ def fused_experts_none_to_flashinfer_cutedsl_fp4( quant_info: CuteDslFp4MoeQuantInfo, runner_config: MoeRunnerConfig, ) -> StandardCombineInput: - from flashinfer import fp4_quantize - from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput from sglang.srt.layers.moe.topk import TopKOutputChecker + from sglang.srt.layers.quantization.fp4_utils import fp4_quantize assert runner_config.activation == "silu", "Only silu is supported for CuteDSL MoE." diff --git a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py index e2b0f08c49f5..01078cf1c1af 100644 --- a/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py +++ b/python/sglang/srt/layers/moe/moe_runner/flashinfer_trtllm.py @@ -49,7 +49,7 @@ def round_up_to_multiple(x: int, m: int) -> int: ) if is_flashinfer_available(): - from flashinfer import fp4_quantize + from sglang.srt.layers.quantization.fp4_utils import fp4_quantize elif is_cuda_alike(): from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize else: diff --git a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py index 7b5080bb860f..c4384ba7995d 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/flashinfer.py @@ -25,11 +25,13 @@ from sglang.srt.utils import get_int_env_var try: - from flashinfer import fp4_quantize, nvfp4_block_scale_interleave + from flashinfer import nvfp4_block_scale_interleave from flashinfer.comm import MoeAlltoAll, moe_a2a_get_workspace_size_per_rank from flashinfer.comm.mapping import Mapping from flashinfer.comm.mnnvl import MnnvlConfig + from sglang.srt.layers.quantization.fp4_utils import fp4_quantize + use_flashinfer = True except ImportError: use_flashinfer = False diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 35ee82fed85c..7658c28d43a2 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -44,10 +44,13 @@ try: - from flashinfer import fp4_quantize as fp4_quantize_flashinfer from flashinfer import ( nvfp4_block_scale_interleave as nvfp4_block_scale_interleave_flashinfer, ) + + from sglang.srt.layers.quantization.modelopt_quant import ( + fp4_quantize as fp4_quantize_flashinfer, + ) except ImportError: fp4_quantize_flashinfer = None nvfp4_block_scale_interleave_flashinfer = None diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py index 6b285809ba16..69e572498c66 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4_moe.py @@ -304,7 +304,9 @@ def apply_weights( topk_output = dispatch_output.topk_output if self.use_flashinfer_trtllm: - from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe + from flashinfer import trtllm_fp4_block_scale_moe + + from sglang.srt.layers.quantization.fp4_utils import fp4_quantize router_logits = topk_output.router_logits topk_config = topk_output.topk_config diff --git a/python/sglang/srt/layers/quantization/fp4_utils.py b/python/sglang/srt/layers/quantization/fp4_utils.py index 938a2203e35e..96409750cbf7 100644 --- a/python/sglang/srt/layers/quantization/fp4_utils.py +++ b/python/sglang/srt/layers/quantization/fp4_utils.py @@ -2,9 +2,12 @@ import logging from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional + +import torch from sglang.srt.utils.common import is_sm100_supported, is_sm120_supported +from sglang.srt.utils.custom_op import register_custom_op_from_extern if TYPE_CHECKING: from sglang.srt.server_args import ServerArgs @@ -12,6 +15,77 @@ logger = logging.getLogger(__name__) +fp4_quantize = None +try: + from flashinfer import fp4_quantize as _flashinfer_fp4_quantize + + _flashinfer_fp4_quantize_backend = "cute-dsl" if is_sm100_supported() else "cuda" + + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + def _flashinfer_fp4_quantize_impl( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return _flashinfer_fp4_quantize( + input, + global_scale, + sf_vec_size, + sf_use_ue8m0, + is_sf_swizzled_layout, + is_sf_8x4_layout, + enable_pdl, + backend=_flashinfer_fp4_quantize_backend, + ) + + def _flashinfer_fp4_quantize_fake( + input: torch.Tensor, + global_scale: Optional[torch.Tensor] = None, + sf_vec_size: int = 16, + sf_use_ue8m0: bool = False, + is_sf_swizzled_layout: bool = True, + is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + is_column_major = input.stride(-2) == 1 + if is_column_major: + m = input.shape[-1] + K = input.shape[-2] + else: + m = input.numel() // input.shape[-1] + K = input.shape[-1] + if is_column_major: + x_q = input.new_empty((*input.shape[:-2], K // 2, m), dtype=torch.uint8) + else: + x_q = input.new_empty((*input.shape[:-1], K // 2), dtype=torch.uint8) + if is_sf_swizzled_layout: + row_size = 8 if is_sf_8x4_layout else 128 + sf_rows = _round_up(m, row_size) + sf_cols = _round_up(K // sf_vec_size, 4) + else: + sf_rows = m + sf_cols = K // sf_vec_size + if is_column_major: + sf = input.new_empty((sf_cols, sf_rows), dtype=torch.uint8) + else: + sf = input.new_empty((sf_rows, sf_cols), dtype=torch.uint8) + return x_q, sf + + fp4_quantize = register_custom_op_from_extern( + _flashinfer_fp4_quantize_impl, + op_name="flashinfer_fp4_quantize", + fake_impl=_flashinfer_fp4_quantize_fake, + ) +except ImportError: + fp4_quantize = None + + class Fp4GemmRunnerBackend(Enum): """Enum for FP4 GEMM runner backend selection.""" diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index d8ad8a1dda53..01b2d38aaa73 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -35,7 +35,10 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp4_utils import get_fp4_gemm_runner_backend +from sglang.srt.layers.quantization.fp4_utils import ( + fp4_quantize, + get_fp4_gemm_runner_backend, +) from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, @@ -70,18 +73,6 @@ ) from sglang.srt.models.utils import WeightsMapper -fp4_quantize = None -try: - if is_sm120_supported(): - try: - from flashinfer import fp4_quantize - except ImportError: - from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize - else: - from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize -except ImportError: - fp4_quantize = None - try: from flashinfer import mm_fp4 as flashinfer_fp4_gemm from flashinfer import reorder_rows_for_gated_act_gemm, shuffle_matrix_sf_a