Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"bmm_fp8",
"bmm_mxfp8",
"mm_fp4",
],
"moe": [
Expand Down Expand Up @@ -236,6 +237,16 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cudnn", "cublas", "cutlass"],
"12.0": ["cudnn", "cublas"],
},
"bmm_mxfp8": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cudnn"],
"10.3": ["cudnn"],
"12.0": [],
},
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
# MOE
"trtllm_fp4_block_scale_moe": {
Expand Down
209 changes: 209 additions & 0 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import flashinfer
from flashinfer.autotuner import autotune
from flashinfer.fp8_quantization import mxfp8_quantize
from flashinfer.testing.utils import (
bench_gpu_time,
dequantize_fp8,
Expand Down Expand Up @@ -38,6 +39,8 @@ def run_gemm_test(args):
return testGroupGemmFp8NtGroupwise(args)
elif args.routine == "bmm_fp8":
return testBmmFp8(args)
elif args.routine == "bmm_mxfp8":
return testBmmMxfp8(args)
elif args.routine == "mm_fp4":
return testMmFp4(args)
else:
Expand Down Expand Up @@ -144,6 +147,7 @@ def parse_gemm_args(line, parser):
action="store_true",
help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.",
)
# TODO: add bmm_mxfp8 ?
parser.add_argument(
"--autotune",
action="store_true",
Expand Down Expand Up @@ -757,6 +761,211 @@ def run_backend(backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s):
return res


def testBmmMxfp8(args):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you think it's better, I can merge this with the existing testBmmFp8.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @bkryu for opinion

"""
Test bmm_mxfp8 API.

This test:
1. Generates random input tensors
2. Quantizes input tensors to MXFP8
3. Runs bmm_mxfp8
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)

Args:
args: Parsed command line arguments containing test configuration

Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBmmMxfp8")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")

device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)

## Parse input arguments
backends = args.backends
batch_size = args.batch_size
m = args.m
n = args.n
k = args.k
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cudnn",
]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments

if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

## Prepare input tensors
input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(input, is_sf_swizzled_layout=True)

mat2 = (
torch.randn([batch_size, n, k], device=device, dtype=torch.bfloat16)
.transpose(-2, -1)
.contiguous()
)
mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout=True)

if args.verbose >= 2:
print(f"[VVERBOSE] {input_mxfp8.shape = }")
print(f"[VVERBOSE] {input_mxfp8.dtype = }")
print(f"[VVERBOSE] {mat2_mxfp8.shape = }")
print(f"[VVERBOSE] {mat2_mxfp8.dtype = }")
print(f"[VVERBOSE] {input_scale.shape = }")
print(f"[VVERBOSE] {input_scale.dtype = }")
print(f"[VVERBOSE] {mat2_scale.shape = }")
print(f"[VVERBOSE] {mat2_scale.dtype = }")

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend == "cudnn":
return flashinfer.gemm.bmm_mxfp8(
A=input_mxfp8,
B=mat2_mxfp8,
A_scale=input_scale,
B_scale=mat2_scale,
dtype=res_dtype,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")

has_reference_output = False
if run_refcheck:
reference_output = torch.bmm(input, mat2)
has_reference_output = True

if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for bmm_mxfp8: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(
cur_backend,
input_mxfp8,
mat2_mxfp8,
input_scale,
mat2_scale,
)

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True,
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
)

min_cos_sim = 0.9 # TODO: check if can be increased

tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
if cos_sim < min_cos_sim:
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}"
)

for backend in backends:
backend_name = backend + (
"_autotune"
if (
getattr(args, "autotune", False)
and backend in autotune_supported_backends
)
else ""
)
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * batch_size
# MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales
# Scale tensors are much smaller, so approximate as 1 byte per element for simplicity
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
) * batch_size
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec)

if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["batch_size"] = batch_size
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["out_dtype"] = res_dtype
cur_res["backend"] = backend_name
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res


def testMmFp4(args):
"""
Test mm_fp4 API.
Expand Down
1 change: 1 addition & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
)
from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm import bmm_fp8 as bmm_fp8
from .gemm import bmm_mxfp8 as bmm_mxfp8
from .gemm import mm_fp4 as mm_fp4
from .gemm import mm_fp8 as mm_fp8
from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100
Expand Down
2 changes: 2 additions & 0 deletions flashinfer/gemm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper
from .gemm_base import bmm_fp8 as bmm_fp8
from .gemm_base import bmm_mxfp8 as bmm_mxfp8
from .gemm_base import mm_fp4 as mm_fp4
from .gemm_base import mm_fp8 as mm_fp8
from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100
Expand All @@ -22,6 +23,7 @@
__all__ = [
"SegmentGEMMWrapper",
"bmm_fp8",
"bmm_mxfp8",
"mm_fp4",
"mm_fp8",
"tgv_gemm_sm100",
Expand Down
Loading