Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 121 additions & 121 deletions benchmark/kernels/quantization/bench_fp4_quant.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,137 @@
"""Benchmark FP4 quantize: sglang jit_kernel vs flashinfer.

Compares ``sglang.jit_kernel.nvfp4.scaled_fp4_quant`` against
``flashinfer.fp4_quantize`` over a sweep of (M, K) shapes.

Timing uses ``flashinfer.testing.bench_gpu_time`` (CUDA-graph based with
rotating-buffer cold-L2).
"""

import argparse
import itertools

import numpy as np
import torch
import triton
from flashinfer import (
scaled_fp4_grouped_quantize,
silu_and_mul_scaled_nvfp4_experts_quantize,
)
from sgl_kernel.elementwise import silu_and_mul

from sglang.benchmark.bench_utils import run_bench
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe.ep_moe.kernels import silu_and_mul_masked_post_quant_fwd


def _test_accuracy_once(E, M, K, input_dtype, device):
x = torch.randn(E, M, K, device=device, dtype=input_dtype)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.full((E,), M, dtype=torch.int32, device=device)
out, blk_scales = silu_and_mul_scaled_nvfp4_experts_quantize(x, masks, glb_scales)
out1, blk_scales1 = scaled_fp4_grouped_quantize(
silu_and_mul(x),
masks,
glb_scales,
from flashinfer import fp4_quantize as flashinfer_fp4_quantize
from flashinfer.testing import bench_gpu_time

from sglang.jit_kernel.nvfp4 import scaled_fp4_quant

Ms = [1, 8, 32, 128, 512, 1024, 2048, 4096, 8192, 16384, 32768]
Ks = [128, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096, 5120, 6144, 8192, 16384]


def _bench(fn, input_args) -> float:
times = bench_gpu_time(
fn=fn,
input_args=input_args,
use_cuda_graph=True,
dry_run_time_ms=25,
repeat_time_ms=100,
)
return float(np.median(times))


def benchmark(M: int, K: int, dtype: torch.dtype, device: str):
x = torch.randn(M, K, device=device, dtype=dtype)
global_scale = torch.ones(1, device=device, dtype=torch.float32)

torch.testing.assert_close(out, out1)
torch.testing.assert_close(blk_scales, blk_scales1)
print(f"E: {E}, M: {M}, K: {K}, type: {input_dtype} OK")


NUM_RANKS = 48
M_PER_RANKs = [128, 256, 512, 1024]
Ms = [M_PER_RANK * NUM_RANKS for M_PER_RANK in M_PER_RANKs]
Ks = [2048, 4096, 7168]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K"],
x_vals=list(itertools.product(Ms, Ks)),
x_log=False,
line_arg="provider",
line_vals=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
line_names=["triton_fp8", "cuda_unfused_fp4", "cuda_fused_fp4"],
styles=[("blue", "-"), ("orange", "-"), ("green", "-")],
ylabel="ms",
plot_name="fp4 quant",
args={},
sglang_ms = _bench(
lambda x, gs: scaled_fp4_quant(x, gs),
input_args=(x, global_scale),
)
)
def benchmark(M, K, provider):
E = 6
device = "cuda"
x = torch.randn(E, M, K, device=device, dtype=torch.bfloat16)
glb_scales = torch.ones((E,), dtype=torch.float32, device=device)
masks = torch.randint(1, 4096, (E,), dtype=torch.int32, device=device)
fp8_out = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2,
),
device=x.device,
dtype=torch.float8_e4m3fn,
flashinfer_ms = _bench(
lambda x, gs: flashinfer_fp4_quantize(x, gs, backend="cute-dsl"),
input_args=(x, global_scale),
)
scale_block_size = 128
fp8_scales = torch.empty(
(
x.shape[0],
x.shape[1],
x.shape[2] // 2 // scale_block_size,
),
device=x.device,
dtype=torch.float32,

return sglang_ms, flashinfer_ms


def plot_speedup(rows, path):
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

Ms_unique = sorted({int(r[0]) for r in rows})
Ks_unique = sorted({int(r[1]) for r in rows})
grid = np.full((len(Ms_unique), len(Ks_unique)), np.nan)
m_idx = {m: i for i, m in enumerate(Ms_unique)}
k_idx = {k: i for i, k in enumerate(Ks_unique)}
for M, K, _, _, sp in rows:
grid[m_idx[int(M)], k_idx[int(K)]] = float(sp)

fig, ax = plt.subplots(figsize=(12, 8))
vmax = max(2.0, np.nanmax(grid))
vmin = min(0.5, np.nanmin(grid))
im = ax.imshow(
grid,
aspect="auto",
cmap="RdYlGn",
vmin=vmin,
vmax=vmax,
origin="lower",
)
ax.set_xticks(range(len(Ks_unique)))
ax.set_xticklabels(Ks_unique, rotation=45)
ax.set_yticks(range(len(Ms_unique)))
ax.set_yticklabels(Ms_unique)
ax.set_xlabel("K")
ax.set_ylabel("M")
ax.set_title("Speedup: flashinfer / sglang (>1 means sglang faster)")
for i in range(len(Ms_unique)):
for j in range(len(Ks_unique)):
v = grid[i, j]
if np.isfinite(v):
ax.text(j, i, f"{v:.2f}", ha="center", va="center", fontsize=7)
fig.colorbar(im, ax=ax, label="speedup")
fig.tight_layout()
fig.savefig(path, dpi=130)
print(f"Saved plot to {path}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--dtype", choices=["bf16", "fp16"], default="bf16")
parser.add_argument("--device", default="cuda")
parser.add_argument("--csv", type=str, default=None)
parser.add_argument("--plot", type=str, default=None)
args = parser.parse_args()

quantiles = (0.5, 0.2, 0.8)
if provider == "triton_fp8":
ms, min_ms, max_ms = run_bench(
lambda: silu_and_mul_masked_post_quant_fwd(
x,
fp8_out,
fp8_scales,
scale_block_size,
masks,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
),
quantiles=quantiles,
)
if provider == "cuda_unfused_fp4":
ms, min_ms, max_ms = run_bench(
lambda: scaled_fp4_grouped_quantize(
silu_and_mul(x),
masks,
glb_scales,
),
quantiles=quantiles,
)
if provider == "cuda_fused_fp4":
ms, min_ms, max_ms = run_bench(
lambda: silu_and_mul_scaled_nvfp4_experts_quantize(
x,
masks,
glb_scales,
),
quantiles=quantiles,
)
dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float16

return ms, min_ms, max_ms
rows = []
header = (
f"{'M':>8} {'K':>8} {'sglang(us)':>12} {'flashinfer(us)':>16} {'speedup':>10}"
)
print(header)
print("-" * len(header))

for M, K in itertools.product(Ms, Ks):
try:
sglang_ms, flashinfer_ms = benchmark(M, K, dtype, args.device)
except Exception as e:
print(f"{M:>8} {K:>8} skipped: {e}")
continue
sglang_us = sglang_ms * 1e3
flashinfer_us = flashinfer_ms * 1e3
speedup = flashinfer_us / sglang_us
print(
f"{M:>8} {K:>8} {sglang_us:>12.3f} {flashinfer_us:>16.3f} {speedup:>10.3f}"
)
rows.append((M, K, sglang_us, flashinfer_us, speedup))

if args.csv:
with open(args.csv, "w") as f:
f.write("M,K,sglang_us,flashinfer_us,speedup_flashinfer_over_sglang\n")
for M, K, s, fi, sp in rows:
f.write(f"{M},{K},{s:.6f},{fi:.6f},{sp:.6f}\n")
print(f"Saved CSV to {args.csv}")

def test_accuracy():
E = 6
N_RANKS = 48
Ms = [128, 256, 512, 1024]
Ks = [2048, 4096, 7168]
input_dtype = torch.bfloat16
for M in Ms:
for K in Ks:
_test_accuracy_once(E, N_RANKS * M, K, input_dtype, "cuda")
if args.plot:
plot_speedup(rows, args.plot)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./bench_fp4_quant_res",
help="Path to save fp4 quant benchmark results",
)
args = parser.parse_args()

