Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
Expand Down
57 changes: 57 additions & 0 deletions sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse

import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
)
)
def benchmark(num_tokens, impl):
kHdIn = 7168
kHdOut = 2112
M, K, N = num_tokens, kHdIn, kHdOut

mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)

quantiles = [0.5, 0.2, 0.8]

if impl == "torch":

def runner():
F.linear(mat_a, mat_b.T)

elif impl == "sgl-kernel":

def runner():
dsv3_fused_a_gemm(mat_a, mat_b)

ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)

def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12

return tflops(ms), tflops(max_ms), tflops(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()

benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
3 changes: 3 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);

m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);

// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
Expand Down
Loading
Loading