Skip to content
Merged
12 changes: 12 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
"bmm_fp8",
"bmm_mxfp8",
"mm_fp4",
"mm_mxfp8",
],
"moe": [
"trtllm_fp4_block_scale_moe",
Expand Down Expand Up @@ -296,6 +297,17 @@ def dtype_str_to_torch_dtype(dtype_str):
"10.3": ["cudnn"],
"12.0": [],
},
"mm_mxfp8": {
"7.5": [],
"8.0": [],
"8.6": [],
"8.9": [],
"9.0": [],
"10.0": ["cutlass"],
"10.3": ["cutlass"],
"11.0": ["cutlass"],
"12.0": [],
},
# Note: mm_fp4 uses support checkers to filter backends, so it is not listed here
# MOE
"trtllm_fp4_block_scale_moe": {
Expand Down
216 changes: 214 additions & 2 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def run_gemm_test(args):
return testBmmMxfp8(args)
elif args.routine == "mm_fp4":
return testMmFp4(args)
elif args.routine == "mm_mxfp8":
return testMmMxfp8(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down Expand Up @@ -147,12 +149,13 @@ def parse_gemm_args(line, parser):
action="store_true",
help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.",
)
# TODO: add bmm_mxfp8 ?
parser.add_argument(
"--autotune",
action="store_true",
default=False,
help=("Enable autotuner warmup for supported routines (mm_fp4 and bmm_fp8)."),
help=(
"Enable autotuner warmup for supported routines (mm_fp4, bmm_fp8, bmm_mxfp8 and mm_mxfp8)."
),
)

args = parser.parse_args(line)
Expand Down Expand Up @@ -1233,3 +1236,212 @@ def run_backend(
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res


def testMmMxfp8(args):
"""
Test mm_mxfp8 API.

This test:
1. Generates random input tensors
2. Quantizes input tensors to MXFP8
3. Runs mm_mxfp8
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)

Args:
args: Parsed command line arguments containing test configuration

Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testMmMxfp8")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")

device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)

## Parse input arguments
backends = args.backends
m = args.m
n = args.n
k = args.k
res_dtype = args.out_dtype
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = [
"cutlass",
]
res = []

backends = filter_backends_by_compute_capability(backends, args.routine, device)
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments

if getattr(args, "autotune", False):
backends_to_remove = []
for cur_backend in backends:
if cur_backend not in autotune_supported_backends:
print(f"[INFO] {cur_backend} backend does not support autotune")
backends_to_remove.append(cur_backend)
for cur_backend in backends_to_remove:
backends.remove(cur_backend)

if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return res

## Prepare input tensors
# Use swizzled layout for optimal performance
is_sf_swizzled_layout = True

input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
input_mxfp8, input_scale = mxfp8_quantize(
input, is_sf_swizzled_layout=is_sf_swizzled_layout
)

mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
mat2_mxfp8, mat2_scale = mxfp8_quantize(
mat2, is_sf_swizzled_layout=is_sf_swizzled_layout
)

if args.verbose >= 2:
print(f"[VVERBOSE] {input_mxfp8.shape = }")
print(f"[VVERBOSE] {input_mxfp8.dtype = }")
print(f"[VVERBOSE] {mat2_mxfp8.shape = }")
print(f"[VVERBOSE] {mat2_mxfp8.dtype = }")
print(f"[VVERBOSE] {input_scale.shape = }")
print(f"[VVERBOSE] {input_scale.dtype = }")
print(f"[VVERBOSE] {mat2_scale.shape = }")
print(f"[VVERBOSE] {mat2_scale.dtype = }")

def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale):
if backend == "cutlass":
return flashinfer.gemm.mm_mxfp8(
a=input_mxfp8,
b=mat2_mxfp8.t(), # mm_mxfp8 expects b.t()
a_descale=input_scale,
b_descale=mat2_scale, # mm_mxfp8 handles swizzled 1D internally
out_dtype=res_dtype,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")

has_reference_output = False
if run_refcheck:
reference_output = torch.mm(input, mat2.t())
has_reference_output = True

if getattr(args, "autotune", False):
warmup_iters = (
args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10
)
for cur_backend in backends:
if cur_backend in autotune_supported_backends:
if args.verbose >= 1:
print(f"[INFO] Autotune warmup for mm_mxfp8: {warmup_iters} iters")
with autotune(True):
for _ in range(warmup_iters):
run_backend(
cur_backend,
input_mxfp8,
mat2_mxfp8,
input_scale,
mat2_scale,
)

# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(
cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale
).detach()
backend_times[cur_backend] = bench_gpu_time(
fn=run_backend,
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
sleep_after_run=True,
enable_cupti=args.use_cupti,
use_cuda_graph=is_cuda_graph_compatible,
cold_l2_cache=True,
input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale),
)

# Minimum cosine similarity for swizzled layout
min_cos_sim = 0.98

tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
if cos_sim < min_cos_sim:
print(
"[ERROR] Output tensor mismatch between reference "
f"{tested_backends[0]} and backend {tested_backends[i]}"
)
if not args.allow_output_mismatch:
raise AssertionError(
"[ERROR] Output tensor mismatch between reference "
f"{tested_backends[0]} and backend {tested_backends[i]} "
f"with {cos_sim=} (expected >= {min_cos_sim})"
)
for backend in backends:
backend_name = backend + (
"_autotune"
if (
getattr(args, "autotune", False)
and backend in autotune_supported_backends
)
else ""
)
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k
# MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales
# Scale tensors are much smaller, so approximate as 1 byte per element for simplicity
problem_bytes = (
m * k * torch.float8_e4m3fn.itemsize
+ n * k * torch.float8_e4m3fn.itemsize
+ m * n * res_dtype.itemsize
)
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec)

if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["out_dtype"] = res_dtype
cur_res["backend"] = backend_name
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
Loading
Loading