Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 329 additions & 0 deletions benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm_blackwell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
import argparse
from typing import Tuple

import torch
import triton
from deep_gemm import ceil_div
from flashinfer.gemm import gemm_fp8_nt_groupwise

from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
w8a8_block_fp8_matmul_deepgemm,
)
from sglang.srt.layers.quantization.fp8_utils import requant_weight_ue8m0

BLOCK_SIZE = 128


def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
assert BLOCK_SIZE == 128
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
Comment on lines +22 to +26
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The constant BLOCK_SIZE is defined at the top of the file, but the magic number 128 is used here. For better readability and maintainability, please use the defined constant BLOCK_SIZE.

Suggested change
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_padded = torch.zeros(
(ceil_div(m, BLOCK_SIZE) * BLOCK_SIZE, ceil_div(n, BLOCK_SIZE) * BLOCK_SIZE), dtype=x.dtype, device=x.device
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, BLOCK_SIZE, x_padded.size(1) // BLOCK_SIZE, BLOCK_SIZE)

x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
x_view.size(0), x_view.size(2)
)


def get_weight_shapes(tp_size):
# cannot TP
total = [
(512 + 64, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(7168, 16384),
(7168, 18432),
]
# N can TP
n_tp = [
(18432 * 2, 7168),
((128 + 64) * 128, 7168),
(128 * (128 + 128), 512),
(24576, 1536),
(4096, 7168),
]
# K can TP
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]

weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
Comment on lines +54 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The construction of weight_shapes can be simplified by using list.extend with list comprehensions. This is more Pythonic and improves readability.

