Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
02842ca
[fusion][1/N] add fusion pass manager and base fusion pass
DevashishLal-CB Sep 18, 2025
3981b86
[fusion][2/N] add fused activation pass with triton kernel
DevashishLal-CB Sep 18, 2025
10accfc
[fusion][3/N] add rmsnorm and quant fusion pass
DevashishLal-CB Sep 19, 2025
ff13b9c
[fusion][4/N] added topk softmax fusion example and cleanup
DevashishLal-CB Sep 26, 2025
e9ef0e1
[fusion][5/N] uprev and integration with piecewise cuda graphs
DevashishLal-CB Nov 19, 2025
dada90f
Merge branch 'main' into gh/dlal/sgl-fusion
BLaZeKiLL Dec 19, 2025
b1976df
Merge branch 'main' into gh/dlal/sgl-fusion
BLaZeKiLL Dec 30, 2025
af11430
Merge branch 'main' into gh/dlal/sgl-fusion
BLaZeKiLL Feb 1, 2026
2ebffb6
[fusion][6/N] uprev and fix fusion passes
BLaZeKiLL Feb 1, 2026
fc7eeaa
[fusion][7/N] add support for flashinfer rmsnorm + quant fused kernels
BLaZeKiLL Feb 1, 2026
54c96b3
Merge branch 'main' into gh/dlal/sgl-fusion
Mar 10, 2026
680c182
Merge remote-tracking branch 'upstream/main' into gh/dlal/sgl-fusion
Mar 11, 2026
59ca839
[fusion][8/N] Add composable fusion pass framework
Mar 11, 2026
2053ce1
Merge remote-tracking branch 'upstream/main' into gh/dlal/sgl-fusion
Mar 12, 2026
811d620
[fusion][9/N] add norm quant jit kenrels and benchmarks
Mar 14, 2026
20ee820
Merge remote-tracking branch 'upstream/main' into gh/dlal/sgl-fusion
Mar 18, 2026
7ec84f3
[fusion][9/N] rmsnorm_quant, fused_add_rmsnorm_quant jit kernels
Mar 19, 2026
b6bb605
[fusion][10/N] Remove graph breaks and ensure full graph compilation
Mar 20, 2026
dc0d2f3
[fusion][11/N] Add cutedsl dual gemm kernel
Mar 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_dual_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import itertools

import sgl_kernel
import torch
import triton
import triton.testing

from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.cutedsl_dual_gemm import cutedsl_dual_gemm
from sglang.srt.compilation.fusion.ops.triton_ops.dual_gemm import dual_gemm

DEVICE = "cuda"
# M_LIST = [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024]
M_LIST = [1, 8, 32, 64, 128, 256, 512, 1024]
K_LIST = [4096]
N_LIST = [4096, 11008]
DTYPE_LIST = [torch.bfloat16]

configs = list(itertools.product(M_LIST, K_LIST, N_LIST, DTYPE_LIST))


def _is_sm90():
return torch.cuda.get_device_capability()[0] >= 9


def cutedsl_fn(x, w_gate, w_up, out):
cutedsl_dual_gemm(x, w_gate, w_up, out)


def triton_fn(x, w, out):
out.copy_(dual_gemm(x, w))


def reference_fn(x, w, out):
mm_result = torch.mm(x, w) # (M, 2*N)
sgl_kernel.silu_and_mul(mm_result, out) # (M, N)


def warmup():
if not _is_sm90():
return
for K, N, dtype in set((K, N, dtype) for _, K, N, dtype in configs):
M = M_LIST[0]
x = torch.randn((M, K), dtype=dtype, device=DEVICE)
w_gate = torch.randn((K, N), dtype=dtype, device=DEVICE)
w_up = torch.randn((K, N), dtype=dtype, device=DEVICE)
out = torch.empty((M, N), dtype=dtype, device=DEVICE)
cutedsl_dual_gemm(x, w_gate, w_up, out)
torch.cuda.synchronize()


LINE_VALS = ["cutedsl", "triton", "reference"]
LINE_NAMES = ["CuteDSL Dual GEMM", "Triton Dual GEMM", "Reference (mm + silu_and_mul)"]
STYLES = [("blue", "-"), ("orange", "--"), ("red", ":")]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K", "N", "dtype"],
x_vals=configs,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="dual-gemm-performance",
args={},
)
)
def benchmark(M: int, K: int, N: int, dtype: torch.dtype, provider: str):
if provider == "cutedsl" and not _is_sm90():
return 0.0, 0.0, 0.0

x = torch.randn((M, K), dtype=dtype, device=DEVICE)
w_gate = torch.randn((K, N), dtype=dtype, device=DEVICE)
w_up = torch.randn((K, N), dtype=dtype, device=DEVICE)
out = torch.empty((M, N), dtype=dtype, device=DEVICE)

# Concat for triton and reference (they take w as (K, 2*N))
w = torch.cat([w_gate, w_up], dim=1)

FN_MAP = {
"cutedsl": lambda: cutedsl_fn(x, w_gate, w_up, out),
"triton": lambda: triton_fn(x, w, out),
"reference": lambda: reference_fn(x, w, out),
}
return run_benchmark(FN_MAP[provider])


if __name__ == "__main__":
warmup()
benchmark.run(print_data=True)
166 changes: 166 additions & 0 deletions python/sglang/jit_kernel/benchmark/bench_dual_gemm_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import itertools

