diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 0857c9f9f8..26909a5dd9 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -100,6 +100,7 @@ "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", "bmm_fp8", + "bmm_mxfp8", "mm_fp4", ], "moe": [ @@ -236,6 +237,16 @@ def dtype_str_to_torch_dtype(dtype_str): "10.3": ["cudnn", "cublas", "cutlass"], "12.0": ["cudnn", "cublas"], }, + "bmm_mxfp8": { + "7.5": [], + "8.0": [], + "8.6": [], + "8.9": [], + "9.0": [], + "10.0": ["cudnn"], + "10.3": ["cudnn"], + "12.0": [], + }, # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here # MOE "trtllm_fp4_block_scale_moe": { diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 9c6ed91766..eb96ac0579 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -7,6 +7,7 @@ import flashinfer from flashinfer.autotuner import autotune +from flashinfer.fp8_quantization import mxfp8_quantize from flashinfer.testing.utils import ( bench_gpu_time, dequantize_fp8, @@ -38,6 +39,8 @@ def run_gemm_test(args): return testGroupGemmFp8NtGroupwise(args) elif args.routine == "bmm_fp8": return testBmmFp8(args) + elif args.routine == "bmm_mxfp8": + return testBmmMxfp8(args) elif args.routine == "mm_fp4": return testMmFp4(args) else: @@ -144,6 +147,7 @@ def parse_gemm_args(line, parser): action="store_true", help="In mm_fp4, whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.", ) + # TODO: add bmm_mxfp8 ? parser.add_argument( "--autotune", action="store_true", @@ -757,6 +761,211 @@ def run_backend(backend, input_fp8, mat2_fp8, input_inv_s, mat2_inv_s): return res +def testBmmMxfp8(args): + """ + Test bmm_mxfp8 API. + + This test: + 1. Generates random input tensors + 2. Quantizes input tensors to MXFP8 + 3. Runs bmm_mxfp8 + 4. Runs reference check + 5. Measures performance metrics (TFLOPS, TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testBmmMxfp8") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends + batch_size = args.batch_size + m = args.m + n = args.n + k = args.k + res_dtype = args.out_dtype + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + autotune_supported_backends = [ + "cudnn", + ] + res = [] + + backends = filter_backends_by_compute_capability(backends, args.routine, device) + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + res_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if res_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16." + ) + ## Done parsing input arguments + + if getattr(args, "autotune", False): + backends_to_remove = [] + for cur_backend in backends: + if cur_backend not in autotune_supported_backends: + print(f"[INFO] {cur_backend} backend does not support autotune") + backends_to_remove.append(cur_backend) + for cur_backend in backends_to_remove: + backends.remove(cur_backend) + + if len(backends) == 0: + print("[ERROR] No backends to test. Exiting.") + return res + + ## Prepare input tensors + input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16) + input_mxfp8, input_scale = mxfp8_quantize(input, is_sf_swizzled_layout=True) + + mat2 = ( + torch.randn([batch_size, n, k], device=device, dtype=torch.bfloat16) + .transpose(-2, -1) + .contiguous() + ) + mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout=True) + + if args.verbose >= 2: + print(f"[VVERBOSE] {input_mxfp8.shape = }") + print(f"[VVERBOSE] {input_mxfp8.dtype = }") + print(f"[VVERBOSE] {mat2_mxfp8.shape = }") + print(f"[VVERBOSE] {mat2_mxfp8.dtype = }") + print(f"[VVERBOSE] {input_scale.shape = }") + print(f"[VVERBOSE] {input_scale.dtype = }") + print(f"[VVERBOSE] {mat2_scale.shape = }") + print(f"[VVERBOSE] {mat2_scale.dtype = }") + + def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): + if backend == "cudnn": + return flashinfer.gemm.bmm_mxfp8( + A=input_mxfp8, + B=mat2_mxfp8, + A_scale=input_scale, + B_scale=mat2_scale, + dtype=res_dtype, + backend=backend, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + has_reference_output = False + if run_refcheck: + reference_output = torch.bmm(input, mat2) + has_reference_output = True + + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + for cur_backend in backends: + if cur_backend in autotune_supported_backends: + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for bmm_mxfp8: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend( + cur_backend, + input_mxfp8, + mat2_mxfp8, + input_scale, + mat2_scale, + ) + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=True, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=(cur_backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale), + ) + + min_cos_sim = 0.9 # TODO: check if can be increased + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + cos_sim = F.cosine_similarity( + reference_output.reshape(-1), + tested_outputs[i].reshape(-1), + dim=0, + ) + if cos_sim < min_cos_sim: + print( + f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}" + ) + + for backend in backends: + backend_name = backend + ( + "_autotune" + if ( + getattr(args, "autotune", False) + and backend in autotune_supported_backends + ) + else "" + ) + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + problem_flops = 2 * m * n * k * batch_size + # MXFP8 uses fp8_e4m3fn for data (1 byte) and uint8 for scales + # Scale tensors are much smaller, so approximate as 1 byte per element for simplicity + problem_bytes = ( + m * k * torch.float8_e4m3fn.itemsize + + n * k * torch.float8_e4m3fn.itemsize + + m * n * res_dtype.itemsize + ) * batch_size + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["batch_size"] = batch_size + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["n"] = n + cur_res["k"] = k + cur_res["out_dtype"] = res_dtype + cur_res["backend"] = backend_name + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + def testMmFp4(args): """ Test mm_fp4 API. diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index 07a913eb5f..d01d7254b7 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -86,6 +86,7 @@ ) from .gemm import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm import bmm_fp8 as bmm_fp8 +from .gemm import bmm_mxfp8 as bmm_mxfp8 from .gemm import mm_fp4 as mm_fp4 from .gemm import mm_fp8 as mm_fp8 from .gemm import tgv_gemm_sm100 as tgv_gemm_sm100 diff --git a/flashinfer/gemm/__init__.py b/flashinfer/gemm/__init__.py index ed66b0bd9c..885cd85fdc 100644 --- a/flashinfer/gemm/__init__.py +++ b/flashinfer/gemm/__init__.py @@ -1,5 +1,6 @@ from .gemm_base import SegmentGEMMWrapper as SegmentGEMMWrapper from .gemm_base import bmm_fp8 as bmm_fp8 +from .gemm_base import bmm_mxfp8 as bmm_mxfp8 from .gemm_base import mm_fp4 as mm_fp4 from .gemm_base import mm_fp8 as mm_fp8 from .gemm_base import tgv_gemm_sm100 as tgv_gemm_sm100 @@ -22,6 +23,7 @@ __all__ = [ "SegmentGEMMWrapper", "bmm_fp8", + "bmm_mxfp8", "mm_fp4", "mm_fp8", "tgv_gemm_sm100", diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 5bbc4df2ba..96aa32ec54 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1336,6 +1336,41 @@ def execute_cudnn_gemm_fp4_graph( ) +def execute_cudnn_gemm_mxfp8_graph( + graph, + a, + b, + a_descale, + b_descale, + c_final, + workspace_buffer, + tactic: int = -1, +): + variant_pack = { + UIDs.A_UID.value: a, + UIDs.B_UID.value: b, + UIDs.BLOCK_DESCALE_A_UID.value: a_descale, + UIDs.BLOCK_DESCALE_B_UID.value: b_descale, + UIDs.O_UID.value: c_final, + } + + workspace_size = graph.get_workspace_size() + + if workspace_buffer.numel() < workspace_size: + workspace_buffer = torch.empty( + workspace_size, device=a.device, dtype=torch.uint8 + ) + + stream = torch.cuda.current_stream(a.device) + + if tactic == -1: + graph.execute(variant_pack, workspace_buffer, handle=_get_cudnn_handle(stream)) + else: + graph.execute_plan_at_index( + variant_pack, workspace_buffer, tactic, handle=_get_cudnn_handle(stream) + ) + + @functools.cache def build_cudnn_gemm_with_per_tensor_q_graph( a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device @@ -3755,3 +3790,409 @@ def fp8_blockscale_gemm_sm90( runner.run_gemm(input, weight, out, input_scale, weight_scale) return out + + +def _calculate_block_scale_dims( + m: int, n: int, k: int, block_size: int +) -> Tuple[int, int, int]: + """Calculate block scale dimensions using indestructible block formula.""" + INDESTRUCTIBLE_128x4_BLOCK_M_N = 128 + INDESTRUCTIBLE_128x4_BLOCK_K = 4 + + def div_up(a, b): + return (a + b - 1) // b + + block_scale_dim_m = ( + div_up(m, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + ) + block_scale_dim_n = ( + div_up(n, INDESTRUCTIBLE_128x4_BLOCK_M_N) * INDESTRUCTIBLE_128x4_BLOCK_M_N + ) + block_scale_dim_k = ( + div_up(div_up(k, block_size), INDESTRUCTIBLE_128x4_BLOCK_K) + * INDESTRUCTIBLE_128x4_BLOCK_K + ) + + return block_scale_dim_m, block_scale_dim_n, block_scale_dim_k + + +@functools.cache +def create_cudnn_execution_plans_mxfp8_gemm( + a_shape, + a_stride, + a_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 + b_shape, + b_stride, + b_type, # cudnn.data_type, FP8_E4M3 or FP8_E5M2 + block_size, + o_type, # cudnn.data_type, BF16 or FP16 + device, +): + if len(a_shape) != 3: + raise ValueError(f"A shape must be 3D, got {a_shape}") + if len(b_shape) != 3: + raise ValueError(f"B shape must be 3D, got {b_shape}") + + if a_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"A type must be FP8_E4M3 or FP8_E5M2, got {a_type}") + if b_type not in [cudnn.data_type.FP8_E4M3, cudnn.data_type.FP8_E5M2]: + raise ValueError(f"B type must be FP8_E4M3 or FP8_E5M2, got {b_type}") + if o_type not in [cudnn.data_type.BFLOAT16, cudnn.data_type.HALF]: + raise ValueError(f"Output type must be BF16 or FP16, got {o_type}") + + # Extract batch, m, n, k dimensions + b_dim = a_shape[0] + m = a_shape[1] + k = a_shape[2] + n = b_shape[2] + + # Calculate block scale dimensions using indestructible block formula + block_scale_dim_m, block_scale_dim_n, block_scale_dim_k = ( + _calculate_block_scale_dims(m, n, k, block_size) + ) + + # For mxfp8, scale tensors need to be reshaped to 3D with correct strides + # cuDNN expects K-major layout: stride for K dimension should be 1 + # For block_descale_a: shape [b, block_scale_dim_m, block_scale_dim_k], stride [block_scale_dim_m * block_scale_dim_k, block_scale_dim_k, 1] + # For block_descale_b: shape [b, block_scale_dim_k, block_scale_dim_n], stride [block_scale_dim_n * block_scale_dim_k, 1, block_scale_dim_k] + + a_descale_shape = (b_dim, block_scale_dim_m, block_scale_dim_k) + a_descale_stride = ( + block_scale_dim_m * block_scale_dim_k, + block_scale_dim_k, + 1, + ) + + b_descale_shape = (b_dim, block_scale_dim_k, block_scale_dim_n) + b_descale_stride = ( + block_scale_dim_n * block_scale_dim_k, + 1, + block_scale_dim_k, + ) + + # MXFP8 uses FP8_E4M3/FP8_E5M2 for quantized data + # MXFP8 uses FP8_E8M0 for scale data + scale_type = cudnn.data_type.FP8_E8M0 + + stream = torch.cuda.current_stream(device) + with cudnn.graph(_get_cudnn_handle(stream)) as (graph, _): + a_cudnn_tensor = graph.tensor( + name="a", + dim=tuple(a_shape), # [b, m, k] + stride=tuple(a_stride), # [m * k, k, 1] + data_type=a_type, + ) + b_cudnn_tensor = graph.tensor( + name="b", + dim=tuple(b_shape), # [b, k, n] + stride=tuple(b_stride), # [k * n, 1, k] + data_type=b_type, + ) + block_descale_a_cudnn_tensor = graph.tensor( + name="block_descale_a", + dim=a_descale_shape, + stride=a_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + block_descale_b_cudnn_tensor = graph.tensor( + name="block_descale_b", + dim=b_descale_shape, + stride=b_descale_stride, + data_type=scale_type, + reordering_type=cudnn.tensor_reordering.F8_128x4, + ) + + # Dequantize the input tensors + dequant_a_tensor = graph.block_scale_dequantize( + a_cudnn_tensor, + block_descale_a_cudnn_tensor, + block_size=[1, block_size], + name="dequant_a", + ) + dequant_a_tensor.set_data_type(cudnn.data_type.FLOAT) + dequant_b_tensor = graph.block_scale_dequantize( + b_cudnn_tensor, + block_descale_b_cudnn_tensor, + block_size=[block_size, 1], + name="dequant_b", + ) + dequant_b_tensor.set_data_type(cudnn.data_type.FLOAT) + + # The actual matmul operation + c_tensor = graph.matmul( + dequant_a_tensor, + dequant_b_tensor, + compute_data_type=cudnn.data_type.FLOAT, + name="gemm", + ) + c_tensor.set_data_type(cudnn.data_type.FLOAT) + + # Output the dequantized result with the specified output dtype + c_tensor.set_output(True).set_data_type(o_type) + c_final_cudnn_tensor = c_tensor + + a_cudnn_tensor.set_uid(UIDs.A_UID.value) + b_cudnn_tensor.set_uid(UIDs.B_UID.value) + block_descale_a_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_A_UID.value) + block_descale_b_cudnn_tensor.set_uid(UIDs.BLOCK_DESCALE_B_UID.value) + c_final_cudnn_tensor.set_uid(UIDs.O_UID.value) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B]) + + return graph + + +def _get_cudnn_mxfp8_gemm_graph( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + block_size: int = 32, # mxfp8 block size is 32 + tactic: int = -1, +): + graph = create_cudnn_execution_plans_mxfp8_gemm( + a_shape=a.shape, + a_stride=a.stride(), + b_shape=b.shape, + b_stride=b.stride(), + a_type=_torch_data_type_to_cudnn_data_type(a.dtype), + b_type=_torch_data_type_to_cudnn_data_type(b.dtype), + o_type=_torch_data_type_to_cudnn_data_type(out_dtype), + block_size=block_size, + device=a.device, + ) + + graph.check_support() + if tactic != -1: + graph.build_plan_at_index(tactic) + else: + graph.build_plans() + return graph + + +def _cudnn_gemm_mxfp8( + a: torch.Tensor, + b: torch.Tensor, + a_descale: torch.Tensor, + b_descale: torch.Tensor, + out_dtype: torch.dtype = torch.bfloat16, + out: Optional[torch.Tensor] = None, + workspace_buffer: torch.Tensor = None, + tactic: int = -1, +): + # mxfp8 block size is 32 + block_size = 32 + + # Graph should have been already cached, when we ran _cudnn_bmm_mxfp8_requirement + graph = _get_cudnn_mxfp8_gemm_graph( + a=a, + b=b, + out_dtype=out_dtype, + out=out, + block_size=block_size, + tactic=tactic, + ) + # execute the mxfp8 cudnn graph + execute_cudnn_gemm_mxfp8_graph( + graph=graph, + a=a, + b=b, + a_descale=a_descale, + b_descale=b_descale, + c_final=out, + workspace_buffer=workspace_buffer, + tactic=tactic, + ) + + +def _cudnn_gemm_mxfp8_runner(): + class CudnnMxfp8GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + # TODO: check if this is correct + # cudnn has heuristic for mxfp8 gemm, so we only need to use the default tactic + return [0] + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, scale_a, scale_b, out, workspace_buffer = inputs + _cudnn_gemm_mxfp8( + a=a, + b=b, + a_descale=scale_a, + b_descale=scale_b, + out=out, + out_dtype=out.dtype, + workspace_buffer=workspace_buffer, + tactic=tactic, + ) + return out + + return CudnnMxfp8GemmRunner() + + +def mxfp8_gemm_sm100( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + workspace_buffer: torch.Tensor, + runner_names: List[str], +) -> None: + runners = [] + if "cudnn" in runner_names: + runners.append(_cudnn_gemm_mxfp8_runner()) + assert runners, "No suitable runners found" + tuner = AutoTuner.get() + + inputs = [a, b, scale_a, scale_b, out, workspace_buffer] + runner, tactic = tuner.choose_one( + "mxfp8_gemm", # TODO: check if this is correct + runners, + _FP8_GEMM_SM100_TUNING_CONFIG, # TODO: check if this is correct + inputs, + ) + + runner(inputs=inputs, tactic=tactic) + + +@supported_compute_capability([100, 103]) +def _cudnn_bmm_mxfp8_requirement( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn"] = "cudnn", +): + _check_cudnn_availability() + return True + + +def _validate_mxfp8_output_dtype(dtype: torch.dtype): + """Validate that the output dtype is either bf16 or fp16.""" + if dtype not in (torch.bfloat16, torch.float16): + raise ValueError( + f"Unsupported output dtype: {dtype}. " + f"Only torch.bfloat16 and torch.float16 are supported for MXFP8 GEMM operations." + ) + + +def _check_bmm_mxfp8_problem_size( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn"] = "cudnn", +): + # Check input tensors + if A.ndim != 3 or B.ndim != 3: + # A is [b, m, k], B is [b, k, n] + raise ValueError(f"bmm_mxfp8 accepts 3d tensors, got {A.shape=} and {B.shape=}") + if A.shape[2] != B.shape[1]: + raise ValueError( + f"K dimension (last dim of A) mismatch in bmm_mxfp8. got {A.shape=}, {B.shape=}" + ) + + _validate_mxfp8_output_dtype(dtype) + return True + + +def _heuristic_func_bmm_mxfp8( + suitable_backends: List[str], + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn"] = "cudnn", +): + heuristic_backends = [] + if CUDNN_AVAILABLE and "cudnn" in suitable_backends: + heuristic_backends.append("cudnn") + return heuristic_backends + + +@backend_requirement( + { + "cudnn": _cudnn_bmm_mxfp8_requirement, + }, + common_check=_check_bmm_mxfp8_problem_size, + heuristic_func=_heuristic_func_bmm_mxfp8, +) +@flashinfer_api +def bmm_mxfp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, + backend: Literal["cudnn"] = "cudnn", +) -> torch.Tensor: + r"""BMM MXFP8 + + Parameters + ---------- + A: torch.Tensor + Input tensor, shape (b, m, k), fp8 e4m3 or fp8 e5m2. + + B: torch.Tensor + Mat2 tensor, shape (b, k, n), should be column major, fp8 e4m3 or fp8 e5m2. + + A_scale: torch.Tensor + Scale tensor for A, uint8 (fp8 e8m0 format). + + B_scale: torch.Tensor + Scale tensor for B, uint8 (fp8 e8m0 format). + + dtype: torch.dtype + out dtype, bf16 or fp16. + + out: Optional[torch.Tensor] + Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``. + + backend: Literal["cudnn"] + The backend to use for the operation. Defaults to ``"cudnn"``. + + Returns + ------- + out: torch.Tensor + Out tensor, shape (b, m, n), bf16 or fp16. + """ + + if backend != "cudnn": + raise ValueError(f"Invalid backend: {backend}") + + if not CUDNN_AVAILABLE: + raise ValueError("cudnn is not available") + + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + + workspace_buffer = _get_cache_buf( + "bmm_mxfp8_workspace", DEFAULT_WORKSPACE_SIZE, A.device + ) + + mxfp8_gemm_sm100(A, B, A_scale, B_scale, out, workspace_buffer, ["cudnn"]) + return out diff --git a/tests/gemm/test_bmm_mxfp8.py b/tests/gemm/test_bmm_mxfp8.py new file mode 100644 index 0000000000..d5938fe41d --- /dev/null +++ b/tests/gemm/test_bmm_mxfp8.py @@ -0,0 +1,80 @@ +import pytest +import torch +import torch.nn.functional as F + +from flashinfer import autotune, bmm_mxfp8 +from flashinfer.fp8_quantization import mxfp8_quantize +from flashinfer.utils import get_compute_capability + + +@pytest.mark.parametrize("b", [1, 16]) +@pytest.mark.parametrize("m", [128, 256, 512]) +@pytest.mark.parametrize("n", [128, 256, 512]) +@pytest.mark.parametrize("k", [128, 256, 512, 1024]) +@pytest.mark.parametrize("input_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16]) +@pytest.mark.parametrize("backend", ["cudnn"]) +@pytest.mark.parametrize("auto_tuning", [True, False]) +def test_bmm_mxfp8( + b, m, n, k, input_dtype, is_sf_swizzled_layout, res_dtype, backend, auto_tuning +): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] in [11, 12]: + pytest.skip("Not tested on SM110/SM120/SM121") + if compute_capability[0] < 10: + pytest.skip( + "bmm_mxfp8 with cudnn backend is only supported on SM100 and above GPUs." + ) + + # Create inputs and quantize them to MXFP8 format + input_mat = torch.randn([b, m, k], device="cuda", dtype=input_dtype) + + # input_mxfp8 dtype will be float8_e4m3fn + # input_scale dtype will be uint8 + input_mxfp8, input_scale = mxfp8_quantize(input_mat, is_sf_swizzled_layout) + + # Block size is 32 in MXFP8 + assert input_mxfp8.numel() == (input_scale.numel() * 32) + + mat2 = ( + torch.randn([b, n, k], device="cuda", dtype=input_dtype) + .transpose(-2, -1) + .contiguous() + ) + mat2_mxfp8, mat2_scale = mxfp8_quantize(mat2, is_sf_swizzled_layout) + + assert mat2_mxfp8.numel() == (mat2_scale.numel() * 32) + + # Compute reference result + reference = torch.bmm(input_mat, mat2) + + # Create output tensor + res = torch.empty([b, m, n], device="cuda", dtype=res_dtype) + + with autotune(auto_tuning): + bmm_mxfp8( + input_mxfp8, + mat2_mxfp8, + input_scale, + mat2_scale, + res_dtype, + res, + backend=backend, + ) + + # Verify output properties + assert res.shape == (b, m, n), f"Expected shape {(b, m, n)}, got {res.shape}" + assert res.dtype == res_dtype, f"Expected dtype {res_dtype}, got {res.dtype}" + assert not torch.isnan(res).any(), "Output contains NaN values" + + # Use the same metric as in test_bmm_fp8 + min_cos_sim = 0.9 # TODO: check if can be increased + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > min_cos_sim, ( + f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})" + ) + + +if __name__ == "__main__": + pytest.main([__file__])