Suggested change
weight_shapes = []
for t in total:
weight_shapes.append(t)
for n_t in n_tp:
new_t = (n_t[0] // tp_size, n_t[1])
weight_shapes.append(new_t)
for k_t in k_tp:
new_t = (k_t[0], k_t[1] // tp_size)
weight_shapes.append(new_t)
weight_shapes = list(total)
weight_shapes.extend([(n_t[0] // tp_size, n_t[1]) for n_t in n_tp])
weight_shapes.extend([(k_t[0], k_t[1] // tp_size) for k_t in k_tp])


return weight_shapes


def create_benchmark_configs(tp_size):
configs = []
weight_shapes = get_weight_shapes(tp_size)
batch_sizes = [8, 16, 32, 64, 128, 256, 1024, 2048, 4096]

for n, k in weight_shapes:
for m in batch_sizes:
configs.append((m, n, k, tp_size))

return configs


def fp8_gemm_flashinfer(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
):
"""Flashinfer implementation of FP8 GEMM"""
output = gemm_fp8_nt_groupwise(
x_fp8,
y_fp8,
x_scale,
y_scale,
out_dtype=torch.bfloat16,
backend="trtllm",
)
return output


def fp8_gemm_deepgemm_blackwell(
x_fp8: torch.Tensor,
x_scale: torch.Tensor,
y_fp8: torch.Tensor,
y_scale: torch.Tensor,
):
"""DeepGEMM implementation of FP8 GEMM"""
block_size = [BLOCK_SIZE, BLOCK_SIZE]
output = w8a8_block_fp8_matmul_deepgemm(
x_fp8, y_fp8, x_scale, y_scale, block_size, output_dtype=torch.bfloat16
)
return output


def check_accuracy(a, b, atol, rtol, percent):
"""Unified accuracy checking function with detailed error reporting."""
if not torch.isfinite(a).all():
print("Non-finite values in reference output")
return False
if not torch.isfinite(b).all():
print("Non-finite values in actual output")
return False
assert a.shape == b.shape, f"Shape mismatch: {a.shape} vs {b.shape}"

close = torch.isclose(a, b, atol=atol, rtol=rtol)
match_ratio = close.float().mean()
if match_ratio >= percent:
return True

mismatch_percent = 1.0 - match_ratio.item()
if mismatch_percent > 1 - percent:
print(
f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
f"(threshold: {1 - percent:.4f})"
)
return False
Comment on lines +126 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if mismatch_percent > 1 - percent: is redundant. If the execution reaches this point, it means match_ratio < percent, which implies 1.0 - match_ratio.item() > 1.0 - percent. Therefore, the if statement is always true. You can simplify the code by removing this condition.

    mismatch_percent = 1.0 - match_ratio.item()
    print(
        f"Mismatch percentage is {mismatch_percent:.4f} for rtol {rtol} "
        f"(threshold: {1 - percent:.4f})"
    )
    return False



def calculate_diff(m: int, n: int, k: int):
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

y_fp8, y_scale = per_block_cast_to_fp8(y)
x_fp8, x_scale = sglang_per_token_group_quant_fp8(
x, BLOCK_SIZE, column_major_scales=True
)
out_flashinfer = fp8_gemm_flashinfer(
x_fp8,
x_scale,
y_fp8,
y_scale,
)

dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(
x,
BLOCK_SIZE,
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
)
# We can directly quantize y here, but to mimic the behavior of the actual
# implementations, we requant it here.
dg_y_fp8, dg_y_scale = requant_weight_ue8m0(
y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]
)
out_deepgemm = fp8_gemm_deepgemm_blackwell(
dg_x_fp8, dg_x_scale, dg_y_fp8, dg_y_scale
)

print(f"Shape m={m}, n={n}, k={k}:")
print(f"Flashinfer output: {out_flashinfer[0, 0:5]}")
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")

flashinfer_deepgemm_match = check_accuracy(
out_flashinfer, out_deepgemm, 0.1, 0.6, 0.95
)
print("Correctness check:")
print(f" - Flashinfer vs DeepGEMM: {'✅' if flashinfer_deepgemm_match else '❌'}")


def _benchmark(m, n, k, tp_size, provider):
print(f"Shape (m={m}, n={n}, k={k}, tp={tp_size}), Provider: {provider}")
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)

# Preprocess data before benchmarking
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_fp8, x_scale = sglang_per_token_group_quant_fp8(
x, BLOCK_SIZE, column_major_scales=True
)
dg_x_fp8, dg_x_scale = sglang_per_token_group_quant_fp8(
x,
BLOCK_SIZE,
column_major_scales=True,
scale_tma_aligned=True,
scale_ue8m0=True,
)
dg_y_fp8, dg_y_scale = requant_weight_ue8m0(
y_fp8, y_scale, [BLOCK_SIZE, BLOCK_SIZE]
)

quantiles = [0.5, 0.2, 0.8]

if provider == "deepgemm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_deepgemm_blackwell(
dg_x_fp8,
dg_x_scale,
dg_y_fp8,
dg_y_scale,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fp8_gemm_flashinfer(
x_fp8,
x_scale,
y_fp8,
y_scale,
),
quantiles=quantiles,
)

# Calculate TFLOPS
flops = 2 * m * n * k # multiply-adds
tflops = flops / (ms * 1e-3) / 1e12

# Print shape-specific results with TFLOPS
print(f"Time: {ms*1000:.2f} us, TFLOPS: {tflops:.2f}")
return ms, max_ms, min_ms


def get_benchmark_plot_friendly(tp_size):
all_configs = create_benchmark_configs(tp_size)
x_vals = list(range(len(all_configs)))

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["cfg_id"],
x_vals=x_vals,
line_arg="provider",
line_vals=["deepgemm", "flashinfer"],
line_names=["DeepGEMM", "Flashinfer"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ylabel is set to "us" (microseconds), but the benchmark function returns values in milliseconds (ms * 1000). Please change ylabel to "ms" to match the returned values.

            ylabel="ms",

plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(cfg_id, provider):
m, n, k, tp_size = all_configs[cfg_id]
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms

return benchmark


def get_benchmark(tp_size):
all_configs = create_benchmark_configs(tp_size)

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["m", "n", "k", "tp_size"],
x_vals=[list(config) for config in all_configs],
line_arg="provider",
line_vals=["deepgemm", "flashinfer"],
line_names=["DeepGEMM", "Flashinfer"],
styles=[("blue", "-"), ("red", "-")],
ylabel="us",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The ylabel is set to "us" (microseconds), but the benchmark function returns values in milliseconds (ms * 1000). This should be corrected to "ms" to match the returned values.

            ylabel="ms",

plot_name=f"fp8-gemm-performance-comparison-tp{tp_size}",
args={},
)
)
def benchmark(m, n, k, tp_size, provider):
ms, min_ms, max_ms = _benchmark(m, n, k, tp_size, provider)
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms

return benchmark


if __name__ == "__main__":
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] != 10:
print("Skipping benchmark because the device is not supported")
exit(0)

parser = argparse.ArgumentParser()
parser.add_argument(
"--save-path",
type=str,
default="./configs/benchmark_ops/fp8_gemm/",
help="Path to save fp8 gemm benchmark results",
)
parser.add_argument(
"--run-correctness",
action="store_true",
default=True,
help="Whether to run correctness test",
Comment on lines +291 to +294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current argparse configuration for --run-correctness with action="store_true" and default=True makes it impossible to disable the correctness check. A better approach is to use action=argparse.BooleanOptionalAction, which automatically creates both --run-correctness and --no-run-correctness flags, providing more intuitive control.

        "--run-correctness",
        action=argparse.BooleanOptionalAction,
        default=True,
        help="Whether to run correctness test",

)
parser.add_argument(
"--tp-size",
type=int,
default=1,
help="Tensor parallelism size to benchmark (default: 1)",
)
parser.add_argument(
"--plot-friendly",
action="store_true",
default=False,
help="Plot x axis as the config index instead of the m",
)
args = parser.parse_args()

# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.manual_seed(0)

# Run correctness tests on a few examples
if args.run_correctness:
print("Running correctness tests...")
calculate_diff(64, 512, 7168) # Small test
calculate_diff(64, 7168, 16384) # Medium test
calculate_diff(64, 18432, 7168) # Large test

# Get the benchmark function with the specified tp_size
benchmark = (
get_benchmark_plot_friendly(args.tp_size)
if args.plot_friendly
else get_benchmark(args.tp_size)
)

print(f"Running performance benchmark for TP size = {args.tp_size}...")
benchmark.run(print_data=True, save_path=args.save_path)
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/quantization/fp8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,15 @@ def requant_weight_ue8m0_inplace(weight, weight_scale_inv, weight_block_size):
assert isinstance(weight, torch.nn.Parameter)
assert isinstance(weight_scale_inv, torch.nn.Parameter)

new_weight, new_weight_scale_inv = _requant_weight_ue8m0(
new_weight, new_weight_scale_inv = requant_weight_ue8m0(
weight.to(weight_scale_inv.device), weight_scale_inv, weight_block_size
)

offloader.update_param(weight, new_weight)
weight_scale_inv.data = new_weight_scale_inv


def _requant_weight_ue8m0(
def requant_weight_ue8m0(
weight: torch.Tensor,
weight_scale_inv: torch.Tensor,
weight_block_size: List[int],
Expand Down
Loading