diff --git a/README.md b/README.md index c4ff46d7e2..ec1fb86c29 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ High-Performance GPU Kernels for Inference - **POD-Attention**: Fused prefill+decode for mixed batching ### GEMM & Linear Operations +- **BF16 GEMM**: BF16 matrix multiplication for SM10.0+ GPUs. - **FP8 GEMM**: Per-tensor and groupwise scaling - **FP4 GEMM**: NVFP4 and MXFP4 matrix multiplication for Blackwell GPUs - **Grouped GEMM**: Efficient batched matrix operations for LoRA and multi-expert routing diff --git a/benchmarks/README.md b/benchmarks/README.md index f8555fbb24..777f62748e 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -24,6 +24,8 @@ Currently supports testing attention, gemm, fused MOE, normalization, quantizati - `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling. - `bmm_fp8` - Batched matrix multiplication with FP8 inputs. - `mm_fp4` - Matrix multiplication with NVFP4 inputs. + - `mm_bf16` - Matrix multiplication with BF16 inputs (Blackwell SM10.0+). + - `bmm_bf16` - Batched matrix multiplication with BF16 inputs (Blackwell SM10.0+). - MOE: - `trtllm_fp4_block_scale_moe` - MOE with FP4 quantized weights and block-wise scaling. - `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling. @@ -219,7 +221,8 @@ The output CSV will contain detailed metrics including: | `--mat2_dtype` | Data type for second matrix (for FP8 GEMM, e.g. `fp8_e4m3`) | | `--use_128x4_sf_layout` | Use 128x4 scale/format layout for FP4 GEMM (for `mm_fp4` routine) | | `--use_nvfp4` | Whether to use nvfp4 quantization or mxfp4 quantization, defaults to False.(for `mm_fp4` routine) | -| `--autotune` | Enable autotune for supported operation (`trtllm` and `cutlass` backends for `mm_fp4` and `bmm_fp8` routines)| +| `--autotune` | Enable autotune for supported operation (`mm_fp4`, `bmm_fp8`, `mm_bf16`, `bmm_bf16` routines) | +| `--bias` | Use bias for `mm_bf16` (Enabled for TGV backend) | ### MOE Flags | Flag | Description | @@ -406,6 +409,8 @@ Legend: | **group_gemm_fp8_nt_groupwise** | | | | | | cutlass | cutlass | | | **bmm_fp8** | | | | cudnn, cublas | cudnn, cublas | cudnn, cublas, cutlass | cudnn, cublas, cutlass | cudnn, cublas | | **mm_fp4** | | | | | | cudnn, trtllm, cutlass | cudnn, trtllm, cutlass | cudnn | +| **mm_bf16** | | | | | | cudnn, cutlass, tgv | cudnn, cutlass, tgv | | +| **bmm_bf16** | | | | | | cudnn, cutlass | cudnn, cutlass | | | **trtllm_fp4_block_scale_moe** | | | | | | trtllm | trtllm | | | **trtllm_fp8_block_scale_moe** | | | | | | trtllm | trtllm | | | **trtllm_fp8_per_tensor_scale_moe** | | | | | | trtllm | trtllm | | @@ -452,6 +457,7 @@ Backend Legend: - cudnn: cuDNN (via wrapper API) - cudnn-native: cuDNN (direct API call) - cutlass: CUTLASS +- tgv: TGV - trtllm: TensorRT-LLM - trtllm-gen: TensorRT-LLM - trtllm-native: TensorRT-LLM (out-of-wrapper) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index e3d83f405b..bf8b9ec46b 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -35,6 +35,7 @@ "mma_sm", "use_128x4_sf_layout", "use_nvfp4", + "bias", ], "moe": [ "num_tokens", @@ -153,6 +154,8 @@ "bmm_mxfp8", "mm_fp4", "mm_mxfp8", + "mm_bf16", + "bmm_bf16", ], "moe": [ "trtllm_fp4_block_scale_moe", @@ -353,7 +356,7 @@ def dtype_str_to_torch_dtype(dtype_str): "11.0": ["cutlass"], "12.0": [], }, - # Note: mm_fp4 uses support checkers to filter backends, so it is not listed here + # Note: mm_fp4, mm_bf16, and bmm_bf16 use support checkers to filter backends, so they are not listed here # MOE "trtllm_fp4_block_scale_moe": { "7.5": [], diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index c4f2488fd6..035a79bdfa 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -45,6 +45,10 @@ def run_gemm_test(args): return testMmFp4(args) elif args.routine == "mm_mxfp8": return testMmMxfp8(args) + elif args.routine == "mm_bf16": + return testMmBf16(args) + elif args.routine == "bmm_bf16": + return testBmmBf16(args) else: raise ValueError(f"Unsupported routine: {args.routine}") @@ -136,7 +140,7 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass", "auto"], + choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "auto"], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -154,9 +158,21 @@ def parse_gemm_args(line, parser): action="store_true", default=False, help=( - "Enable autotuner warmup for supported routines (mm_fp4, bmm_fp8, bmm_mxfp8 and mm_mxfp8)." + "Enable autotuner warmup for supported routines (mm_fp4, bmm_fp8, bmm_mxfp8, mm_mxfp8, mm_bf16, bmm_bf16)." ), ) + parser.add_argument( + "--bias", + action="store_true", + default=False, + help="Use bias (enabled for mm_bf16 with TGV backend for now)", + ) + parser.add_argument( + "--enable_pdl", + action="store_true", + default=False, + help="Enable programmatic dependent launch.", + ) args = parser.parse_args(line) if args.verbose >= 1: @@ -1406,6 +1422,7 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): f"{tested_backends[0]} and backend {tested_backends[i]} " f"with {cos_sim=} (expected >= {min_cos_sim})" ) + for backend in backends: backend_name = backend + ( "_autotune" @@ -1445,3 +1462,418 @@ def run_backend(backend, input_mxfp8, mat2_mxfp8, input_scale, mat2_scale): cur_res["case_tag"] = args.case_tag res.append(cur_res) return res + + +def testMmBf16(args): + """ + Test mm_bf16 API. + + This test: + 1. Generates random BF16 input tensors + 2. Runs mm_bf16 with specified backend + 3. Runs reference check (torch.mm) + 4. Measures performance metrics (TFLOPS, TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testMmBf16") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends + m = args.m + n = args.n + k = args.k + use_bias = getattr(args, "bias", False) + use_pdl = getattr(args, "enable_pdl", False) + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"] + res = [] + + out_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if out_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16 and float16." + ) + + ## Prepare input tensors + # a: (m, k), row-major + a = torch.randn([m, k], device=device, dtype=torch.bfloat16) + # b: (n, k) then transpose to (k, n) for column-major layout + b = torch.randn([n, k], device=device, dtype=torch.bfloat16).transpose(-2, -1) + + bias = None + if use_bias: + bias = torch.randn([n], device=device, dtype=torch.bfloat16) + + if args.verbose >= 2: + print(f"[VVERBOSE] {a.shape = }") + print(f"[VVERBOSE] {a.dtype = }") + print(f"[VVERBOSE] {b.shape = }") + print(f"[VVERBOSE] {b.dtype = }") + if bias is not None: + print(f"[VVERBOSE] {bias.shape = }") + print(f"[VVERBOSE] {bias.dtype = }") + print(f"[VVERBOSE] {use_pdl = }") + + # Programmatically filter backends + backends_to_remove = [] + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + flashinfer.mm_bf16( + a=a, + b=b, + bias=bias, + pdl=use_pdl, + out_dtype=out_dtype, + backend=backend, + ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return res + + def run_backend(backend, a, b, bias, use_pdl, out_dtype): + if backend in ["cudnn", "cutlass", "tgv", "auto"]: + return flashinfer.mm_bf16( + a=a, + b=b, + bias=bias, + pdl=use_pdl, + out_dtype=out_dtype, + backend=backend, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + has_reference_output = False + reference_output_base = None + if run_refcheck: + reference_output_base = torch.mm(a, b).to(out_dtype) + has_reference_output = True + + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + for cur_backend in backends: + if cur_backend in autotune_supported_backends: + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for mm_bf16: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend, a, b, bias, use_pdl, out_dtype) + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend( + cur_backend, a, b, bias, use_pdl, out_dtype + ).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=True, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=(cur_backend, a, b, bias, use_pdl, out_dtype), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + # Only add bias to reference when comparing against tgv backend + if tested_backends[i] == "tgv" and bias is not None: + reference_output = reference_output_base + bias.unsqueeze(0).to( + out_dtype + ) + else: + reference_output = reference_output_base + + cos_sim = F.cosine_similarity( + reference_output.reshape(-1), + tested_outputs[i].reshape(-1), + dim=0, + ) + if cos_sim < 0.99: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]} with cos_sim={cos_sim}" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}" + ) + + for backend in backends: + backend_name = backend + ( + "_autotune" + if ( + getattr(args, "autotune", False) + and backend in autotune_supported_backends + ) + else "" + ) + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + problem_flops = 2 * m * n * k + if use_bias: + problem_flops += m * n # bias addition + problem_bytes = ( + m * k * torch.bfloat16.itemsize + + k * n * torch.bfloat16.itemsize + + m * n * out_dtype.itemsize + ) + if use_bias: + problem_bytes += n * torch.bfloat16.itemsize + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["n"] = n + cur_res["k"] = k + cur_res["out_dtype"] = str(out_dtype) + cur_res["backend"] = backend_name + cur_res["bias"] = use_bias + cur_res["enable_pdl"] = use_pdl + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res + + +def testBmmBf16(args): + """ + Test bmm_bf16 API. + + This test: + 1. Generates random BF16 batched input tensors + 2. Runs bmm_bf16 with specified backend + 3. Runs reference check (torch.bmm) + 4. Measures performance metrics (TFLOPS, TB/sec) + + Args: + args: Parsed command line arguments containing test configuration + + Returns: + dict: List of dictionaries containing performance results + """ + if args.verbose >= 1: + print("[INFO] Running testBmmBf16") + print(f"[INFO] FlashInfer version: {flashinfer.__version__}") + + device = get_device(args) + if args.generate_repro_command: + print( + f"[INFO] To reproduce this test case, run the following command: {args.repro_command}" + ) + + ## Parse input arguments + backends = args.backends + batch_size = args.batch_size + m = args.m + n = args.n + k = args.k + is_cuda_graph_compatible = not args.no_cuda_graph + run_refcheck = args.refcheck + autotune_supported_backends = ["cudnn", "cutlass", "auto"] + res = [] + + out_dtype = dtype_str_to_torch_dtype(args.out_dtype) + if out_dtype not in [torch.bfloat16, torch.float16]: + raise ValueError( + f"Unsupported output dtype: {args.out_dtype}. Supported dtypes are bfloat16 and float16." + ) + + ## Prepare input tensors + # A: (batch_size, m, k), row-major + A = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16) + # B: (batch_size, n, k) then transpose to (batch_size, k, n) for column-major layout + B = torch.randn([batch_size, n, k], device=device, dtype=torch.bfloat16).transpose( + -2, -1 + ) + + if args.verbose >= 2: + print(f"[VVERBOSE] {A.shape = }") + print(f"[VVERBOSE] {A.dtype = }") + print(f"[VVERBOSE] {B.shape = }") + print(f"[VVERBOSE] {B.dtype = }") + + # Programmatically filter backends + backends_to_remove = [] + for backend in backends: + # Skip autotune check for now (handled separately below) + if ( + getattr(args, "autotune", False) + and backend not in autotune_supported_backends + ): + print(f"[INFO] {backend} backend does not support autotune") + backends_to_remove.append(backend) + continue + + try: + flashinfer.bmm_bf16( + A=A, + B=B, + out_dtype=out_dtype, + backend=backend, + ) + except Exception as e: + print( + f"[INFO] {backend} backend does not support this configuration: {type(e).__name__}: {e}" + ) + backends_to_remove.append(backend) + + # Remove unsupported backends + for backend in backends_to_remove: + backends.remove(backend) + + if len(backends) == 0: + print("[ERROR] No backends passed validation. Exiting.") + return res + + def run_backend(backend, A, B, out_dtype): + if backend in ["cudnn", "cutlass", "auto"]: + return flashinfer.bmm_bf16( + A=A, + B=B, + out_dtype=out_dtype, + backend=backend, + ) + else: + raise ValueError(f"Unsupported backend: {backend}") + + has_reference_output = False + if run_refcheck: + reference_output = torch.bmm(A, B).to(out_dtype) + has_reference_output = True + + if getattr(args, "autotune", False): + warmup_iters = ( + args.dry_run_iters if args.dry_run_iters and args.dry_run_iters > 0 else 10 + ) + for cur_backend in backends: + if cur_backend in autotune_supported_backends: + if args.verbose >= 1: + print(f"[INFO] Autotune warmup for bmm_bf16: {warmup_iters} iters") + with autotune(True): + for _ in range(warmup_iters): + run_backend(cur_backend, A, B, out_dtype) + + # Storage for timing results and outputs + backend_times = {backend: [] for backend in backends} + outputs = {} + for cur_backend in backends: + if run_refcheck: + outputs[cur_backend] = run_backend(cur_backend, A, B, out_dtype).detach() + backend_times[cur_backend] = bench_gpu_time( + fn=run_backend, + dry_run_iters=args.dry_run_iters, + repeat_iters=args.num_iters, + sleep_after_run=True, + enable_cupti=args.use_cupti, + use_cuda_graph=is_cuda_graph_compatible, + cold_l2_cache=True, + input_args=(cur_backend, A, B, out_dtype), + ) + + tested_backends = list(outputs.keys()) + tested_outputs = list(outputs.values()) + if len(tested_backends) > 0: + if run_refcheck and has_reference_output: + for i in range(len(tested_backends)): + cos_sim = F.cosine_similarity( + reference_output.reshape(-1), + tested_outputs[i].reshape(-1), + dim=0, + ) + if cos_sim < 0.99: + print( + f"[ERROR] Output tensor mismatch from backend {tested_backends[i]} with cos_sim={cos_sim}" + ) + if not args.allow_output_mismatch: + raise AssertionError( + f"[ERROR] Backend {tested_backends[i]} output mismatch with cos_sim={cos_sim}" + ) + + for backend in backends: + backend_name = backend + ( + "_autotune" + if ( + getattr(args, "autotune", False) + and backend in autotune_supported_backends + ) + else "" + ) + if len(backend_times[backend]) > 0: + median_time = np.median(backend_times[backend]) + std_time = np.std(backend_times[backend]) + problem_flops = 2 * batch_size * m * n * k + problem_bytes = batch_size * ( + m * k * torch.bfloat16.itemsize + + k * n * torch.bfloat16.itemsize + + m * n * out_dtype.itemsize + ) + tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec + tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec + print_perf_metrics(backend_name, median_time, std_time, tflops, tb_per_sec) + + if args.output_path is not None: + cur_res = defaultdict(str) + cur_res["batch_size"] = batch_size + cur_res["routine"] = args.routine + cur_res["median_time"] = median_time + cur_res["std_time"] = std_time + cur_res["tflops"] = tflops + cur_res["tb_per_sec"] = tb_per_sec + cur_res["m"] = m + cur_res["n"] = n + cur_res["k"] = k + cur_res["out_dtype"] = str(out_dtype) + cur_res["backend"] = backend_name + cur_res["case_tag"] = args.case_tag + res.append(cur_res) + return res diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ad55fe6b7f..4630583982 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -192,7 +192,7 @@ def _cutlass_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -216,7 +216,7 @@ def _cudnn_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -241,7 +241,7 @@ def _tgv_gemm_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ): if out_dtype != torch.bfloat16: raise ValueError( @@ -257,7 +257,7 @@ def _check_mm_bf16_problem_size( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ): if a.dtype != torch.bfloat16: raise ValueError( @@ -298,7 +298,7 @@ def _heuristic_func_mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ): heuristic_backends = [] if bias is not None or pdl: @@ -332,7 +332,7 @@ def mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "tgv", + backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", ) -> torch.Tensor: r"""MM BF16 @@ -345,22 +345,22 @@ def mm_bf16( Weight tensor, shape (k, n), bf16 in column-major layout. bias: Optional[torch.Tensor] - Optional bias tensor, shape (n,). If provided, can only be used with the TGV backend. Defaults to ``None``. + Optional bias tensor, shape (n,). Enabled for TGV backend. Defaults to ``None``. pdl: bool - Whether to use persistant data loader mode. Can only be used with the TGV backend. Defaults to ``False``. + Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to ``False``. out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16 or fp16. If provided, can only be used with the CUTLASS backend. Defaults to ``None``. + Out tensor, shape (m, n), bf16 or fp16. Enabled for CUTLASS backend. Defaults to ``None``. out_dtype: torch.dtype - Output dtype, bf16 or fp16. Can be used with the CUTLASS or cuDNN backends. Defaults to ``torch.bfloat16``. + Output dtype, bf16 or fp16. Enabled for CUTLASS and cuDNN backends. Defaults to ``torch.bfloat16``. backend: Literal["cudnn", "cutlass", "tgv", "auto"] - The backend to use for the operation. Defaults to ``"tgv"``. - ``"cudnn"`` uses the cuDNN backend (no bias/pdl support). - ``"cutlass"`` uses the CUTLASS backend (no bias/pdl support). - ``"tgv"`` uses the TGV backend (supports bias/pdl, bf16 output only). + The backend to use for the operation. Defaults to ``"cudnn"``. + ``"cudnn"`` uses the cuDNN backend. + ``"cutlass"`` uses the CUTLASS backend. + ``"tgv"`` uses the TGV backend. ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns @@ -388,6 +388,12 @@ def mm_bf16( torch.Size([48, 80]) >>> out.dtype torch.float16 + >>> # Using the cuDNN backend + >>> out = flashinfer.mm_bf16(a, b, backend="cudnn") + >>> out.shape + torch.Size([48, 80]) + >>> out.dtype + torch.bfloat16 """ if out is None: @@ -427,7 +433,7 @@ def _cutlass_bmm_bf16_requirement( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cudnn", ): _validate_bf16_output_dtype(out_dtype) @@ -440,7 +446,7 @@ def _cudnn_bmm_bf16_requirement( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cudnn", ): _validate_bf16_output_dtype(out_dtype) _check_cudnn_availability() @@ -452,7 +458,7 @@ def _check_bmm_bf16_problem_size( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cudnn", ): if A.dtype != torch.bfloat16: raise ValueError( @@ -487,7 +493,7 @@ def _heuristic_func_bmm_bf16( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cudnn", ): heuristic_backends = [] if "cudnn" in suitable_backends: @@ -511,7 +517,7 @@ def bmm_bf16( B: torch.Tensor, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "auto"] = "cutlass", + backend: Literal["cudnn", "cutlass", "auto"] = "cudnn", ) -> torch.Tensor: r"""BMM BF16 @@ -530,7 +536,7 @@ def bmm_bf16( Output dtype, bf16 (default) or fp16. backend: Literal["cudnn", "cutlass", "auto"] - Backend to use, defaults to "cutlass". ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. + Backend to use, defaults to "cudnn". ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns ------- @@ -541,9 +547,17 @@ def bmm_bf16( -------- >>> import torch >>> import flashinfer + >>> # Using the CUTLASS backend >>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) >>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1) - >>> out = flashinfer.bmm_bf16(input, weight) + >>> fp16_out = torch.empty([16, 48, 80], device="cuda", dtype=torch.float16) + >>> out = flashinfer.bmm_bf16(input, weight, out=fp16_out, out_dtype=torch.float16, backend="cutlass") + >>> out.shape + torch.Size([16, 48, 80]) + >>> out.dtype + torch.float16 + >>> # using the cuDNN backend + >>> out = flashinfer.bmm_bf16(input, weight, backend="cudnn") >>> out.shape torch.Size([16, 48, 80]) >>> out.dtype diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 24541f329f..e9a15d20e3 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from flashinfer import autotune, bmm_bf16 +from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE from flashinfer.utils import get_compute_capability @@ -22,6 +23,10 @@ def test_bmm_bf16(b, m, n, k, res_dtype, backend): ) if not bmm_bf16.is_backend_supported(backend, compute_capability_number): pytest.skip(f"{backend} backend not supported on current compute capability.") + + if backend == "cudnn" and not CUDNN_AVAILABLE: + pytest.skip("cuDNN is not available on this system.") + # cuDNN on SM103 does not support bf16 input -> fp16 output if ( backend == "cudnn" diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index 384d3a52e6..d5d978a011 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -3,6 +3,7 @@ import torch.nn.functional as F from flashinfer import autotune, mm_bf16 +from flashinfer.gemm.gemm_base import CUDNN_AVAILABLE from flashinfer.utils import get_compute_capability @@ -32,6 +33,9 @@ def test_mm_bf16( if not mm_bf16.is_backend_supported(backend, compute_capability_number): pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend == "cudnn" and not CUDNN_AVAILABLE: + pytest.skip("cuDNN is not available on this system.") + if backend == "cudnn" and (enable_bias or pdl): pytest.skip( "mm_bf16 with cuDNN backend does not support bias or pdl arguments."