test_accuracy()

benchmark.run(print_data=True, show_plots=True, save_path=args.save_path)
main()
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,9 @@ def fused_experts_none_to_flashinfer_cutedsl_fp4(
quant_info: CuteDslFp4MoeQuantInfo,
runner_config: MoeRunnerConfig,
) -> StandardCombineInput:
from flashinfer import fp4_quantize

from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
from sglang.srt.layers.moe.topk import TopKOutputChecker
from sglang.srt.layers.quantization.fp4_utils import fp4_quantize

assert runner_config.activation == "silu", "Only silu is supported for CuteDSL MoE."

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def round_up_to_multiple(x: int, m: int) -> int:
)

if is_flashinfer_available():
from flashinfer import fp4_quantize
from sglang.srt.layers.quantization.fp4_utils import fp4_quantize
elif is_cuda_alike():
from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as fp4_quantize
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@
from sglang.srt.utils import get_int_env_var

try:
from flashinfer import fp4_quantize, nvfp4_block_scale_interleave
from flashinfer import nvfp4_block_scale_interleave
from flashinfer.comm import MoeAlltoAll, moe_a2a_get_workspace_size_per_rank
from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig

from sglang.srt.layers.quantization.fp4_utils import fp4_quantize

use_flashinfer = True
except ImportError:
use_flashinfer = False
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/moe/token_dispatcher/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,13 @@


try:
from flashinfer import fp4_quantize as fp4_quantize_flashinfer
from flashinfer import (
nvfp4_block_scale_interleave as nvfp4_block_scale_interleave_flashinfer,
)

from sglang.srt.layers.quantization.modelopt_quant import (
fp4_quantize as fp4_quantize_flashinfer,
)
except ImportError:
fp4_quantize_flashinfer = None
nvfp4_block_scale_interleave_flashinfer = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def apply_weights(
topk_output = dispatch_output.topk_output

if self.use_flashinfer_trtllm:
from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe
from flashinfer import trtllm_fp4_block_scale_moe

from sglang.srt.layers.quantization.fp4_utils import fp4_quantize

router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
Expand Down
Loading
Loading