From dc0d12847ef0bae39a35509eb13d38499706b57f Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 15 Jan 2026 19:37:45 +0000 Subject: [PATCH 1/7] Add norm benchmarking --- benchmarks/flashinfer_benchmark.py | 12 +- .../routines/flashinfer_benchmark_utils.py | 76 +- benchmarks/routines/norm.py | 1047 +++++++++++++++++ benchmarks/samples/sample_testlist.txt | 62 + 4 files changed, 1191 insertions(+), 6 deletions(-) create mode 100644 benchmarks/routines/norm.py diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index 330d734221..3e564004ef 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -9,6 +9,7 @@ ) from routines.gemm import parse_gemm_args, run_gemm_test from routines.moe import parse_moe_args, run_moe_test +from routines.norm import parse_norm_args, run_norm_test def run_test(args): @@ -26,6 +27,8 @@ def run_test(args): res = run_gemm_test(args) elif args.routine in benchmark_apis["moe"]: res = run_moe_test(args) + elif args.routine in benchmark_apis["norm"]: + res = run_norm_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -34,7 +37,9 @@ def run_test(args): with open(args.output_path, "a") as fout: for cur_res in res: for key in output_column_dict["general"]: - cur_res[key] = getattr(args, key) + # Use getattr with default "" for optional columns like batch_size/hidden_size + # that may not be present in all routine types + cur_res[key] = getattr(args, key, "") output_line = ",".join( [str(cur_res[col]) for col in full_output_columns] @@ -65,7 +70,8 @@ def parse_args(line=sys.argv[1:]): required=True, choices=list(benchmark_apis["attention"]) + list(benchmark_apis["gemm"]) - + list(benchmark_apis["moe"]), + + list(benchmark_apis["moe"]) + + list(benchmark_apis["norm"]), ) args, _ = parser.parse_known_args(line[:]) @@ -156,6 +162,8 @@ def parse_args(line=sys.argv[1:]): args = parse_gemm_args(line, parser) elif args.routine in benchmark_apis["moe"]: args = parse_moe_args(line, parser) + elif args.routine in benchmark_apis["norm"]: + args = parse_norm_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 26909a5dd9..a93b14bbd3 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -15,7 +15,6 @@ ], "attention": [ "page_size", - "batch_size", "s_qo", "s_kv", "num_qo_heads", @@ -37,14 +36,12 @@ "group_size", "tile_size", "scale_major_mode", - "out_dtype", "mma_sm", "use_128x4_sf_layout", "use_nvfp4", ], "moe": [ "num_tokens", - "hidden_size", "intermediate_size", "num_experts", "top_k", @@ -58,7 +55,6 @@ "weight_layout", "use_routing_bias", "use_routing_scales_on_input", - "input_dtype", "weight_dtype", "gated_act", # CUTLASS fused MoE specific @@ -69,7 +65,19 @@ "ep_size", "ep_rank", ], + "norm": [ + "num_heads", + "scale", + "eps", + "enable_pdl", + "use_global_scale", + "is_sf_swizzled_layout", + ], "general": [ + "batch_size", + "hidden_size", + "input_dtype", + "out_dtype", "refcheck", "no_cuda_graph", "use_cupti", @@ -86,6 +94,7 @@ + output_column_dict["attention"] + output_column_dict["gemm"] + output_column_dict["moe"] + + output_column_dict["norm"] + output_column_dict["general"] ) @@ -109,6 +118,13 @@ "trtllm_fp8_per_tensor_scale_moe", "cutlass_fused_moe", ], + "norm": [ + "rmsnorm", + "rmsnorm_quant", + "fused_add_rmsnorm_quant", + "rmsnorm_fp4quant", + "add_rmsnorm_fp4quant", + ], } @@ -289,6 +305,58 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cutlass"], "12.0": [], }, + # NORM + "rmsnorm": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "rmsnorm_quant": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + "fused_add_rmsnorm_quant": { + "7.5": ["cuda"], + "8.0": ["cuda"], + "8.6": ["cuda"], + "8.9": ["cuda"], + "9.0": ["cuda"], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": ["cuda"], + }, + # NORM - FP4 Quantization (Blackwell SM100+ only, CuTe-DSL kernels) + "rmsnorm_fp4quant": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cute-dsl"], + "10.3": ["cute-dsl"], + "12.0": [], + }, + "add_rmsnorm_fp4quant": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cute-dsl"], + "10.3": ["cute-dsl"], + "12.0": [], + }, } diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py new file mode 100644 index 0000000000..9c45ee630f --- /dev/null +++ b/benchmarks/routines/norm.py @@ -0,0 +1,1047 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict + +import numpy as np +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + +from .flashinfer_benchmark_utils import ( + dtype_str_to_torch_dtype, + get_device, + print_perf_metrics, + is_close_stats, + filter_backends_by_compute_capability, +) + + +def run_norm_test(args): + """ + Run a norm test. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.routine == "rmsnorm": + return testRmsnorm(args) + elif args.routine == "rmsnorm_quant": + return testRmsnormQuant(args) + elif args.routine == "fused_add_rmsnorm_quant": + return testFusedAddRmsnormQuant(args) + elif args.routine == "rmsnorm_fp4quant": + return testRmsnormFp4quant(args) + elif args.routine == "add_rmsnorm_fp4quant": + return testAddRmsnormFp4quant(args) + else: + raise ValueError(f"Unsupported routine: {args.routine}") + + +def parse_norm_args(line, parser): + """ + Parse command line arguments for norm test configuration. + + Args: + line: Command line arguments + parser: ArgumentParser object already populated with shared arguments + + Returns: + Parsed argument namespace + """ + parser.add_argument( + "--batch_size", + type=int, + required=True, + help="Batch size.", + ) + parser.add_argument( + "--hidden_size", + type=int, + required=True, + help="Hidden dimension size.", + ) + parser.add_argument( + "--num_heads", + type=int, + required=False, + default=None, + help="Number of heads (for 3D input shape). If not specified, uses 2D shape.", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="bfloat16", + help="Data type of the input tensor.", + ) + parser.add_argument( + "--eps", + type=float, + required=False, + default=1e-6, + help="Epsilon for numerical stability.", + ) + parser.add_argument( + "--enable_pdl", + action="store_true", + default=False, + help="Enable programmatic dependent launch.", + ) + parser.add_argument( + "--scale", + type=float, + required=False, + default=1.0, + help="Scale factor for quantization (used by rmsnorm_quant and fused_add_rmsnorm_quant).", + ) + parser.add_argument( + "--out_dtype", + type=str, + required=False, + default="fp8_e4m3", + choices=["fp8_e4m3", "fp8_e5m2", "nvfp4", "mxfp4"], + help="Output dtype for quantized operations. fp8_e4m3/fp8_e5m2 for FP8 quant; nvfp4/mxfp4 for FP4 quant.", + ) + parser.add_argument( + "--backends", + type=str, + required=False, + nargs="+", + default=["cuda"], + choices=["cuda", "cute-dsl"], + help="Backend to test. Default: cuda. Use cute-dsl for FP4 quantization.", + ) + # FP4 quantization specific arguments (for rmsnorm_fp4quant, add_rmsnorm_fp4quant) + parser.add_argument( + "--use_global_scale", + action="store_true", + default=False, + help="Use global scale factor (NVFP4 format). Default: False", + ) + parser.add_argument( + "--is_sf_swizzled_layout", + action="store_true", + default=False, + help="Use swizzled scale factor layout for tensor core GEMM. Default: False", + ) + + args = parser.parse_args(line) + if args.verbose >= 1: + print(f"[INFO] {args = }") + return args + + +def testRmsnorm(args): + """ + Test rmsnorm API. + + This test: + 1. Generates random input tensors + 2. Runs rmsnorm + 3. Runs reference check + 4. Measures performance metrics (memory bandwidth) + + Note: RMSNorm is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testRmsnorm") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + batch_size = args.batch_size + hidden_size = args.hidden_size + num_heads = args.num_heads + eps = args.eps + enable_pdl = args.enable_pdl + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + if num_heads is not None: + input_shape = (batch_size, num_heads, hidden_size) + else: + input_shape = (batch_size, hidden_size) + + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {weight.shape = }") + + def run_backend(backend, input_tensor, weight): + if backend == "cuda": + return flashinfer.rmsnorm( + input_tensor, weight, eps=eps, enable_pdl=enable_pdl + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference: PyTorch implementation of RMSNorm + has_reference_output = False + if run_refcheck: + rms = torch.sqrt( + torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps + ) + reference_output = (input_tensor.float() / rms * weight.float()).to(input_dtype) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, input_tensor, weight + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor, weight), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2 + ) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for RMSNorm + # Read: input tensor + weight tensor + # Write: output tensor (same shape as input) + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + hidden_size * input_dtype.itemsize # weight read + + num_elements * input_dtype.itemsize # output write + ) + # RMSNorm is memory-bound, so TFLOPS is not the primary metric + # But we compute approximate FLOPS for completeness: + # Per element: square, sum reduction, sqrt, divide, multiply + problem_flops = num_elements * 5 # rough estimate + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testRmsnormQuant(args): + """ + Test rmsnorm_quant API. + + This test: + 1. Generates random input tensors + 2. Runs rmsnorm_quant with quantized output + 3. Runs reference check + 4. Measures performance metrics (memory bandwidth) + + Note: RMSNorm is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testRmsnormQuant") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + batch_size = args.batch_size + hidden_size = args.hidden_size + scale = args.scale + eps = args.eps + enable_pdl = args.enable_pdl + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + + out_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if out_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError( + f"Unsupported out dtype: {args.out_dtype}. Supported dtypes are fp8_e4m3, fp8_e5m2." + ) + ## Done parsing input arguments + + ## Prepare input tensors (2D only for rmsnorm_quant) + input_shape = (batch_size, hidden_size) + + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + out_tensor = torch.empty(input_shape, dtype=out_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {weight.shape = }") + print(f"[VVERBOSE] {out_tensor.dtype = }") + print(f"[VVERBOSE] {scale = }") + + def run_backend(backend, out_tensor, input_tensor, weight): + if backend == "cuda": + flashinfer.norm.rmsnorm_quant( + out_tensor, input_tensor, weight, scale=scale, eps=eps, enable_pdl=enable_pdl + ) + return out_tensor + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference: PyTorch implementation of RMSNorm + quantization + has_reference_output = False + if run_refcheck: + rms = torch.sqrt( + torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps + ) + rmsnorm_output = (input_tensor.float() / rms * weight.float()) + # Quantize to output dtype + reference_output = (rmsnorm_output * scale).clamp( + torch.finfo(out_dtype).min, torch.finfo(out_dtype).max + ).to(out_dtype) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + # Create fresh output tensor for each run + cur_out = torch.empty(input_shape, dtype=out_dtype, device=device) + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, cur_out, input_tensor, weight + ).detach().clone() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, out_tensor, input_tensor, weight), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + # Compare in float for FP8 outputs + ref_float = reference_output.float() + out_float = tested_outputs[i].float() + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats(ref_float, out_float, rtol=1e-1, atol=1e-1) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for RMSNorm + Quant + # Read: input tensor + weight tensor + # Write: output tensor (quantized, smaller dtype) + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + hidden_size * input_dtype.itemsize # weight read + + num_elements * out_dtype.itemsize # output write (quantized) + ) + problem_flops = num_elements * 5 # rough estimate + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = str(out_dtype) + cur_res["scale"] = scale + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testFusedAddRmsnormQuant(args): + """ + Test fused_add_rmsnorm_quant API. + + This test: + 1. Generates random input and residual tensors + 2. Runs fused_add_rmsnorm_quant (residual += input, then RMSNorm with quantized output) + 3. Runs reference check + 4. Measures performance metrics (memory bandwidth) + + Note: This operation is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testFusedAddRmsnormQuant") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + batch_size = args.batch_size + hidden_size = args.hidden_size + scale = args.scale + eps = args.eps + enable_pdl = args.enable_pdl + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + + out_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if out_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]: + raise ValueError( + f"Unsupported out dtype: {args.out_dtype}. Supported dtypes are fp8_e4m3, fp8_e5m2." + ) + ## Done parsing input arguments + + ## Prepare input tensors (2D only for fused_add_rmsnorm_quant) + input_shape = (batch_size, hidden_size) + + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + residual_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + out_tensor = torch.empty(input_shape, dtype=out_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {residual_tensor.shape = }") + print(f"[VVERBOSE] {weight.shape = }") + print(f"[VVERBOSE] {out_tensor.dtype = }") + print(f"[VVERBOSE] {scale = }") + + def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): + if backend == "cuda": + flashinfer.norm.fused_add_rmsnorm_quant( + out_tensor, input_tensor, residual_tensor, weight, + scale=scale, eps=eps, enable_pdl=enable_pdl + ) + return out_tensor + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference: PyTorch implementation of fused add + RMSNorm + quantization + has_reference_output = False + if run_refcheck: + # Clone residual for reference computation since it gets modified + ref_residual = residual_tensor.clone() + # Step 1: residual += input + ref_residual = ref_residual + input_tensor + # Step 2: RMSNorm on residual + rms = torch.sqrt( + torch.mean(ref_residual.float() ** 2, dim=-1, keepdim=True) + eps + ) + rmsnorm_output = (ref_residual.float() / rms * weight.float()) + # Quantize to output dtype + reference_output = (rmsnorm_output * scale).clamp( + torch.finfo(out_dtype).min, torch.finfo(out_dtype).max + ).to(out_dtype) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + # Create fresh tensors for each run (residual is mutated) + cur_out = torch.empty(input_shape, dtype=out_dtype, device=device) + cur_residual = residual_tensor.clone() + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, cur_out, input_tensor, cur_residual, weight + ).detach().clone() + # For timing, use fresh residual each iteration + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, out_tensor, input_tensor, residual_tensor.clone(), weight), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + # Compare in float for FP8 outputs + ref_float = reference_output.float() + out_float = tested_outputs[i].float() + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats(ref_float, out_float, rtol=1e-1, atol=1e-1) + if num_different_elements > 0: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}: " + f"{num_different_elements}/{num_elements} ({num_different_elements_percentage:.2f}%) elements differ" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with {num_different_elements} elements" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for Fused Add + RMSNorm + Quant + # Read: input tensor + residual tensor + weight tensor + # Write: residual tensor (updated) + output tensor (quantized) + num_elements = np.prod(input_shape) + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + num_elements * input_dtype.itemsize # residual read + + hidden_size * input_dtype.itemsize # weight read + + num_elements * input_dtype.itemsize # residual write + + num_elements * out_dtype.itemsize # output write (quantized) + ) + problem_flops = num_elements * 6 # rough estimate (add + rmsnorm ops) + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = str(out_dtype) + cur_res["scale"] = scale + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testRmsnormFp4quant(args): + """ + Test rmsnorm_fp4quant API from flashinfer.cute_dsl. + + This test: + 1. Generates random input tensors + 2. Runs rmsnorm_fp4quant with FP4 quantized output + 3. Runs reference check + 4. Measures performance metrics (memory bandwidth) + + Note: This is a CuTe-DSL kernel requiring SM10.0+ (Blackwell). + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testRmsnormFp4quant") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + + # Default backend to cute-dsl for FP4 quantization routines + if backends == ["cuda"]: + backends = ["cute-dsl"] + + batch_size = args.batch_size + hidden_size = args.hidden_size + num_heads = args.num_heads + eps = args.eps + out_dtype = args.out_dtype + use_global_scale = args.use_global_scale + is_sf_swizzled_layout = args.is_sf_swizzled_layout + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + # Default to nvfp4 if FP8 dtype is specified (for backwards compatibility) + if out_dtype in ["fp8_e4m3", "fp8_e5m2"]: + out_dtype = "nvfp4" + + # Derive block_size from out_dtype + # nvfp4: block_size=16, e4m3 scale factors + # mxfp4: block_size=32, ue8m0 scale factors + if out_dtype == "nvfp4": + block_size = 16 + elif out_dtype == "mxfp4": + block_size = 32 + else: + raise ValueError( + f"Unsupported out_dtype for FP4 quant: {out_dtype}. Supported: nvfp4, mxfp4." + ) + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + if num_heads is not None: + input_shape = (batch_size, num_heads, hidden_size) + else: + input_shape = (batch_size, hidden_size) + + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + + # Prepare global_scale if using NVFP4 format + global_scale = None + if use_global_scale: + global_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {weight.shape = }") + print(f"[VVERBOSE] {out_dtype = }") + print(f"[VVERBOSE] {block_size = }") + print(f"[VVERBOSE] {use_global_scale = }") + print(f"[VVERBOSE] {is_sf_swizzled_layout = }") + + def run_backend(backend, input_tensor, weight): + if backend == "cute-dsl": + return flashinfer.rmsnorm_fp4quant( + input_tensor, weight, eps=eps, block_size=block_size, + global_scale=global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference: PyTorch implementation of RMSNorm + FP4 quantization + has_reference_output = False + if run_refcheck: + rms = torch.sqrt( + torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps + ) + rmsnorm_output = (input_tensor.float() / rms * weight.float()) + # For FP4 quantization reference, we just verify the RMSNorm part + # since FP4 quantization details are complex and implementation-specific + has_reference_output = True + reference_rmsnorm = rmsnorm_output.to(input_dtype) + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + out_fp4, out_scale = run_backend(cur_backend, input_tensor, weight) + outputs[cur_backend] = (out_fp4.detach().clone(), out_scale.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor, weight), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + # For FP4, we just verify output shapes are correct + for i in range(len(tested_backends)): + out_fp4, out_scale = outputs[tested_backends[i]] + if args.verbose >= 2: + print(f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}") + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for RMSNorm + FP4 Quant + # Read: input tensor + weight tensor + # Write: FP4 output (2 elements per byte) + scale factors + num_elements = np.prod(input_shape) + num_scale_elements = num_elements // block_size + # FP4: 2 elements per byte (4 bits each), packed in float4_e2m1fn_x2 + fp4_output_bytes = num_elements // 2 + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + hidden_size * input_dtype.itemsize # weight read + + fp4_output_bytes # FP4 output write + + num_scale_elements # scale factors write (1 byte each) + ) + problem_flops = num_elements * 5 # rough estimate + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = out_dtype + cur_res["eps"] = eps + cur_res["use_global_scale"] = use_global_scale + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testAddRmsnormFp4quant(args): + """ + Test add_rmsnorm_fp4quant API from flashinfer.cute_dsl. + + This test: + 1. Generates random input and residual tensors + 2. Runs add_rmsnorm_fp4quant (h = input + residual, then RMSNorm with FP4 quantized output) + 3. Runs reference check + 4. Measures performance metrics (memory bandwidth) + + Note: This is a CuTe-DSL kernel requiring SM10.0+ (Blackwell). + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testAddRmsnormFp4quant") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + + # Default backend to cute-dsl for FP4 quantization routines + if backends == ["cuda"]: + backends = ["cute-dsl"] + + batch_size = args.batch_size + hidden_size = args.hidden_size + num_heads = args.num_heads + eps = args.eps + out_dtype = args.out_dtype + use_global_scale = args.use_global_scale + is_sf_swizzled_layout = args.is_sf_swizzled_layout + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + # Default to nvfp4 if FP8 dtype is specified (for backwards compatibility) + if out_dtype in ["fp8_e4m3", "fp8_e5m2"]: + out_dtype = "nvfp4" + + # Derive block_size from out_dtype + # nvfp4: block_size=16, e4m3 scale factors + # mxfp4: block_size=32, ue8m0 scale factors + if out_dtype == "nvfp4": + block_size = 16 + elif out_dtype == "mxfp4": + block_size = 32 + else: + raise ValueError( + f"Unsupported out_dtype for FP4 quant: {out_dtype}. Supported: nvfp4, mxfp4." + ) + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + if num_heads is not None: + input_shape = (batch_size, num_heads, hidden_size) + else: + input_shape = (batch_size, hidden_size) + + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + residual_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + weight = torch.randn(hidden_size, dtype=input_dtype, device=device) + + # Prepare global_scale if using NVFP4 format + global_scale = None + if use_global_scale: + global_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {residual_tensor.shape = }") + print(f"[VVERBOSE] {weight.shape = }") + print(f"[VVERBOSE] {out_dtype = }") + print(f"[VVERBOSE] {block_size = }") + print(f"[VVERBOSE] {use_global_scale = }") + print(f"[VVERBOSE] {is_sf_swizzled_layout = }") + + def run_backend(backend, input_tensor, residual_tensor, weight): + if backend == "cute-dsl": + return flashinfer.add_rmsnorm_fp4quant( + input_tensor, residual_tensor, weight, eps=eps, block_size=block_size, + global_scale=global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference: PyTorch implementation of Add + RMSNorm + FP4 quantization + has_reference_output = False + if run_refcheck: + # Step 1: h = input + residual + h = input_tensor + residual_tensor + # Step 2: RMSNorm on h + rms = torch.sqrt( + torch.mean(h.float() ** 2, dim=-1, keepdim=True) + eps + ) + rmsnorm_output = (h.float() / rms * weight.float()) + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + out_fp4, out_scale, out_h = run_backend( + cur_backend, input_tensor, residual_tensor.clone(), weight + ) + outputs[cur_backend] = (out_fp4.detach().clone(), out_scale.detach().clone(), out_h.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor, residual_tensor.clone(), weight), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + # For FP4, we just verify output shapes are correct + for i in range(len(tested_backends)): + out_fp4, out_scale, out_h = outputs[tested_backends[i]] + if args.verbose >= 2: + print(f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}, out_h.shape = {out_h.shape}") + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for Add + RMSNorm + FP4 Quant + # Read: input tensor + residual tensor + weight tensor + # Write: FP4 output + scale factors + h tensor (residual updated) + num_elements = np.prod(input_shape) + num_scale_elements = num_elements // block_size + # FP4: 2 elements per byte (4 bits each) + fp4_output_bytes = num_elements // 2 + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + num_elements * input_dtype.itemsize # residual read + + hidden_size * input_dtype.itemsize # weight read + + fp4_output_bytes # FP4 output write + + num_scale_elements # scale factors write (1 byte each) + + num_elements * input_dtype.itemsize # h output write + ) + problem_flops = num_elements * 6 # rough estimate (add + rmsnorm ops) + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = out_dtype + cur_res["eps"] = eps + cur_res["use_global_scale"] = use_global_scale + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 051b793e57..7cd326c74d 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -48,3 +48,65 @@ --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights" --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 2 --top_k 2 --cutlass_variant nvfp4 --quantized_input --input_dtype float16 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_weights_quantized" --routine cutlass_fused_moe --num_tokens 32 --hidden_size 128 --intermediate_size 128 --num_experts 8 --top_k 2 --cutlass_variant base --input_dtype float16 --tp_size 2 --tp_rank 0 --ep_size 4 --ep_rank 0 -vv --generate_repro_command --case_tag "cutlass_moe_nvfp4_ep_tp" + +## RMSNorm +# Basic RMSNorm with 2D input shape +--routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_llama_hidden" +--routine rmsnorm --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_large_hidden" + +# RMSNorm with 3D input shape (batch, num_heads, head_dim) +--routine rmsnorm --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_3d_gqa" +--routine rmsnorm --batch_size 16 --num_heads 64 --hidden_size 128 --input_dtype float16 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_3d_mha" + +# RMSNorm with PDL enabled +--routine rmsnorm --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --enable_pdl --refcheck -vv --generate_repro_command --case_tag "rmsnorm_pdl" + +## RMSNorm with Quantized Output +# RMSNorm with FP8 e4m3 output +--routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e4m3" +--routine rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_large" + +# RMSNorm with FP8 e5m2 output +--routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype float16 --out_dtype fp8_e5m2 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e5m2" + +## Fused Add + RMSNorm with Quantized Output +# Fused add + RMSNorm with FP8 e4m3 output +--routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_fp8_e4m3" +--routine fused_add_rmsnorm_quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_large" + +# Fused add + RMSNorm with PDL enabled +--routine fused_add_rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --enable_pdl --refcheck -vv --generate_repro_command --case_tag "fused_add_rmsnorm_quant_pdl" + +## RMSNorm with FP4 Quantization (Blackwell SM10.0+ only, cute-dsl backend) +# NVFP4 format (block_size=16, e4m3 scale factors) - nvfp4 is default out_dtype +--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4" +--routine rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_large" + +# NVFP4 with global scale +--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_global" + +# NVFP4 with swizzled scale factor layout for tensor core GEMM +--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_nvfp4_swizzled" + +# MXFP4 format (block_size=32, ue8m0 scale factors) +--routine rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_mxfp4" + +# 3D input shape (batch, num_heads, head_dim) +--routine rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "rmsnorm_fp4quant_3d" + +## Fused Add + RMSNorm with FP4 Quantization (Blackwell SM10.0+ only, cute-dsl backend) +# NVFP4 format (block_size=16, e4m3 scale factors) - nvfp4 is default out_dtype +--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4" +--routine add_rmsnorm_fp4quant --batch_size 64 --hidden_size 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_large" + +# NVFP4 with global scale +--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --use_global_scale -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_global" + +# NVFP4 with swizzled scale factor layout for tensor core GEMM +--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --is_sf_swizzled_layout -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_nvfp4_swizzled" + +# MXFP4 format (block_size=32, ue8m0 scale factors) +--routine add_rmsnorm_fp4quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype mxfp4 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_mxfp4" + +# 3D input shape (batch, num_heads, head_dim) +--routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d" From 62fe9b4175d9cfe1bfb11fccdf5b73ef016573ca Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 15 Jan 2026 21:12:34 +0000 Subject: [PATCH 2/7] Add quantization benchmarking --- benchmarks/flashinfer_benchmark.py | 8 +- .../routines/flashinfer_benchmark_utils.py | 59 ++ benchmarks/routines/norm.py | 112 ++- benchmarks/routines/quantization.py | 819 ++++++++++++++++++ benchmarks/samples/sample_testlist.txt | 41 + 5 files changed, 997 insertions(+), 42 deletions(-) create mode 100644 benchmarks/routines/quantization.py diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index 3e564004ef..3bf3977ad1 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -10,6 +10,7 @@ from routines.gemm import parse_gemm_args, run_gemm_test from routines.moe import parse_moe_args, run_moe_test from routines.norm import parse_norm_args, run_norm_test +from routines.quantization import parse_quantization_args, run_quantization_test def run_test(args): @@ -29,6 +30,8 @@ def run_test(args): res = run_moe_test(args) elif args.routine in benchmark_apis["norm"]: res = run_norm_test(args) + elif args.routine in benchmark_apis["quantization"]: + res = run_quantization_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -71,7 +74,8 @@ def parse_args(line=sys.argv[1:]): choices=list(benchmark_apis["attention"]) + list(benchmark_apis["gemm"]) + list(benchmark_apis["moe"]) - + list(benchmark_apis["norm"]), + + list(benchmark_apis["norm"]) + + list(benchmark_apis["quantization"]), ) args, _ = parser.parse_known_args(line[:]) @@ -164,6 +168,8 @@ def parse_args(line=sys.argv[1:]): args = parse_moe_args(line, parser) elif args.routine in benchmark_apis["norm"]: args = parse_norm_args(line, parser) + elif args.routine in benchmark_apis["quantization"]: + args = parse_quantization_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index a93b14bbd3..6ba35a1fd4 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -73,6 +73,17 @@ "use_global_scale", "is_sf_swizzled_layout", ], + "quantization": [ + "m", + "k", + "is_sf_swizzled_layout", + "alignment", + "enable_pdl", + "global_scale", + "sf_layout", + "do_shuffle", + "sf_vec_size", + ], "general": [ "batch_size", "hidden_size", @@ -95,6 +106,7 @@ + output_column_dict["gemm"] + output_column_dict["moe"] + output_column_dict["norm"] + + output_column_dict["quantization"] + output_column_dict["general"] ) @@ -125,6 +137,12 @@ "rmsnorm_fp4quant", "add_rmsnorm_fp4quant", ], + "quantization": [ + "mxfp8_quantize", + "mxfp4_quantize", + "nvfp4_quantize", + "nvfp4_batched_quantize", + ], } @@ -357,6 +375,47 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cute-dsl"], "12.0": [], }, + # QUANTIZATION + "mxfp8_quantize": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": [], + }, + "mxfp4_quantize": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": [], + }, + "nvfp4_quantize": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": [], + }, + "nvfp4_batched_quantize": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cuda"], + "10.3": ["cuda"], + "12.0": [], + }, } diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index 9c45ee630f..94085145d4 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -384,7 +384,12 @@ def testRmsnormQuant(args): def run_backend(backend, out_tensor, input_tensor, weight): if backend == "cuda": flashinfer.norm.rmsnorm_quant( - out_tensor, input_tensor, weight, scale=scale, eps=eps, enable_pdl=enable_pdl + out_tensor, + input_tensor, + weight, + scale=scale, + eps=eps, + enable_pdl=enable_pdl, ) return out_tensor else: @@ -396,11 +401,13 @@ def run_backend(backend, out_tensor, input_tensor, weight): rms = torch.sqrt( torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps ) - rmsnorm_output = (input_tensor.float() / rms * weight.float()) + rmsnorm_output = input_tensor.float() / rms * weight.float() # Quantize to output dtype - reference_output = (rmsnorm_output * scale).clamp( - torch.finfo(out_dtype).min, torch.finfo(out_dtype).max - ).to(out_dtype) + reference_output = ( + (rmsnorm_output * scale) + .clamp(torch.finfo(out_dtype).min, torch.finfo(out_dtype).max) + .to(out_dtype) + ) has_reference_output = True # Storage for timing results and outputs @@ -410,9 +417,9 @@ def run_backend(backend, out_tensor, input_tensor, weight): # Create fresh output tensor for each run cur_out = torch.empty(input_shape, dtype=out_dtype, device=device) if run_refcheck: - outputs[cur_backend] = run_backend( - cur_backend, cur_out, input_tensor, weight - ).detach().clone() + outputs[cur_backend] = ( + run_backend(cur_backend, cur_out, input_tensor, weight).detach().clone() + ) backend_times[cur_backend] = bench_gpu_time( fn=run_backend, dry_run_iters=args.dry_run_iters, @@ -559,8 +566,13 @@ def testFusedAddRmsnormQuant(args): def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): if backend == "cuda": flashinfer.norm.fused_add_rmsnorm_quant( - out_tensor, input_tensor, residual_tensor, weight, - scale=scale, eps=eps, enable_pdl=enable_pdl + out_tensor, + input_tensor, + residual_tensor, + weight, + scale=scale, + eps=eps, + enable_pdl=enable_pdl, ) return out_tensor else: @@ -577,11 +589,13 @@ def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): rms = torch.sqrt( torch.mean(ref_residual.float() ** 2, dim=-1, keepdim=True) + eps ) - rmsnorm_output = (ref_residual.float() / rms * weight.float()) + rmsnorm_output = ref_residual.float() / rms * weight.float() # Quantize to output dtype - reference_output = (rmsnorm_output * scale).clamp( - torch.finfo(out_dtype).min, torch.finfo(out_dtype).max - ).to(out_dtype) + reference_output = ( + (rmsnorm_output * scale) + .clamp(torch.finfo(out_dtype).min, torch.finfo(out_dtype).max) + .to(out_dtype) + ) has_reference_output = True # Storage for timing results and outputs @@ -592,9 +606,11 @@ def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): cur_out = torch.empty(input_shape, dtype=out_dtype, device=device) cur_residual = residual_tensor.clone() if run_refcheck: - outputs[cur_backend] = run_backend( - cur_backend, cur_out, input_tensor, cur_residual, weight - ).detach().clone() + outputs[cur_backend] = ( + run_backend(cur_backend, cur_out, input_tensor, cur_residual, weight) + .detach() + .clone() + ) # For timing, use fresh residual each iteration backend_times[cur_backend] = bench_gpu_time( fn=run_backend, @@ -602,7 +618,13 @@ def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): repeat_iters=args.num_iters, enable_cupti=args.use_cupti, use_cuda_graph=is_cuda_graph_compatible, - input_args=(cur_backend, out_tensor, input_tensor, residual_tensor.clone(), weight), + input_args=( + cur_backend, + out_tensor, + input_tensor, + residual_tensor.clone(), + weight, + ), ) tested_backends = list(outputs.keys()) @@ -768,8 +790,12 @@ def testRmsnormFp4quant(args): def run_backend(backend, input_tensor, weight): if backend == "cute-dsl": return flashinfer.rmsnorm_fp4quant( - input_tensor, weight, eps=eps, block_size=block_size, - global_scale=global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + input_tensor, + weight, + eps=eps, + block_size=block_size, + global_scale=global_scale, + is_sf_swizzled_layout=is_sf_swizzled_layout, ) else: raise ValueError(f"Unsupported backend: {backend}") @@ -777,14 +803,9 @@ def run_backend(backend, input_tensor, weight): # Reference: PyTorch implementation of RMSNorm + FP4 quantization has_reference_output = False if run_refcheck: - rms = torch.sqrt( - torch.mean(input_tensor.float() ** 2, dim=-1, keepdim=True) + eps - ) - rmsnorm_output = (input_tensor.float() / rms * weight.float()) - # For FP4 quantization reference, we just verify the RMSNorm part + # For FP4 quantization, we verify output shapes and dtypes # since FP4 quantization details are complex and implementation-specific has_reference_output = True - reference_rmsnorm = rmsnorm_output.to(input_dtype) # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} @@ -792,7 +813,10 @@ def run_backend(backend, input_tensor, weight): for cur_backend in backends: if run_refcheck: out_fp4, out_scale = run_backend(cur_backend, input_tensor, weight) - outputs[cur_backend] = (out_fp4.detach().clone(), out_scale.detach().clone()) + outputs[cur_backend] = ( + out_fp4.detach().clone(), + out_scale.detach().clone(), + ) backend_times[cur_backend] = bench_gpu_time( fn=run_backend, dry_run_iters=args.dry_run_iters, @@ -809,7 +833,9 @@ def run_backend(backend, input_tensor, weight): for i in range(len(tested_backends)): out_fp4, out_scale = outputs[tested_backends[i]] if args.verbose >= 2: - print(f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}") + print( + f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}" + ) for backend in backends: if len(backend_times[backend]) > 0: @@ -956,22 +982,21 @@ def testAddRmsnormFp4quant(args): def run_backend(backend, input_tensor, residual_tensor, weight): if backend == "cute-dsl": return flashinfer.add_rmsnorm_fp4quant( - input_tensor, residual_tensor, weight, eps=eps, block_size=block_size, - global_scale=global_scale, is_sf_swizzled_layout=is_sf_swizzled_layout + input_tensor, + residual_tensor, + weight, + eps=eps, + block_size=block_size, + global_scale=global_scale, + is_sf_swizzled_layout=is_sf_swizzled_layout, ) else: raise ValueError(f"Unsupported backend: {backend}") - # Reference: PyTorch implementation of Add + RMSNorm + FP4 quantization + # Reference: For FP4 quantization, we verify output shapes and dtypes + # since FP4 quantization details are complex and implementation-specific has_reference_output = False if run_refcheck: - # Step 1: h = input + residual - h = input_tensor + residual_tensor - # Step 2: RMSNorm on h - rms = torch.sqrt( - torch.mean(h.float() ** 2, dim=-1, keepdim=True) + eps - ) - rmsnorm_output = (h.float() / rms * weight.float()) has_reference_output = True # Storage for timing results and outputs @@ -982,7 +1007,11 @@ def run_backend(backend, input_tensor, residual_tensor, weight): out_fp4, out_scale, out_h = run_backend( cur_backend, input_tensor, residual_tensor.clone(), weight ) - outputs[cur_backend] = (out_fp4.detach().clone(), out_scale.detach().clone(), out_h.detach().clone()) + outputs[cur_backend] = ( + out_fp4.detach().clone(), + out_scale.detach().clone(), + out_h.detach().clone(), + ) backend_times[cur_backend] = bench_gpu_time( fn=run_backend, dry_run_iters=args.dry_run_iters, @@ -999,7 +1028,9 @@ def run_backend(backend, input_tensor, residual_tensor, weight): for i in range(len(tested_backends)): out_fp4, out_scale, out_h = outputs[tested_backends[i]] if args.verbose >= 2: - print(f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}, out_h.shape = {out_h.shape}") + print( + f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}, out_h.shape = {out_h.shape}" + ) for backend in backends: if len(backend_times[backend]) > 0: @@ -1044,4 +1075,3 @@ def run_backend(backend, input_tensor, residual_tensor, weight): cur_res["case_tag"] = args.case_tag res.append(cur_res) return res - diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py new file mode 100644 index 0000000000..6b54bd331e --- /dev/null +++ b/benchmarks/routines/quantization.py @@ -0,0 +1,819 @@ +""" +Copyright (c) 2023 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from collections import defaultdict + +import numpy as np +import torch + +import flashinfer +from flashinfer.testing.utils import bench_gpu_time + +from .flashinfer_benchmark_utils import ( + dtype_str_to_torch_dtype, + get_device, + print_perf_metrics, + is_close_stats, + filter_backends_by_compute_capability, +) + + +def run_quantization_test(args): + """ + Run a quantization test. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.routine == "mxfp8_quantize": + return testMxfp8Quantize(args) + elif args.routine == "mxfp4_quantize": + return testMxfp4Quantize(args) + elif args.routine == "nvfp4_quantize": + return testNvfp4Quantize(args) + elif args.routine == "nvfp4_batched_quantize": + return testNvfp4BatchedQuantize(args) + else: + raise ValueError(f"Unsupported routine: {args.routine}") + + +def parse_quantization_args(line, parser): + """ + Parse command line arguments for quantization test configuration. + + Args: + line: Command line arguments + parser: ArgumentParser object already populated with shared arguments + + Returns: + Parsed argument namespace + """ + parser.add_argument( + "--m", + type=int, + required=True, + help="Number of rows in input tensor.", + ) + parser.add_argument( + "--k", + type=int, + required=True, + help="Number of columns in input tensor (must be divisible by 32).", + ) + parser.add_argument( + "--input_dtype", + type=str, + required=False, + default="bfloat16", + choices=["bfloat16", "float16"], + help="Data type of the input tensor.", + ) + parser.add_argument( + "--is_sf_swizzled_layout", + action="store_true", + default=True, + help="Use swizzled layout for scale factors. Default: True", + ) + parser.add_argument( + "--no_sf_swizzled_layout", + action="store_true", + default=False, + help="Disable swizzled layout for scale factors.", + ) + parser.add_argument( + "--alignment", + type=int, + required=False, + default=32, + help="sfVecSize for quantization. Default: 32", + ) + parser.add_argument( + "--enable_pdl", + action="store_true", + default=False, + help="Enable programmatic dependent launch.", + ) + parser.add_argument( + "--backends", + type=str, + required=False, + nargs="+", + default=["cuda"], + choices=["cuda"], + help="Backend to test. Default: cuda", + ) + # FP4 quantization specific arguments + parser.add_argument( + "--batch_size", + type=int, + required=False, + default=None, + help="Batch size for batched quantization (nvfp4_batched_quantize).", + ) + parser.add_argument( + "--global_scale", + type=float, + required=False, + default=1.0, + help="Global scale factor for NVFP4 quantization. Default: 1.0", + ) + parser.add_argument( + "--sf_layout", + type=str, + required=False, + default="128x4", + choices=["128x4", "8x4", "linear"], + help="Scale factor layout for NVFP4 quantization. Default: 128x4", + ) + parser.add_argument( + "--do_shuffle", + action="store_true", + default=False, + help="Shuffle scale factors for TRTLLM backend (nvfp4_quantize only).", + ) + parser.add_argument( + "--sf_vec_size", + type=int, + required=False, + default=16, + help="Scale factor vector size for NVFP4 quantization. Default: 16", + ) + + args = parser.parse_args(line) + + # Handle swizzled layout flag + if args.no_sf_swizzled_layout: + args.is_sf_swizzled_layout = False + + if args.verbose >= 1: + print(f"[INFO] {args = }") + return args + + +def testMxfp8Quantize(args): + """ + Test mxfp8_quantize API. + + This test: + 1. Generates random input tensors + 2. Runs mxfp8_quantize + 3. Runs reference check (via dequantize round-trip) + 4. Measures performance metrics (memory bandwidth) + + Note: Quantization is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMxfp8Quantize") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + m = args.m + k = args.k + is_sf_swizzled_layout = args.is_sf_swizzled_layout + alignment = args.alignment + enable_pdl = args.enable_pdl + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + # Validate k is divisible by 32 (sf_vec_size) + if k % 32 != 0: + raise ValueError(f"k ({k}) must be divisible by 32 (sf_vec_size)") + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + input_shape = (m, k) + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {is_sf_swizzled_layout = }") + print(f"[VVERBOSE] {alignment = }") + print(f"[VVERBOSE] {enable_pdl = }") + + def run_backend(backend, input_tensor): + if backend == "cuda": + return flashinfer.mxfp8_quantize( + input_tensor, + is_sf_swizzled_layout=is_sf_swizzled_layout, + alignment=alignment, + enable_pdl=enable_pdl, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference check via dequantize round-trip + has_reference_output = False + if run_refcheck: + # For mxfp8, we verify by dequantizing and comparing + # This tests the quantize->dequantize round-trip + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + x_q, sf = run_backend(cur_backend, input_tensor) + outputs[cur_backend] = (x_q.detach().clone(), sf.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + x_q, sf = outputs[tested_backends[i]] + if args.verbose >= 2: + print( + f"[VVERBOSE] Backend {tested_backends[i]}: " + f"x_q.shape = {x_q.shape}, x_q.dtype = {x_q.dtype}, " + f"sf.shape = {sf.shape}, sf.dtype = {sf.dtype}" + ) + # Dequantize and compare with original + # Note: mxfp8_dequantize_host is a HOST function, so tensors must be on CPU + # and expects uint8 dtype + try: + x_q_cpu = x_q.cpu().view(torch.uint8) + sf_cpu = sf.cpu().view(torch.uint8).reshape(-1) + dequantized = flashinfer.mxfp8_dequantize_host( + x_q_cpu, sf_cpu, is_sf_swizzled_layout=is_sf_swizzled_layout + ) + # Move back to GPU for comparison + dequantized = dequantized.to(input_tensor.device) + # Compare with original input (allowing for quantization error) + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + input_tensor.float(), dequantized, rtol=0.5, atol=0.5 + ) + if args.verbose >= 2: + print( + f"[VVERBOSE] Round-trip error: {num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ" + ) + except Exception as e: + if args.verbose >= 1: + print(f"[WARNING] Dequantize check failed: {e}") + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for mxfp8_quantize + # Read: input tensor + # Write: quantized tensor (fp8) + scale factors + num_elements = m * k + sf_vec_size = 32 + num_scale_factors = num_elements // sf_vec_size + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + num_elements * 1 # quantized output write (fp8 = 1 byte) + + num_scale_factors * 1 # scale factors write (1 byte each) + ) + # Quantization is memory-bound, TFLOPS not primary metric + problem_flops = num_elements * 3 # rough estimate (scale, clamp, convert) + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["alignment"] = alignment + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testMxfp4Quantize(args): + """ + Test mxfp4_quantize API. + + This test: + 1. Generates random input tensors + 2. Runs mxfp4_quantize + 3. Runs reference check (via dequantize round-trip) + 4. Measures performance metrics (memory bandwidth) + + Note: Quantization is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMxfp4Quantize") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + m = args.m + k = args.k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + input_shape = (m, k) + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + + def run_backend(backend, input_tensor): + if backend == "cuda": + return flashinfer.mxfp4_quantize(input_tensor) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Reference check via dequantize round-trip + has_reference_output = False + if run_refcheck: + has_reference_output = True + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + x_q, sf = run_backend(cur_backend, input_tensor) + outputs[cur_backend] = (x_q.detach().clone(), sf.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + x_q, sf = outputs[tested_backends[i]] + if args.verbose >= 2: + print( + f"[VVERBOSE] Backend {tested_backends[i]}: " + f"x_q.shape = {x_q.shape}, x_q.dtype = {x_q.dtype}, " + f"sf.shape = {sf.shape}, sf.dtype = {sf.dtype}" + ) + # Dequantize and compare with original + try: + dequantized = flashinfer.mxfp4_dequantize(x_q, sf) + # Move back to GPU for comparison + dequantized = dequantized.to(input_tensor.device) + # Compare with original input (allowing for quantization error) + ( + num_different_elements, + num_elements, + num_different_elements_percentage, + ) = is_close_stats( + input_tensor.float(), dequantized, rtol=0.5, atol=0.5 + ) + if args.verbose >= 2: + print( + f"[VVERBOSE] Round-trip error: {num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ" + ) + except Exception as e: + if args.verbose >= 1: + print(f"[WARNING] Dequantize check failed: {e}") + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for mxfp4_quantize + # Read: input tensor + # Write: quantized tensor (fp4 = 0.5 bytes per element) + scale factors + num_elements = m * k + sf_vec_size = 32 # mxfp4 uses sf_vec_size=32 + num_scale_factors = num_elements // sf_vec_size + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + num_elements // 2 # quantized output write (fp4 = 0.5 byte) + + num_scale_factors * 1 # scale factors write (1 byte each, ue8m0) + ) + problem_flops = num_elements * 3 # rough estimate + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testNvfp4Quantize(args): + """ + Test nvfp4_quantize API. + + This test: + 1. Generates random input tensors + 2. Runs nvfp4_quantize with specified layout + 3. Verifies output shapes + 4. Measures performance metrics (memory bandwidth) + + Note: Quantization is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + from flashinfer.fp4_quantization import SfLayout + + if args.verbose >= 1: + print("[INFO] Running testNvfp4Quantize") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + m = args.m + k = args.k + global_scale = args.global_scale + sf_layout_str = args.sf_layout + do_shuffle = args.do_shuffle + sf_vec_size = args.sf_vec_size + enable_pdl = args.enable_pdl + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + # Convert sf_layout string to enum + sf_layout_map = { + "128x4": SfLayout.layout_128x4, + "8x4": SfLayout.layout_8x4, + "linear": SfLayout.layout_linear, + } + sf_layout = sf_layout_map[sf_layout_str] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + input_shape = (m, k) + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + global_sf_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {global_scale = }") + print(f"[VVERBOSE] {sf_layout_str = }") + print(f"[VVERBOSE] {do_shuffle = }") + print(f"[VVERBOSE] {sf_vec_size = }") + print(f"[VVERBOSE] {enable_pdl = }") + + def run_backend(backend, input_tensor, global_sf_tensor): + if backend == "cuda": + return flashinfer.nvfp4_quantize( + input_tensor, + global_sf_tensor, + sfLayout=sf_layout, + do_shuffle=do_shuffle, + sf_vec_size=sf_vec_size, + enable_pdl=enable_pdl, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + x_q, sf = run_backend(cur_backend, input_tensor, global_sf_tensor) + outputs[cur_backend] = (x_q.detach().clone(), sf.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor, global_sf_tensor), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck: + for i in range(len(tested_backends)): + x_q, sf = outputs[tested_backends[i]] + if args.verbose >= 2: + print( + f"[VVERBOSE] Backend {tested_backends[i]}: " + f"x_q.shape = {x_q.shape}, x_q.dtype = {x_q.dtype}, " + f"sf.shape = {sf.shape}, sf.dtype = {sf.dtype}" + ) + # Verify output shape (M, K/2) for FP4 + expected_shape = (m, k // 2) + if x_q.shape != expected_shape: + print( + f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for nvfp4_quantize + # Read: input tensor + global_sf + # Write: quantized tensor (fp4 = 0.5 bytes per element) + scale factors + num_elements = m * k + num_scale_factors = num_elements // sf_vec_size + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + 4 # global_sf read (float32) + + num_elements // 2 # quantized output write (fp4 = 0.5 byte) + + num_scale_factors * 1 # scale factors write (1 byte each, e4m3) + ) + problem_flops = num_elements * 3 # rough estimate + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["global_scale"] = global_scale + cur_res["sf_layout"] = sf_layout_str + cur_res["do_shuffle"] = do_shuffle + cur_res["sf_vec_size"] = sf_vec_size + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testNvfp4BatchedQuantize(args): + """ + Test nvfp4_batched_quantize API. + + This test: + 1. Generates random batched input tensors + 2. Runs nvfp4_batched_quantize + 3. Verifies output shapes + 4. Measures performance metrics (memory bandwidth) + + Note: Quantization is memory-bandwidth bound, so TB/sec is the primary metric. + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testNvfp4BatchedQuantize") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends[:] # Make a copy to avoid modifying the original + batch_size = args.batch_size + m = args.m + k = args.k + global_scale = args.global_scale + sf_vec_size = args.sf_vec_size + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + res = [] + + if batch_size is None: + raise ValueError("--batch_size is required for nvfp4_batched_quantize") + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + input_dtype = dtype_str_to_torch_dtype(args.input_dtype) + if input_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported input dtype: {args.input_dtype}. Supported dtypes are bfloat16, float16." + ) + ## Done parsing input arguments + + ## Prepare input tensors + input_shape = (batch_size, m, k) + input_tensor = torch.randn(input_shape, dtype=input_dtype, device=device) + global_sf_tensor = torch.tensor([global_scale], dtype=torch.float32, device=device) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_tensor.shape = }") + print(f"[VVERBOSE] {input_tensor.dtype = }") + print(f"[VVERBOSE] {global_scale = }") + print(f"[VVERBOSE] {sf_vec_size = }") + + def run_backend(backend, input_tensor, global_sf_tensor): + if backend == "cuda": + return flashinfer.nvfp4_batched_quantize( + input_tensor, + global_sf_tensor, + sf_vec_size=sf_vec_size, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + x_q, sf = run_backend(cur_backend, input_tensor, global_sf_tensor) + outputs[cur_backend] = (x_q.detach().clone(), sf.detach().clone()) + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + input_args=(cur_backend, input_tensor, global_sf_tensor), + ) + + tested_backends = list(outputs.keys()) + if len(tested_backends) > 0: + if run_refcheck: + for i in range(len(tested_backends)): + x_q, sf = outputs[tested_backends[i]] + if args.verbose >= 2: + print( + f"[VVERBOSE] Backend {tested_backends[i]}: " + f"x_q.shape = {x_q.shape}, x_q.dtype = {x_q.dtype}, " + f"sf.shape = {sf.shape}, sf.dtype = {sf.dtype}" + ) + # Verify output shape (B, M, K/2) for FP4 + expected_shape = (batch_size, m, k // 2) + if x_q.shape != expected_shape: + print( + f"[WARNING] Unexpected output shape: {x_q.shape}, expected {expected_shape}" + ) + + for backend in backends: + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + + # Memory bandwidth calculation for nvfp4_batched_quantize + # Read: input tensor + global_sf + # Write: quantized tensor (fp4 = 0.5 bytes per element) + scale factors + num_elements = batch_size * m * k + num_scale_factors = num_elements // sf_vec_size + problem_bytes = ( + num_elements * input_dtype.itemsize # input read + + 4 # global_sf read (float32) + + num_elements // 2 # quantized output write (fp4 = 0.5 byte) + + num_scale_factors * 1 # scale factors write (1 byte each) + ) + problem_flops = num_elements * 3 # rough estimate + tflops = problem_flops / (10**9 * median_time) + tb_per_sec = problem_bytes / (10**9 * median_time) + + print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["batch_size"] = batch_size + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["global_scale"] = global_scale + cur_res["sf_vec_size"] = sf_vec_size + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + diff --git a/benchmarks/samples/sample_testlist.txt b/benchmarks/samples/sample_testlist.txt index 7cd326c74d..24c4dd1d9a 100644 --- a/benchmarks/samples/sample_testlist.txt +++ b/benchmarks/samples/sample_testlist.txt @@ -110,3 +110,44 @@ # 3D input shape (batch, num_heads, head_dim) --routine add_rmsnorm_fp4quant --batch_size 32 --num_heads 32 --hidden_size 128 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "add_rmsnorm_fp4quant_3d" + +## Quantization (Blackwell SM10.0+ only) +# MxFP8 Quantization - basic +--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_basic" +--routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp8_quantize_large" + +# MxFP8 Quantization - float16 input +--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype float16 -vv --generate_repro_command --case_tag "mxfp8_quantize_fp16" + +# MxFP8 Quantization - with swizzled layout disabled +--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --no_sf_swizzled_layout -vv --generate_repro_command --case_tag "mxfp8_quantize_no_swizzle" + +# MxFP8 Quantization - with PDL enabled +--routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --enable_pdl -vv --generate_repro_command --case_tag "mxfp8_quantize_pdl" + +# MxFP8 Quantization - with refcheck (round-trip verification) +--routine mxfp8_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp8_quantize_refcheck" + +# MxFP4 Quantization (Blackwell SM10.0+ only) +--routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp4_quantize_basic" +--routine mxfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 -vv --generate_repro_command --case_tag "mxfp4_quantize_large" +--routine mxfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp4_quantize_refcheck" + +# NVFP4 Quantization (Blackwell SM10.0+ only) +# With 128x4 layout (default, for large tileN GEMMs) +--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_128x4" +--routine nvfp4_quantize --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 128x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_128x4_large" + +# With 8x4 layout (for small tileN GEMMs) +--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --sf_layout 8x4 -vv --generate_repro_command --case_tag "nvfp4_quantize_8x4" + +# With shuffle for TRTLLM backend +--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --do_shuffle -vv --generate_repro_command --case_tag "nvfp4_quantize_shuffle" + +# With PDL enabled +--routine nvfp4_quantize --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 --enable_pdl -vv --generate_repro_command --case_tag "nvfp4_quantize_pdl" + +# NVFP4 Batched Quantization (Blackwell SM10.0+ only) +--routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_basic" +--routine nvfp4_batched_quantize --batch_size 8 --m 2048 --k 8192 --input_dtype bfloat16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_large" +--routine nvfp4_batched_quantize --batch_size 4 --m 1024 --k 4096 --input_dtype float16 --global_scale 1.0 -vv --generate_repro_command --case_tag "nvfp4_batched_fp16" From 7517a05781ab8f93c24c9903ac14794326f1aa35 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 15 Jan 2026 21:19:08 +0000 Subject: [PATCH 3/7] First fixes --- benchmarks/flashinfer_benchmark.py | 7 +- benchmarks/routines/norm.py | 154 ++++++++++++++-------------- benchmarks/routines/quantization.py | 129 ++++++++++++----------- 3 files changed, 145 insertions(+), 145 deletions(-) diff --git a/benchmarks/flashinfer_benchmark.py b/benchmarks/flashinfer_benchmark.py index 3bf3977ad1..f86869b702 100644 --- a/benchmarks/flashinfer_benchmark.py +++ b/benchmarks/flashinfer_benchmark.py @@ -40,9 +40,10 @@ def run_test(args): with open(args.output_path, "a") as fout: for cur_res in res: for key in output_column_dict["general"]: - # Use getattr with default "" for optional columns like batch_size/hidden_size - # that may not be present in all routine types - cur_res[key] = getattr(args, key, "") + # Only set from args if the routine hasn't already set a value + # This preserves routine-specific formatting while providing defaults + if key not in cur_res or cur_res[key] == "": + cur_res[key] = getattr(args, key, "") output_line = ",".join( [str(cur_res[col]) for col in full_output_columns] diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index 94085145d4..fec3bba4ce 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2023 by FlashInfer team. +Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -293,20 +293,20 @@ def run_backend(backend, input_tensor, weight): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["num_heads"] = num_heads if num_heads else "" - cur_res["input_dtype"] = str(input_dtype) - cur_res["eps"] = eps - cur_res["enable_pdl"] = enable_pdl - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -472,21 +472,21 @@ def run_backend(backend, out_tensor, input_tensor, weight): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["out_dtype"] = str(out_dtype) - cur_res["scale"] = scale - cur_res["eps"] = eps - cur_res["enable_pdl"] = enable_pdl - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = str(out_dtype) + cur_res["scale"] = scale + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -672,21 +672,21 @@ def run_backend(backend, out_tensor, input_tensor, residual_tensor, weight): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["input_dtype"] = str(input_dtype) - cur_res["out_dtype"] = str(out_dtype) - cur_res["scale"] = scale - cur_res["eps"] = eps - cur_res["enable_pdl"] = enable_pdl - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = str(out_dtype) + cur_res["scale"] = scale + cur_res["eps"] = eps + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -861,22 +861,22 @@ def run_backend(backend, input_tensor, weight): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["num_heads"] = num_heads if num_heads else "" - cur_res["input_dtype"] = str(input_dtype) - cur_res["out_dtype"] = out_dtype - cur_res["eps"] = eps - cur_res["use_global_scale"] = use_global_scale - cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = out_dtype + cur_res["eps"] = eps + cur_res["use_global_scale"] = use_global_scale + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -1058,20 +1058,20 @@ def run_backend(backend, input_tensor, residual_tensor, weight): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["num_heads"] = num_heads if num_heads else "" - cur_res["input_dtype"] = str(input_dtype) - cur_res["out_dtype"] = out_dtype - cur_res["eps"] = eps - cur_res["use_global_scale"] = use_global_scale - cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["num_heads"] = num_heads if num_heads else "" + cur_res["input_dtype"] = str(input_dtype) + cur_res["out_dtype"] = out_dtype + cur_res["eps"] = eps + cur_res["use_global_scale"] = use_global_scale + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index 6b54bd331e..1b83911e9c 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -1,5 +1,5 @@ """ -Copyright (c) 2023 by FlashInfer team. +Copyright (c) 2025 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -328,22 +328,22 @@ def run_backend(backend, input_tensor): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["m"] = m - cur_res["k"] = k - cur_res["input_dtype"] = str(input_dtype) - cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout - cur_res["alignment"] = alignment - cur_res["enable_pdl"] = enable_pdl - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["is_sf_swizzled_layout"] = is_sf_swizzled_layout + cur_res["alignment"] = alignment + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -485,19 +485,19 @@ def run_backend(backend, input_tensor): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["m"] = m - cur_res["k"] = k - cur_res["input_dtype"] = str(input_dtype) - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -647,24 +647,24 @@ def run_backend(backend, input_tensor, global_sf_tensor): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["m"] = m - cur_res["k"] = k - cur_res["input_dtype"] = str(input_dtype) - cur_res["global_scale"] = global_scale - cur_res["sf_layout"] = sf_layout_str - cur_res["do_shuffle"] = do_shuffle - cur_res["sf_vec_size"] = sf_vec_size - cur_res["enable_pdl"] = enable_pdl - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["global_scale"] = global_scale + cur_res["sf_layout"] = sf_layout_str + cur_res["do_shuffle"] = do_shuffle + cur_res["sf_vec_size"] = sf_vec_size + cur_res["enable_pdl"] = enable_pdl + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res @@ -799,21 +799,20 @@ def run_backend(backend, input_tensor, global_sf_tensor): print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec) - if args.output_path is not None: - cur_res = defaultdict(str) - cur_res["routine"] = args.routine - cur_res["median_time"] = median_time - cur_res["std_time"] = std_time - cur_res["tflops"] = tflops - cur_res["tb_per_sec"] = tb_per_sec - cur_res["batch_size"] = batch_size - cur_res["m"] = m - cur_res["k"] = k - cur_res["input_dtype"] = str(input_dtype) - cur_res["global_scale"] = global_scale - cur_res["sf_vec_size"] = sf_vec_size - cur_res["backend"] = backend - cur_res["case_tag"] = args.case_tag - res.append(cur_res) + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["batch_size"] = batch_size + cur_res["m"] = m + cur_res["k"] = k + cur_res["input_dtype"] = str(input_dtype) + cur_res["global_scale"] = global_scale + cur_res["sf_vec_size"] = sf_vec_size + cur_res["backend"] = backend + cur_res["case_tag"] = args.case_tag + res.append(cur_res) return res - From 5ff09b20f455217107aaf58a7e003e89f3b1788e Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 15 Jan 2026 21:45:14 +0000 Subject: [PATCH 4/7] Second fixes --- benchmarks/README.md | 77 ++++++++++++++++++- .../routines/flashinfer_benchmark_utils.py | 12 +-- benchmarks/routines/norm.py | 60 ++------------- benchmarks/routines/quantization.py | 7 ++ 4 files changed, 94 insertions(+), 62 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 162c166cee..66c8fd49b6 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -5,11 +5,11 @@ The aim of `flashinfer_benchmark.py` is to provide a single framework for benchm ## Overview This framework provides tools to: -- Benchmark FlashInfer's Attention, GEMM, and MOE API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, and TensorRT-LLM +- Benchmark FlashInfer's Attention, GEMM, MOE, Norm, and Quantization API performance from different kernel backends such as FlashAttention2/3, cuDNN, cuBLAS, CUTLASS, CuTe-DSL, and TensorRT-LLM - Compare performance across different configurations -- Batch performance test multiple attention test cases +- Batch performance test multiple test cases -Currently supports testing most attention, gemm, and fused MOE APIs: +Currently supports testing attention, gemm, fused MOE, normalization, and quantization APIs: - Attention: - `BatchDecodeWithPagedKVCacheWrapper` - Decode attention with paged KV cache. - Also supports computationally similar `cudnn_batch_decode_with_kv_cache` and `trtllm_batch_decode_with_kv_cache`. @@ -29,6 +29,17 @@ Currently supports testing most attention, gemm, and fused MOE APIs: - `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling. - `trtllm_fp8_per_tensor_scale_moe` - MOE with FP8 quantized weights and per-tensor scaling. - `cutlass_fused_moe` - CUTLASS fused MoE (base/fp8/nvfp4 variants with optional TP/EP) +- Norm: + - `rmsnorm` - Root Mean Square Layer Normalization. + - `rmsnorm_quant` - RMSNorm with FP8 quantized output. + - `fused_add_rmsnorm_quant` - Fused residual add + RMSNorm with FP8 quantized output. + - `rmsnorm_fp4quant` - RMSNorm with FP4 quantized output (CuTe-DSL, Blackwell SM10.0+). + - `add_rmsnorm_fp4quant` - Fused residual add + RMSNorm with FP4 quantized output (CuTe-DSL, Blackwell SM10.0+). +- Quantization: + - `mxfp8_quantize` - Quantize tensor to MxFP8 format (Blackwell SM10.0+). + - `mxfp4_quantize` - Quantize tensor to MxFP4 format (Blackwell SM10.0+). + - `nvfp4_quantize` - Quantize tensor to NVFP4 format with configurable scale factor layout (Blackwell SM10.0+). + - `nvfp4_batched_quantize` - Batched NVFP4 quantization (Blackwell SM10.0+). ## Quick Start ### Single Test Run @@ -81,6 +92,20 @@ $ python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper [PERF] fa2 :: median time 0.495 ms; std 0.006 ms; achieved tflops 219.336 TFLOPs/sec; achieved tb_per_sec 1.736 TB/sec [PERF] cutlass :: median time 0.530 ms; std 0.002 ms; achieved tflops 204.674 TFLOPs/sec; achieved tb_per_sec 1.620 TB/sec [PERF] cudnn :: median time 0.313 ms; std 0.000 ms; achieved tflops 346.715 TFLOPs/sec; achieved tb_per_sec 2.745 TB/sec + +# RMSNorm with FP8 quantized output +$ python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e4m3" +[INFO] args = Namespace(routine='rmsnorm_quant', ...) +[INFO] Running testRmsnormQuant +[INFO] FlashInfer version: 0.6.1 +[PERF] cuda :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.131 TFLOPs/sec; achieved tb_per_sec 0.132 TB/sec + +# MxFP8 Quantization (Blackwell SM10.0+ only) +$ python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp8_quantize" +[INFO] args = Namespace(routine='mxfp8_quantize', ...) +[INFO] Running testMxfp8Quantize +[INFO] FlashInfer version: 0.6.1 +[PERF] cuda :: median time 0.041 ms; std 0.000 ms; achieved tflops 1.228 TFLOPs/sec; achieved tb_per_sec 1.240 TB/sec ``` ### Batch Testing @@ -104,7 +129,7 @@ The output CSV will contain detailed metrics including: ### General Flags | Flag | Description | |--------------------------|-------------------------------------------------------------------------------------------------------------| -| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, `BatchPrefillWithRaggedKVCacheWrapper`, `BatchMLAPagedAttentionWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`, `bmm_fp8`, `mm_fp4`, `trtllm_fp4_block_scale_moe`, `trtllm_fp8_block_scale_moe`, `trtllm_fp8_per_tensor_scale_moe`, `cutlass_fused_moe` | +| `--routine` | Test routine to run. See [Overview](#overview) for full list including attention, GEMM, MOE, norm, and quantization routines. | | `--num_iters` | Number of iterations for performance measurement | | `--dry_run_iters` | Number of warmup iterations | | `--no_cuda_graph` | Disable CUDA graph to execute kernels outside of the graph. | @@ -198,6 +223,38 @@ Notes: - FP8 MOE kernels require integer values for group parameters, while FP4 MOE kernels accept optional values. - CUTLASS fused MoE (`cutlass_fused_moe`) ignores `--routing_method`, `--n_group`, and `--topk_group`; it computes routing via softmax+top-k internally from the provided logits. +### Norm Flags +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--batch_size` | Batch size (number of sequences) | +| `--hidden_size` | Hidden dimension size | +| `--num_heads` | Number of heads for 3D input shape (batch, num_heads, hidden_size). Optional; if not set, uses 2D shape. | +| `--input_dtype` | Input data type: `bfloat16` (default) or `float16` | +| `--eps` | Epsilon for numerical stability. Default: 1e-6 | +| `--enable_pdl` | Enable programmatic dependent launch | +| `--scale` | Scale factor for FP8 quantization (used by `rmsnorm_quant`, `fused_add_rmsnorm_quant`). Default: 1.0 | +| `--out_dtype` | Output dtype: `fp8_e4m3`, `fp8_e5m2` (for FP8 quant); `nvfp4`, `mxfp4` (for FP4 quant). Default: `fp8_e4m3`| +| `--use_global_scale` | Use global scale factor for NVFP4 format (FP4 routines only) | +| `--is_sf_swizzled_layout`| Use swizzled scale factor layout for tensor core GEMM (FP4 routines only) | +| `--backends` | Backend to test: `cuda` (default) or `cute-dsl` (for FP4 routines) | + +### Quantization Flags +| Flag | Description | +|--------------------------|-------------------------------------------------------------------------------------------------------------| +| `--m` | Number of rows in input tensor | +| `--k` | Number of columns in input tensor (must be divisible by 32) | +| `--input_dtype` | Input data type: `bfloat16` (default) or `float16` | +| `--is_sf_swizzled_layout`| Use swizzled layout for scale factors. Default: True | +| `--no_sf_swizzled_layout`| Disable swizzled layout for scale factors | +| `--alignment` | sfVecSize for quantization. Default: 32 | +| `--enable_pdl` | Enable programmatic dependent launch | +| `--batch_size` | Batch size for batched quantization (`nvfp4_batched_quantize` only) | +| `--global_scale` | Global scale factor for NVFP4 quantization. Default: 1.0 | +| `--sf_layout` | Scale factor layout for NVFP4: `128x4` (default), `8x4`, or `linear` | +| `--do_shuffle` | Shuffle scale factors for TRTLLM backend (`nvfp4_quantize` only) | +| `--sf_vec_size` | Scale factor vector size for NVFP4 quantization. Default: 16 | +| `--backends` | Backend to test. Default: `cuda` | + ## `flashinfer_benchmark.py` Routine & Backend Support Matrix The following table summarizes the support surface of each routine & backend's on various [CUDA Compute Capabilities](https://developer.nvidia.com/cuda-gpus). @@ -228,13 +285,25 @@ Legend: | **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | | | **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | | | **cutlass_fused_moe** | | | | | | cutlass | cutlass | | +| **rmsnorm** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **fused_add_rmsnorm_quant** | cuda | cuda | cuda | cuda | cuda | cuda | cuda | cuda | +| **rmsnorm_fp4quant** | | | | | | cute-dsl | cute-dsl | | +| **add_rmsnorm_fp4quant** | | | | | | cute-dsl | cute-dsl | | +| **mxfp8_quantize** | | | | | | cuda | cuda | | +| **mxfp4_quantize** | | | | | | cuda | cuda | | +| **nvfp4_quantize** | | | | | | cuda | cuda | | +| **nvfp4_batched_quantize** | | | | | | cuda | cuda | | Backend Legend: - fa2: FlashAttention2 - fa2_tc: FlashAttention2 (with Tensor Cores for `BatchDecodeWithPagedKVCacheWrapper`) - fa3: FlashAttention-3 - cudnn: cuDNN +- cublas: cuBLAS - cutlass: CUTLASS - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM - trtllm-native: TensorRT-LLM (out-of-wrapper) +- cuda: FlashInfer CUDA kernels +- cute-dsl: FlashInfer CuTe-DSL kernels (Blackwell SM10.0+) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 6ba35a1fd4..00bb04fb6f 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -363,7 +363,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cute-dsl"], "10.3": ["cute-dsl"], - "12.0": [], + "12.0": ["cute-dsl"], }, "add_rmsnorm_fp4quant": { "7.5": [], @@ -373,7 +373,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cute-dsl"], "10.3": ["cute-dsl"], - "12.0": [], + "12.0": ["cute-dsl"], }, # QUANTIZATION "mxfp8_quantize": { @@ -384,7 +384,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cuda"], "10.3": ["cuda"], - "12.0": [], + "12.0": ["cuda"], }, "mxfp4_quantize": { "7.5": [], @@ -394,7 +394,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cuda"], "10.3": ["cuda"], - "12.0": [], + "12.0": ["cuda"], }, "nvfp4_quantize": { "7.5": [], @@ -404,7 +404,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cuda"], "10.3": ["cuda"], - "12.0": [], + "12.0": ["cuda"], }, "nvfp4_batched_quantize": { "7.5": [], @@ -414,7 +414,7 @@ def dtype_str_to_torch_dtype(dtype_str): "9.0": [], "10.0": ["cuda"], "10.3": ["cuda"], - "12.0": [], + "12.0": ["cuda"], }, } diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index fec3bba4ce..77bfe0b1bc 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -787,6 +787,10 @@ def testRmsnormFp4quant(args): print(f"[VVERBOSE] {use_global_scale = }") print(f"[VVERBOSE] {is_sf_swizzled_layout = }") + # Warn user that refcheck is not supported for FP4 quantization fusion + if run_refcheck: + print("[WARNING] --refcheck is not supported for rmsnorm_fp4quant.") + def run_backend(backend, input_tensor, weight): if backend == "cute-dsl": return flashinfer.rmsnorm_fp4quant( @@ -800,23 +804,9 @@ def run_backend(backend, input_tensor, weight): else: raise ValueError(f"Unsupported backend: {backend}") - # Reference: PyTorch implementation of RMSNorm + FP4 quantization - has_reference_output = False - if run_refcheck: - # For FP4 quantization, we verify output shapes and dtypes - # since FP4 quantization details are complex and implementation-specific - has_reference_output = True - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} - outputs = {} for cur_backend in backends: - if run_refcheck: - out_fp4, out_scale = run_backend(cur_backend, input_tensor, weight) - outputs[cur_backend] = ( - out_fp4.detach().clone(), - out_scale.detach().clone(), - ) backend_times[cur_backend] = bench_gpu_time( fn=run_backend, dry_run_iters=args.dry_run_iters, @@ -826,17 +816,6 @@ def run_backend(backend, input_tensor, weight): input_args=(cur_backend, input_tensor, weight), ) - tested_backends = list(outputs.keys()) - if len(tested_backends) > 0: - if run_refcheck and has_reference_output: - # For FP4, we just verify output shapes are correct - for i in range(len(tested_backends)): - out_fp4, out_scale = outputs[tested_backends[i]] - if args.verbose >= 2: - print( - f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}" - ) - for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) @@ -979,6 +958,10 @@ def testAddRmsnormFp4quant(args): print(f"[VVERBOSE] {use_global_scale = }") print(f"[VVERBOSE] {is_sf_swizzled_layout = }") + # Warn user that refcheck is not supported for FP4 quantization fusion + if run_refcheck: + print("[WARNING] --refcheck is not supported for add_rmsnorm_fp4quant. ") + def run_backend(backend, input_tensor, residual_tensor, weight): if backend == "cute-dsl": return flashinfer.add_rmsnorm_fp4quant( @@ -993,25 +976,9 @@ def run_backend(backend, input_tensor, residual_tensor, weight): else: raise ValueError(f"Unsupported backend: {backend}") - # Reference: For FP4 quantization, we verify output shapes and dtypes - # since FP4 quantization details are complex and implementation-specific - has_reference_output = False - if run_refcheck: - has_reference_output = True - # Storage for timing results and outputs backend_times = {backend: [] for backend in backends} - outputs = {} for cur_backend in backends: - if run_refcheck: - out_fp4, out_scale, out_h = run_backend( - cur_backend, input_tensor, residual_tensor.clone(), weight - ) - outputs[cur_backend] = ( - out_fp4.detach().clone(), - out_scale.detach().clone(), - out_h.detach().clone(), - ) backend_times[cur_backend] = bench_gpu_time( fn=run_backend, dry_run_iters=args.dry_run_iters, @@ -1021,17 +988,6 @@ def run_backend(backend, input_tensor, residual_tensor, weight): input_args=(cur_backend, input_tensor, residual_tensor.clone(), weight), ) - tested_backends = list(outputs.keys()) - if len(tested_backends) > 0: - if run_refcheck and has_reference_output: - # For FP4, we just verify output shapes are correct - for i in range(len(tested_backends)): - out_fp4, out_scale, out_h = outputs[tested_backends[i]] - if args.verbose >= 2: - print( - f"[VVERBOSE] Backend {tested_backends[i]}: out_fp4.shape = {out_fp4.shape}, out_scale.shape = {out_scale.shape}, out_h.shape = {out_h.shape}" - ) - for backend in backends: if len(backend_times[backend]) > 0: median_time = np.median(backend_times[backend]) diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index 1b83911e9c..df9cf5d100 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -544,6 +544,13 @@ def testNvfp4Quantize(args): run_refcheck = args.refcheck res = [] + # do_shuffle involves CPU index generation which is not CUDA graph compatible + if do_shuffle and is_cuda_graph_compatible: + print( + "[WARNING] do_shuffle=True is not CUDA graph compatible. Disabling CUDA graph." + ) + is_cuda_graph_compatible = False + # Convert sf_layout string to enum sf_layout_map = { "128x4": SfLayout.layout_128x4, From c2c39ad08072d53bc277062673f2dcee2c4c3096 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Thu, 15 Jan 2026 22:03:32 +0000 Subject: [PATCH 5/7] Third fixes --- benchmarks/README.md | 23 +++++++++++++++++++---- benchmarks/routines/norm.py | 6 ++++-- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index 66c8fd49b6..efcafdc403 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -95,17 +95,32 @@ $ python3 flashinfer_benchmark.py --routine BatchPrefillWithRaggedKVCacheWrapper # RMSNorm with FP8 quantized output $ python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag "rmsnorm_quant_fp8_e4m3" -[INFO] args = Namespace(routine='rmsnorm_quant', ...) [INFO] Running testRmsnormQuant [INFO] FlashInfer version: 0.6.1 -[PERF] cuda :: median time 0.005 ms; std 0.000 ms; achieved tflops 0.131 TFLOPs/sec; achieved tb_per_sec 0.132 TB/sec +[VVERBOSE] gpu_name = 'NVIDIA_B300_SXM6_AC' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine rmsnorm_quant --batch_size 32 --hidden_size 4096 --input_dtype bfloat16 --out_dtype fp8_e4m3 --scale 1.0 --refcheck -vv --generate_repro_command --case_tag rmsnorm_quant_fp8_e4m3 +[VVERBOSE] input_tensor.shape = torch.Size([32, 4096]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] weight.shape = torch.Size([4096]) +[VVERBOSE] out_tensor.dtype = torch.float8_e4m3fn +[VVERBOSE] scale = 1.0 +[PERF] cuda :: median time 0.003 ms; std 0.000 ms; achieved tflops 0.229 TFLOPs/sec; achieved tb_per_sec 0.140 TB/sec # MxFP8 Quantization (Blackwell SM10.0+ only) $ python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag "mxfp8_quantize" -[INFO] args = Namespace(routine='mxfp8_quantize', ...) +[INFO] args = Namespace(routine='mxfp8_quantize', no_cuda_graph=False, use_cupti=False, use_cuda_events=False, refcheck=True, allow_output_mismatch=False, random_seed=42, verbose=2, output_path=None, num_iters=30, dry_run_iters=5, case_tag='mxfp8_quantize', generate_repro_command=True, repro_command='', m=2048, k=8192, input_dtype='bfloat16', is_sf_swizzled_layout=True, no_sf_swizzled_layout=False, alignment=32, enable_pdl=False, backends=['cuda'], batch_size=None, global_scale=1.0, sf_layout='128x4', do_shuffle=False, sf_vec_size=16) [INFO] Running testMxfp8Quantize [INFO] FlashInfer version: 0.6.1 -[PERF] cuda :: median time 0.041 ms; std 0.000 ms; achieved tflops 1.228 TFLOPs/sec; achieved tb_per_sec 1.240 TB/sec +[VVERBOSE] gpu_name = 'NVIDIA_B300_SXM6_AC' +[INFO] To reproduce this test case, run the following command: python3 flashinfer_benchmark.py --routine mxfp8_quantize --m 2048 --k 8192 --input_dtype bfloat16 --refcheck -vv --generate_repro_command --case_tag mxfp8_quantize +[VVERBOSE] input_tensor.shape = torch.Size([2048, 8192]) +[VVERBOSE] input_tensor.dtype = torch.bfloat16 +[VVERBOSE] is_sf_swizzled_layout = True +[VVERBOSE] alignment = 32 +[VVERBOSE] enable_pdl = False +[VVERBOSE] Backend cuda: x_q.shape = torch.Size([2048, 8192]), x_q.dtype = torch.float8_e4m3fn, sf.shape = torch.Size([524288]), sf.dtype = torch.uint8 +[VVERBOSE] Round-trip error: 0/16777216 (0.00%) elements differ +[PERF] cuda :: median time 0.016 ms; std 0.000 ms; achieved tflops 3.118 TFLOPs/sec; achieved tb_per_sec 3.150 TB/sec ``` ### Batch Testing diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index 77bfe0b1bc..7a5fcd1ef3 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -774,9 +774,10 @@ def testRmsnormFp4quant(args): weight = torch.randn(hidden_size, dtype=input_dtype, device=device) # Prepare global_scale if using NVFP4 format + # Note: API expects a 1D tensor of shape [1], not a 0D scalar global_scale = None if use_global_scale: - global_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + global_scale = torch.tensor([1.0], dtype=torch.float32, device=device) if args.verbose >= 2: print(f"[VVERBOSE] {input_tensor.shape = }") @@ -944,9 +945,10 @@ def testAddRmsnormFp4quant(args): weight = torch.randn(hidden_size, dtype=input_dtype, device=device) # Prepare global_scale if using NVFP4 format + # Note: API expects a 1D tensor of shape [1], not a 0D scalar global_scale = None if use_global_scale: - global_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + global_scale = torch.tensor([1.0], dtype=torch.float32, device=device) if args.verbose >= 2: print(f"[VVERBOSE] {input_tensor.shape = }") From e9227f01a54a143214babe6fd182b2e44d799e47 Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 16 Jan 2026 00:32:14 +0000 Subject: [PATCH 6/7] Address reviewers comments --- benchmarks/routines/norm.py | 14 +++++++ benchmarks/routines/quantization.py | 62 +++++++++++++++++++++++------ 2 files changed, 64 insertions(+), 12 deletions(-) diff --git a/benchmarks/routines/norm.py b/benchmarks/routines/norm.py index 7a5fcd1ef3..29c0678ca3 100644 --- a/benchmarks/routines/norm.py +++ b/benchmarks/routines/norm.py @@ -752,6 +752,13 @@ def testRmsnormFp4quant(args): f"Unsupported out_dtype for FP4 quant: {out_dtype}. Supported: nvfp4, mxfp4." ) + # Validate alignment: hidden_size must be divisible by block_size + if hidden_size % block_size != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size}) " + f"for {out_dtype} quantization." + ) + backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -922,6 +929,13 @@ def testAddRmsnormFp4quant(args): f"Unsupported out_dtype for FP4 quant: {out_dtype}. Supported: nvfp4, mxfp4." ) + # Validate alignment: hidden_size must be divisible by block_size + if hidden_size % block_size != 0: + raise ValueError( + f"hidden_size ({hidden_size}) must be divisible by block_size ({block_size}) " + f"for {out_dtype} quantization." + ) + backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index df9cf5d100..10ff1632ff 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -86,16 +86,17 @@ def parse_quantization_args(line, parser): ) parser.add_argument( "--is_sf_swizzled_layout", + dest="is_sf_swizzled_layout", action="store_true", - default=True, - help="Use swizzled layout for scale factors. Default: True", + help="Use swizzled layout for scale factors. (default)", ) parser.add_argument( "--no_sf_swizzled_layout", - action="store_true", - default=False, + dest="is_sf_swizzled_layout", + action="store_false", help="Disable swizzled layout for scale factors.", ) + parser.set_defaults(is_sf_swizzled_layout=True) parser.add_argument( "--alignment", type=int, @@ -157,10 +158,6 @@ def parse_quantization_args(line, parser): args = parser.parse_args(line) - # Handle swizzled layout flag - if args.no_sf_swizzled_layout: - args.is_sf_swizzled_layout = False - if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -205,9 +202,10 @@ def testMxfp8Quantize(args): run_refcheck = args.refcheck res = [] - # Validate k is divisible by 32 (sf_vec_size) - if k % 32 != 0: - raise ValueError(f"k ({k}) must be divisible by 32 (sf_vec_size)") + # Validate k is divisible by alignment (sf_vec_size) + sf_vec_size = alignment + if k % sf_vec_size != 0: + raise ValueError(f"k ({k}) must be divisible by {sf_vec_size} (sf_vec_size)") backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: @@ -301,6 +299,17 @@ def run_backend(backend, input_tensor): f"[VVERBOSE] Round-trip error: {num_different_elements}/{num_elements} " f"({num_different_elements_percentage:.2f}%) elements differ" ) + # Enforce refcheck: fail or warn on mismatches + if num_different_elements > 0: + mismatch_msg = ( + f"[mxfp8_quantize] Round-trip mismatch: " + f"{num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ" + ) + if args.allow_output_mismatch: + print(f"[WARNING] {mismatch_msg}") + else: + raise AssertionError(mismatch_msg) except Exception as e: if args.verbose >= 1: print(f"[WARNING] Dequantize check failed: {e}") @@ -314,7 +323,6 @@ def run_backend(backend, input_tensor): # Read: input tensor # Write: quantized tensor (fp8) + scale factors num_elements = m * k - sf_vec_size = 32 num_scale_factors = num_elements // sf_vec_size problem_bytes = ( num_elements * input_dtype.itemsize # input read @@ -383,6 +391,13 @@ def testMxfp4Quantize(args): run_refcheck = args.refcheck res = [] + # mxfp4 uses sf_vec_size=32 (hardcoded in the API) + sf_vec_size = 32 + if k % sf_vec_size != 0: + raise ValueError( + f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size}) for mxfp4_quantize" + ) + backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -459,6 +474,17 @@ def run_backend(backend, input_tensor): f"[VVERBOSE] Round-trip error: {num_different_elements}/{num_elements} " f"({num_different_elements_percentage:.2f}%) elements differ" ) + # Enforce refcheck: fail or warn on mismatches + if num_different_elements > 0: + mismatch_msg = ( + f"[mxfp4_quantize] Round-trip mismatch: " + f"{num_different_elements}/{num_elements} " + f"({num_different_elements_percentage:.2f}%) elements differ" + ) + if args.allow_output_mismatch: + print(f"[WARNING] {mismatch_msg}") + else: + raise AssertionError(mismatch_msg) except Exception as e: if args.verbose >= 1: print(f"[WARNING] Dequantize check failed: {e}") @@ -559,6 +585,12 @@ def testNvfp4Quantize(args): } sf_layout = sf_layout_map[sf_layout_str] + # Validate k is divisible by sf_vec_size + if k % sf_vec_size != 0: + raise ValueError( + f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size}) for nvfp4_quantize" + ) + backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") @@ -717,6 +749,12 @@ def testNvfp4BatchedQuantize(args): if batch_size is None: raise ValueError("--batch_size is required for nvfp4_batched_quantize") + # Validate k is divisible by sf_vec_size + if k % sf_vec_size != 0: + raise ValueError( + f"k ({k}) must be divisible by sf_vec_size ({sf_vec_size}) for nvfp4_batched_quantize" + ) + backends = filter_backends_by_compute_capability(backends, args.routine, device) if len(backends) == 0: print("[ERROR] No backends to test. Exiting.") From 1f2d48a3064d4930aa1dabfb8fc1d6547cd279ca Mon Sep 17 00:00:00 2001 From: Brian Ryu Date: Fri, 16 Jan 2026 00:46:37 +0000 Subject: [PATCH 7/7] Address reviewers comment --- benchmarks/routines/quantization.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/benchmarks/routines/quantization.py b/benchmarks/routines/quantization.py index 10ff1632ff..e1f40a9220 100644 --- a/benchmarks/routines/quantization.py +++ b/benchmarks/routines/quantization.py @@ -312,7 +312,11 @@ def run_backend(backend, input_tensor): raise AssertionError(mismatch_msg) except Exception as e: if args.verbose >= 1: - print(f"[WARNING] Dequantize check failed: {e}") + print( + f"[WARNING] [mxfp8_quantize] Dequantize check failed: {e}" + ) + if not args.allow_output_mismatch: + raise for backend in backends: if len(backend_times[backend]) > 0: @@ -487,7 +491,11 @@ def run_backend(backend, input_tensor): raise AssertionError(mismatch_msg) except Exception as e: if args.verbose >= 1: - print(f"[WARNING] Dequantize check failed: {e}") + print( + f"[WARNING] [mxfp4_quantize] Dequantize check failed: {e}" + ) + if not args.allow_output_mismatch: + raise for backend in backends: if len(backend_times[backend]) > 0: