Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ set(SOURCES
"csrc/elementwise/rope.cu"
"csrc/gemm/awq_kernel.cu"
"csrc/gemm/bmm_fp8.cu"
"csrc/gemm/dsv3_fused_a_gemm.cu"
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
"csrc/gemm/fp8_gemm_kernel.cu"
"csrc/gemm/int8_gemm_kernel.cu"
Expand Down
57 changes: 57 additions & 0 deletions sgl-kernel/benchmark/bench_dsv3_fused_a_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse

import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import dsv3_fused_a_gemm


@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=[i + 1 for i in range(16)],
x_log=False,
line_arg="impl",
line_vals=["torch", "sgl-kernel"],
line_names=["torch (bf16)", "dsv3_fused_a_gemm"],
styles=[("blue", "-"), ("orange", "-")],
ylabel="TFLOPs",
plot_name="bf16 dsv3 fused a GEMM throughput",
args={},
)
)
def benchmark(num_tokens, impl):
kHdIn = 7168
kHdOut = 2112
M, K, N = num_tokens, kHdIn, kHdOut

mat_a = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous()
mat_b = torch.randn((N, K), dtype=torch.bfloat16, device="cuda").transpose(0, 1)

quantiles = [0.5, 0.2, 0.8]

if impl == "torch":

def runner():
F.linear(mat_a, mat_b.T)

elif impl == "sgl-kernel":

def runner():
dsv3_fused_a_gemm(mat_a, mat_b)

ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)

def tflops(t_ms):
flops = 2 * M * K * N
return flops / (t_ms * 1e-3) / 1e12

return tflops(ms), tflops(max_ms), tflops(min_ms)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()

benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
3 changes: 3 additions & 0 deletions sgl-kernel/csrc/common_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
" Tensor! output_scale, Tensor! input_scale) -> ()");
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);

m.def("dsv3_fused_a_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()");
m.impl("dsv3_fused_a_gemm", torch::kCUDA, &dsv3_fused_a_gemm);

// Compute NVFP4 experts quantization.
m.def(
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
Expand Down
Loading
Loading