import sgl_kernel
import torch
import triton
import triton.testing

from sglang.jit_kernel.benchmark.utils import run_benchmark
from sglang.jit_kernel.cutedsl_dual_gemm import cutedsl_dual_gemm
from sglang.jit_kernel.per_tensor_quant_fp8 import per_tensor_quant_fp8
from sglang.srt.compilation.fusion.ops.triton_ops.dual_gemm import dual_gemm_kernel

DTYPE = torch.bfloat16
DEVICE = "cuda"
M_LIST = [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024]
K_LIST = [1024, 2048, 4096]
N_LIST = [1024, 2048, 4096, 8192]

configs = list(itertools.product(M_LIST, K_LIST, N_LIST))


def _is_sm90():
return torch.cuda.get_device_capability()[0] >= 9


def cutedsl_fn(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale):
cutedsl_dual_gemm(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale)


def triton_fn(x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale):
M_val, K_val = x_fp8.shape
N_val = w_gate_fp8.shape[1]

def grid(META):
return (
triton.cdiv(M_val, META["BLOCK_SIZE_M"])
* triton.cdiv(N_val, META["BLOCK_SIZE_K"]),
)

dual_gemm_kernel[grid](
x_fp8,
w_gate_fp8,
w_up_fp8,
out_fp8,
x_scale,
w_scale,
o_scale,
True,
x_fp8.stride(0),
x_fp8.stride(1),
w_gate_fp8.stride(0),
w_gate_fp8.stride(1),
out_fp8.stride(0),
out_fp8.stride(1),
M_val,
K_val,
N_val,
torch.finfo(torch.float8_e4m3fn).min,
torch.finfo(torch.float8_e4m3fn).max,
128,
64,
128,
1,
num_warps=4,
num_stages=4,
)


def reference_fp8_fn(
x_fp8, w_fp8_col_major, out_bf16, out_fp8, x_scale, w_scale, o_scale
):
# Step 1: FP8 scaled matmul -> BF16
# w_fp8_col_major is (K, 2*N) column-major for _scaled_mm
mm_result = torch._scaled_mm(
x_fp8,
w_fp8_col_major,
scale_a=x_scale,
scale_b=w_scale,
out_dtype=torch.bfloat16,
)
# Step 2: silu_and_mul
sgl_kernel.silu_and_mul(mm_result, out_bf16)
# Step 3: per-tensor FP8 quantize
per_tensor_quant_fp8(out_bf16, out_fp8, o_scale, is_static=True)


def warmup():
if not _is_sm90():
return
for K, N in set((K, N) for _, K, N in configs):
M = M_LIST[0]
x = torch.randn((M, K), dtype=DTYPE, device=DEVICE)
x_fp8 = x.to(torch.float8_e4m3fn)
w_gate_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(
torch.float8_e4m3fn
)
w_up_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(
torch.float8_e4m3fn
)
out_fp8 = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=DEVICE)
x_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
w_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
o_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
cutedsl_dual_gemm(
x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale
)
torch.cuda.synchronize()


LINE_VALS = ["cutedsl", "triton", "reference"]
LINE_NAMES = [
"CuteDSL Dual GEMM FP8",
"Triton Dual GEMM FP8",
"Reference (scaled_mm + silu_and_mul + quant_fp8)",
]
STYLES = [("blue", "-"), ("orange", "--"), ("red", ":")]


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["M", "K", "N"],
x_vals=configs,
line_arg="provider",
line_vals=LINE_VALS,
line_names=LINE_NAMES,
styles=STYLES,
ylabel="us",
plot_name="dual-gemm-fp8-performance",
args={},
)
)
def benchmark(M: int, K: int, N: int, provider: str):
if provider == "cutedsl" and not _is_sm90():
return 0.0, 0.0, 0.0

x_bf16 = torch.randn((M, K), dtype=DTYPE, device=DEVICE)
x_fp8 = x_bf16.to(torch.float8_e4m3fn)
w_gate_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(torch.float8_e4m3fn)
w_up_fp8 = torch.randn((K, N), dtype=DTYPE, device=DEVICE).to(torch.float8_e4m3fn)
out_fp8 = torch.empty((M, N), dtype=torch.float8_e4m3fn, device=DEVICE)
x_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
w_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)
o_scale = torch.tensor([1.0], dtype=torch.float32, device=DEVICE)

# For reference: _scaled_mm expects b as column-major (2*N, K).t()
w_cat = torch.cat([w_gate_fp8, w_up_fp8], dim=1) # (K, 2*N)
w_fp8_t = w_cat.t().contiguous().t() # (K, 2*N) column-major
out_bf16 = torch.empty((M, N), dtype=DTYPE, device=DEVICE)

FN_MAP = {
"cutedsl": lambda: cutedsl_fn(
x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale
),
"triton": lambda: triton_fn(
x_fp8, w_gate_fp8, w_up_fp8, out_fp8, x_scale, w_scale, o_scale
),
"reference": lambda: reference_fp8_fn(
x_fp8, w_fp8_t, out_bf16, out_fp8, x_scale, w_scale, o_scale
),
}
return run_benchmark(FN_MAP[provider])


if __name__ == "__main__":
warmup()
benchmark.run(print_data=True)
Loading
Loading