diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 41ba7b3c18..0209e1bd40 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -140,7 +140,16 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "cute-dsl", "auto"], + choices=[ + "cudnn", + "cublas", + "trtllm", + "cutlass", + "tgv", + "cublaslt", + "cute-dsl", + "auto", + ], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -1553,7 +1562,7 @@ def testMmBf16(args): use_pdl = getattr(args, "enable_pdl", False) is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"] + autotune_supported_backends = ["cudnn", "cutlass", "tgv", "cublaslt", "auto"] res = [] out_dtype = dtype_str_to_torch_dtype(args.out_dtype) @@ -1618,7 +1627,7 @@ def testMmBf16(args): return res def run_backend(backend, a, b, bias, use_pdl, out_dtype): - if backend in ["cudnn", "cutlass", "tgv", "auto"]: + if backend in ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]: return flashinfer.mm_bf16( a=a, b=b, diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index 4de464fac0..54e8c26898 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -49,7 +49,8 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso auto stream = get_stream(A.device()); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( - workspace_buffer.data_ptr(), workspace_buffer.numel(), + workspace_buffer.data_ptr(), + workspace_buffer.numel() * get_element_size(workspace_buffer), static_cast(B.data_ptr()), static_cast(A.data_ptr()), static_cast(D.data_ptr()), batch_size, n, m, k, static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), @@ -61,3 +62,91 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso }); }); } + +int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer) { + CHECK_CUDA(A); + CHECK_CUDA(B); + CHECK_CUDA(D); + CHECK_DIM(3, A); + CHECK_DIM(3, B); + CHECK_DIM(3, D); + CHECK_CONTIGUOUS(algo_buffer); + TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; + TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; + TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) + << "Result tensor has incorrect shape"; + + int64_t result = 0; + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(A.device().device_id); + + int max_algos = static_cast(algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::bmm_fp8::kAlgoBytes); + result = flashinfer::bmm_fp8::get_fp8_algorithms( + batch_size, n, m, k, static_cast(B_scale.data_ptr()), + static_cast(A_scale.data_ptr()), + workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle, + algo_buffer.data_ptr(), max_algos); + return true; + }); + }); + }); + return static_cast(result); +} + +void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx) { + CHECK_CUDA(A); + CHECK_CUDA(B); + CHECK_CUDA(D); + CHECK_DIM(3, A); + CHECK_DIM(3, B); + CHECK_DIM(3, D); + CHECK_CONTIGUOUS(algo_buffer); + TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; + TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; + TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) + << "Result tensor has incorrect shape"; + + int64_t max_algos = + algo_buffer.numel() * get_element_size(algo_buffer) / flashinfer::bmm_fp8::kAlgoBytes; + TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos) + << "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")"; + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(A.device().device_id); + auto stream = get_stream(A.device()); + + auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo( + workspace_buffer.data_ptr(), + workspace_buffer.numel() * get_element_size(workspace_buffer), + static_cast(B.data_ptr()), static_cast(A.data_ptr()), + static_cast(D.data_ptr()), batch_size, n, m, k, + static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), + lt_handle, stream, algo_buffer.data_ptr(), static_cast(algo_idx)); + TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) + << "bmm_fp8_run_with_algo failed: " << cublasGetStatusString(status); + return true; + }); + }); + }); +} diff --git a/csrc/flashinfer_gemm_binding.cu b/csrc/flashinfer_gemm_binding.cu index 52d0551413..80fffb4de4 100644 --- a/csrc/flashinfer_gemm_binding.cu +++ b/csrc/flashinfer_gemm_binding.cu @@ -19,9 +19,19 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle); +int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer); + +void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx); + void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr, TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld, TensorView y_ld, TensorView empty_x_data, bool weight_column_major); TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM); TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_get_algos, bmm_fp8_get_algos); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_run_with_algo, bmm_fp8_run_with_algo); diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu new file mode 100644 index 0000000000..d6f0330620 --- /dev/null +++ b/csrc/mm_bf16_cublaslt.cu @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "tvm_ffi_utils.h" + +namespace { + +cudaDataType_t get_d_type(DLDataType dtype) { + switch (encode_dlpack_dtype(dtype)) { + case bfloat16_code: + return CUDA_R_16BF; + case float16_code: + return CUDA_R_16F; + case float32_code: + return CUDA_R_32F; + default: + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32."; + return CUDA_R_16BF; + } +} + +} // namespace + +// Serialize all heuristic algorithms into a CPU uint8 tensor for caching. +// algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes. +// Returns number of algorithms written. +int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer) { + CHECK_CUDA(mat1); + CHECK_CUDA(mat2); + CHECK_CUDA(out); + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + CHECK_DIM(2, out); + CHECK_CPU(algo_buffer); + CHECK_CONTIGUOUS(algo_buffer); + CHECK_CUDA(workspace_buffer); + + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + + TVM_FFI_ICHECK_EQ(mat2.size(1), k) + << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); + TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; + TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + + cudaDataType_t d_type = get_d_type(out.dtype()); + + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); + auto lt_handle = reinterpret_cast(cublas_handle); + int max_algos = static_cast(algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::mm_bf16_cublaslt::kAlgoBytes); + return static_cast(flashinfer::mm_bf16_cublaslt::get_algorithms( + static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle, + algo_buffer.data_ptr(), max_algos)); +} + +// Run matmul using a pre-cached algorithm — zero heuristic overhead. +void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx) { + CHECK_CUDA(mat1); + CHECK_CUDA(mat2); + CHECK_CUDA(out); + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + CHECK_DIM(2, out); + CHECK_CPU(algo_buffer); + CHECK_CONTIGUOUS(algo_buffer); + CHECK_CUDA(workspace_buffer); + + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + + TVM_FFI_ICHECK_EQ(mat2.size(1), k) + << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); + TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; + TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + + int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::mm_bf16_cublaslt::kAlgoBytes; + TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos) + << "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")"; + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); + auto stream = get_stream(mat1.device()); + cudaDataType_t d_type = get_d_type(out.dtype()); + + auto status = flashinfer::mm_bf16_cublaslt::run_with_algo( + static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), + out.data_ptr(), static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer), + lt_handle, stream, algo_buffer.data_ptr(), static_cast(algo_idx)); + TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) + << "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo); diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 9909befbbd..e839a13162 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -77,6 +77,7 @@ gen_gemm_sm100_module_cutlass_mxfp8, gen_gemm_sm120_module, gen_gemm_sm120_module_cutlass_fp4, + gen_mm_bf16_cublaslt_module, gen_tgv_gemm_sm10x_module, gen_trtllm_gen_gemm_module, gen_trtllm_low_latency_gemm_module, @@ -511,6 +512,8 @@ def gen_all_modules( ) jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True)) jit_specs.append(gen_moe_utils_module()) + if has_sm100 or has_sm103: + jit_specs.append(gen_mm_bf16_cublaslt_module()) if has_sm103: jit_specs.append(gen_fp4_quantization_sm103_module()) jit_specs.append(gen_cutlass_fused_moe_sm103_module()) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index cb0141fdf8..48120ba23f 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -397,7 +397,16 @@ def forward( raise NotImplementedError def __hash__(self): - return hash(tuple(self.__dict__.values())) + hashable_vals = [] + for k, v in self.__dict__.items(): + if k.endswith("_cache"): + continue + try: + hash(v) + hashable_vals.append(v) + except TypeError: + hashable_vals.append(id(v)) + return hash(tuple(hashable_vals)) @contextlib.contextmanager diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 02b54b5b8b..565304b2fe 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 from ..jit.gemm import gen_gemm_sm100_module_cutlass_mxfp8 from ..jit.gemm import gen_gemm_sm100_module_cutlass_bf16 +from ..jit.gemm import gen_mm_bf16_cublaslt_module from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module @@ -89,6 +90,11 @@ DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +# sizeof(cublasLtMatmulAlgo_t) = uint64_t[8] = 64 bytes. +# Shared by cuBLAS FP8, cuBLASLt BF16, and any other cuBLASLt-based runners. +_CUBLASLT_ALGO_BYTES = 64 +_CUBLASLT_MAX_ALGOS = 100 + # Error messages CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120/SM121 with cuDNN backend version < 9.14.0." @@ -108,16 +114,49 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): def get_gemm_module(): module = gen_gemm_module().build_and_load() - # auto-tuned cublas fp8 gemm runner def cublas_fp8_gemm_runner(): class CublasFp8GemmRunner(TunableRunner): + def __init__(self): + self._algo_cache: dict = {} + + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + + def _get_algos(self, inputs): + a, b, scale_a, scale_b, out, workspace_buffer = inputs + key = (a.shape, b.shape, a.dtype, b.dtype, out.dtype) + cached = self._algo_cache.get(key) + if cached is not None: + return cached + algo_buf = torch.empty( + _CUBLASLT_MAX_ALGOS * _CUBLASLT_ALGO_BYTES, + dtype=torch.uint8, + device="cpu", + ) + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() + count = module.bmm_fp8_get_algos( + a, + b, + out, + scale_a, + scale_b, + workspace_buffer, + cublas_handle, + algo_buf, + ) + result = (algo_buf, count) + self._algo_cache[key] = result + return result + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # cublas has heuristic for fp8 gemm, so we only need to use the default tactic - return [0] + _, count = self._get_algos(inputs) + return list(range(count)) def forward( self, @@ -126,11 +165,33 @@ def forward( do_preparation: bool = False, **kwargs, ) -> torch.Tensor: - cublas_handle = torch.cuda.current_blas_handle() a, b, scale_a, scale_b, out, workspace_buffer = inputs - module.bmm_fp8( - a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle - ) + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() + if tactic >= 0: + algo_buf, count = self._get_algos(inputs) + if count == 0: + raise RuntimeError( + "cuBLASLt heuristic returned zero FP8 algorithms for " + f"A={tuple(a.shape)}, B={tuple(b.shape)}, out={tuple(out.shape)}." + ) + if tactic >= count: + tactic = 0 + module.bmm_fp8_run_with_algo( + a, + b, + out, + scale_a, + scale_b, + workspace_buffer, + cublas_handle, + algo_buf, + tactic, + ) + else: + module.bmm_fp8( + a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle + ) return out return CublasFp8GemmRunner() @@ -197,7 +258,7 @@ def _cutlass_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -213,6 +274,31 @@ def _cutlass_mm_bf16_requirement( return True +# Gated to Blackwell (SM100/SM103) for the initial scope of this backend. +# cuBLASLt supports BF16 GEMM on SM80+; the gate can be widened in a follow-up. +@supported_compute_capability([100, 103]) +def _cublaslt_mm_bf16_requirement( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + bias: Optional[torch.Tensor] = None, + pdl: bool = False, + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", +): + if bias is not None: + raise ValueError( + "You cannot use the cuBLASLt backend with a bias. Use the TGV backend instead." + ) + if pdl: + raise ValueError( + "The cuBLASLt backend does not support PDL. Use the TGV backend instead." + ) + _validate_bf16_output_dtype(out_dtype) + + return True + + @supported_compute_capability([100, 103]) def _cudnn_mm_bf16_requirement( a: torch.Tensor, @@ -221,7 +307,7 @@ def _cudnn_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -246,7 +332,7 @@ def _tgv_gemm_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if out_dtype != torch.bfloat16: raise ValueError( @@ -262,7 +348,7 @@ def _check_mm_bf16_problem_size( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if a.dtype != torch.bfloat16: raise ValueError( @@ -303,16 +389,18 @@ def _heuristic_func_mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): heuristic_backends = [] if bias is not None or pdl: - # cuDNN and CUTLASS don't support bias/pdl, only TGV does + # cuDNN, CUTLASS, and cuBLASLt don't support bias/pdl, only TGV does if "tgv" in suitable_backends: heuristic_backends.append("tgv") else: if "cudnn" in suitable_backends: heuristic_backends.append("cudnn") + if "cublaslt" in suitable_backends: + heuristic_backends.append("cublaslt") if "cutlass" in suitable_backends: heuristic_backends.append("cutlass") if "tgv" in suitable_backends: @@ -325,6 +413,7 @@ def _heuristic_func_mm_bf16( "cudnn": _cudnn_mm_bf16_requirement, "cutlass": _cutlass_mm_bf16_requirement, "tgv": _tgv_gemm_requirement, + "cublaslt": _cublaslt_mm_bf16_requirement, }, common_check=_check_mm_bf16_problem_size, heuristic_func=_heuristic_func_mm_bf16, @@ -337,7 +426,7 @@ def mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ) -> torch.Tensor: r"""MM BF16 @@ -353,21 +442,23 @@ def mm_bf16( Optional bias tensor, shape (n,). Enabled for TGV backend. Defaults to ``None``. pdl: bool - Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to ``False``. + Whether to enable Programmatic Dependent Launch (PDL). Enabled for TGV backend. Defaults to ``False``. out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends. + Out tensor, shape (m, n). Preallocated output is supported by all backends. + TGV requires ``torch.bfloat16``; CUTLASS, cuDNN, and cuBLASLt also support fp16/fp32. Defaults to ``None``. out_dtype: torch.dtype - Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends. + Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt backends. Defaults to ``torch.bfloat16``. - backend: Literal["cudnn", "cutlass", "tgv", "auto"] + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] The backend to use for the operation. Defaults to ``"cudnn"``. ``"cudnn"`` uses the cuDNN backend. ``"cutlass"`` uses the CUTLASS backend. ``"tgv"`` uses the TGV backend. + ``"cublaslt"`` uses the cuBLASLt backend with heuristic algorithm search. ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns @@ -401,6 +492,12 @@ def mm_bf16( torch.Size([48, 80]) >>> out.dtype torch.bfloat16 + >>> # Using the cuBLASLt backend + >>> out = flashinfer.mm_bf16(a, b, backend="cublaslt") + >>> out.shape + torch.Size([48, 80]) + >>> out.dtype + torch.bfloat16 """ if out is None: @@ -427,6 +524,10 @@ def mm_bf16( backends = _heuristic_func_mm_bf16( ["tgv"], a, b, bias, pdl, out, out_dtype, backend ) + elif backend == "cublaslt": + backends = _heuristic_func_mm_bf16( + ["cublaslt"], a, b, None, False, out, out_dtype, backend + ) else: backends = [backend] @@ -616,7 +717,8 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # For now, return a single default tactic + # TODO: add multi-tactic support when PingPong 64x128x128 schedule + # is implemented (see gemm_groupwise_sm120.cuh) return [-1] def forward( @@ -870,6 +972,114 @@ def forward( ) +@functools.cache +def get_mm_bf16_cublaslt_module(): + module = gen_mm_bf16_cublaslt_module().build_and_load() + + def cublaslt_bf16_gemm_runner(): + class CublasltBf16GemmRunner(TunableRunner): + def __init__(self): + self._algo_cache: dict = {} + + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + _, _, _, _, out, _ = inputs + return (self._compute_dtype(out.dtype),) + + @staticmethod + def _compute_dtype(out_dtype): + # cuBLASLt with BF16 inputs supports BF16 or FP32 output natively. + # FP16 output is achieved via BF16 compute + cast. + if out_dtype == torch.float16: + return torch.bfloat16 + return out_dtype + + def _get_algos(self, inputs): + a, b, _, _, out, workspace_buffer = inputs + compute_dt = self._compute_dtype(out.dtype) + key = (a.shape[0], b.shape[1], a.shape[1], compute_dt) + cached = self._algo_cache.get(key) + if cached is not None: + return cached + algo_buf = torch.empty( + _CUBLASLT_MAX_ALGOS * _CUBLASLT_ALGO_BYTES, + dtype=torch.uint8, + device="cpu", + ) + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() + proxy_out = ( + out + if out.dtype == compute_dt + else torch.empty_like(out, dtype=compute_dt) + ) + count = module.mm_bf16_cublaslt_get_algos( + a, + b.transpose(-2, -1), + proxy_out, + workspace_buffer, + cublas_handle, + algo_buf, + ) + result = (algo_buf, count) + self._algo_cache[key] = result + return result + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + _, count = self._get_algos(inputs) + return list(range(count)) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, _, _, out, workspace_buffer = inputs + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() + b_t = b.transpose(-2, -1) + + need_cast = out.dtype == torch.float16 + if need_cast: + compute_out = torch.empty_like(out, dtype=torch.bfloat16) + else: + compute_out = out + + algo_buf, count = self._get_algos(inputs) + if count == 0: + raise RuntimeError( + "cuBLASLt heuristic returned zero algorithms for " + f"M={a.shape[0]}, N={b.shape[1]}, K={a.shape[1]}, " + f"dtype={compute_out.dtype}. " + "This shape/dtype combination may not be supported." + ) + if tactic < 0 or tactic >= count: + tactic = 0 + module.mm_bf16_cublaslt_run_with_algo( + a, + b_t, + compute_out, + workspace_buffer, + cublas_handle, + algo_buf, + tactic, + ) + if need_cast: + out.copy_(compute_out) + return out + + return CublasltBf16GemmRunner() + + return SimpleNamespace( + cublaslt_bf16_gemm_runner=cublaslt_bf16_gemm_runner, + ) + + _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -902,6 +1112,8 @@ def bf16_gemm_sm100( use_sm_100f = is_sm100f_supported(a.device) if "cudnn" in runner_names: runners.append(_cudnn_gemm_bf16_runner()) + if "cublaslt" in runner_names: + runners.append(get_mm_bf16_cublaslt_module().cublaslt_bf16_gemm_runner()) if "cutlass" in runner_names: runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) if "tgv" in runner_names: @@ -2463,9 +2675,17 @@ def _cudnn_gemm_mxfp8_override_shape( ) -@functools.lru_cache(maxsize=1024) +@functools.lru_cache(maxsize=2048) def build_cudnn_gemm_with_per_tensor_q_graph( - a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device + a_shape, + a_stride, + b_shape, + b_stride, + a_type, + b_type, + o_type, + device, + policy=None, ): """Build a cuDNN graph for GEMM with per-tensor quantization. @@ -2479,11 +2699,15 @@ def build_cudnn_gemm_with_per_tensor_q_graph( a_type: Data type for input tensor A b_type: Data type for input tensor B o_type: Data type for output tensor + policy: cuDNN build plan policy. None defaults to HEURISTICS_CHOICE. + Use ALL to enumerate all execution plans for autotuning. Returns: cuDNN graph object """ _check_cudnn_availability() + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE stream = torch.cuda.current_stream(device) with cudnn.graph(_get_cudnn_handle(device, stream)) as (graph, _): @@ -2539,13 +2763,20 @@ def build_cudnn_gemm_with_per_tensor_q_graph( graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() - graph.build_plans() + graph.build_plans(policy) return graph def execute_cudnn_gemm_with_per_tensor_q_graph( - graph, a, b, a_scale, b_scale, c_final, workspace + graph, + a, + b, + a_scale, + b_scale, + c_final, + workspace, + tactic: int = -1, ): variant_pack = { UIDs.A_UID.value: a, @@ -2563,7 +2794,12 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( graph.get_workspace_size(), device=a.device, dtype=torch.uint8 ) - graph.execute(variant_pack, workspace, handle=cudnn_handle) + if tactic == -1: + graph.execute(variant_pack, workspace, handle=cudnn_handle) + else: + graph.execute_plan_at_index( + variant_pack, workspace, tactic, handle=cudnn_handle + ) # --------------------------------------------------------------------------- @@ -2708,9 +2944,15 @@ def _cudnn_gemm_fp8( b_scale: torch.Tensor, out: Optional[torch.Tensor], torch_out_dtype: torch.dtype, + tactic: int = -1, ): _check_cudnn_availability() + if tactic == -1: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE + else: + policy = cudnn.build_plan_policy.ALL + graph = build_cudnn_gemm_with_per_tensor_q_graph( a.shape, a.stride(), @@ -2720,23 +2962,46 @@ def _cudnn_gemm_fp8( _torch_data_type_to_cudnn_data_type(b.dtype), _torch_data_type_to_cudnn_data_type(torch_out_dtype), a.device, + policy=policy, ) execute_cudnn_gemm_with_per_tensor_q_graph( - graph, a, b, a_scale, b_scale, out, workspace + graph, + a, + b, + a_scale, + b_scale, + out, + workspace, + tactic=tactic, ) return out def _cudnn_gemm_fp8_runner(): class CudnnFp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # cudnn has heuristic for fp8 gemm, so we only need to use the default tactic - return [0] + a, b, _, _, out, _ = inputs + graph = build_cudnn_gemm_with_per_tensor_q_graph( + a.shape, + a.stride(), + b.shape, + b.stride(), + _torch_data_type_to_cudnn_data_type(a.dtype), + _torch_data_type_to_cudnn_data_type(b.dtype), + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + policy=cudnn.build_plan_policy.ALL, + ) + return list(range(graph.get_execution_plan_count())) def forward( self, @@ -2746,7 +3011,16 @@ def forward( **kwargs, ) -> torch.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs - _cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype) + _cudnn_gemm_fp8( + workspace_buffer, + a, + b, + scale_a, + scale_b, + out, + out.dtype, + tactic=tactic, + ) return out return CudnnFp8GemmRunner() @@ -6719,6 +6993,7 @@ def _get_cudnn_mxfp8_gemm_graph( out: Optional[torch.Tensor] = None, block_size: int = 32, # mxfp8 block size is 32 tactic: int = -1, + build_all: bool = False, ): graph = create_cudnn_execution_plans_mxfp8_gemm( a_shape=a.shape, @@ -6733,7 +7008,9 @@ def _get_cudnn_mxfp8_gemm_graph( ) graph.check_support() - if tactic != -1: + if build_all: + graph.build_plans() + elif tactic != -1: graph.build_plan_at_index(tactic) else: graph.build_plans() @@ -6777,14 +7054,24 @@ def _cudnn_gemm_mxfp8( def _cudnn_gemm_mxfp8_runner(): class CudnnMxfp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # TODO: check if this is correct - # cudnn has heuristic for mxfp8 gemm, so we only need to use the default tactic - return [0] + a, b, _, _, out, _ = inputs + graph = _get_cudnn_mxfp8_gemm_graph( + a=a, + b=b, + out_dtype=out.dtype, + out=out, + build_all=True, + ) + return list(range(graph.get_execution_plan_count())) def forward( self, diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index 7fa72c353d..92c11b6859 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -22,6 +22,7 @@ gen_gemm_sm100_module_cutlass_fp8, gen_gemm_sm100_module_cutlass_mxfp8, gen_gemm_sm100_module_cutlass_bf16, + gen_mm_bf16_cublaslt_module, gen_gemm_sm100_module, gen_gemm_sm120_module, gen_trtllm_gen_gemm_module, @@ -40,6 +41,7 @@ "gen_gemm_sm100_module_cutlass_fp8", "gen_gemm_sm100_module_cutlass_mxfp8", "gen_gemm_sm100_module_cutlass_bf16", + "gen_mm_bf16_cublaslt_module", "gen_gemm_sm100_module", "gen_gemm_sm120_module", "gen_trtllm_gen_gemm_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 590f839678..e91c522639 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -50,6 +50,20 @@ def gen_gemm_module() -> JitSpec: ) +def gen_mm_bf16_cublaslt_module() -> JitSpec: + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10] + ) + return gen_jit_spec( + "mm_bf16_cublaslt", + [ + jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu", + ], + extra_cuda_cflags=nvcc_flags, + extra_ldflags=["-lcublas", "-lcublasLt"], + ) + + def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_fp4" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/bmm_fp8.cuh b/include/flashinfer/gemm/bmm_fp8.cuh index 7778934160..4c8539d1c0 100644 --- a/include/flashinfer/gemm/bmm_fp8.cuh +++ b/include/flashinfer/gemm/bmm_fp8.cuh @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -137,6 +139,107 @@ cudaDataType_t get_cuda_data_type() { } } +static constexpr int kMaxFp8Algorithms = 100; +static constexpr size_t kAlgoBytes = sizeof(cublasLtMatmulAlgo_t); + +/*! + * \brief Set up cuBLASLt descriptors for FP8 BMM. + * + * Factors out the common descriptor creation shared by heuristic query, + * run, and run_with_algo paths. + */ +template +struct Fp8GemmDescriptors { + CuBlasLtMatmulDescriptor matmul_desc; + CuBlasLtMatrixLayout a_layout; + CuBlasLtMatrixLayout b_layout; + CuBlasLtMatrixLayout d_layout; + + Fp8GemmDescriptors(int batch_size, int m, int n, int k, const float* A_scale, + const float* B_scale) + : matmul_desc(CUBLAS_COMPUTE_32F, CUDA_R_32F), + a_layout(get_cuda_data_type(), m, k, k, true), + b_layout(get_cuda_data_type(), k, n, k), + d_layout(get_cuda_data_type
(), m, n, m) { + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + int8_t fast_accum = 1; + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); + + const void* A_scale_ptr = static_cast(A_scale); + const void* B_scale_ptr = static_cast(B_scale); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr); + + if constexpr (std::is_same_v && std::is_same_v) { + FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2"); + } + + if (batch_size > 1) { + int64_t stride_a = m * k; + int64_t stride_b = k * n; + int64_t stride_d = m * n; + a_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + a_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a); + b_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + b_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b); + d_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + d_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d); + } + } +}; + +/*! + * \brief Query heuristics and serialize all cublasLtMatmulAlgo_t structs for FP8 BMM. + */ +template +int get_fp8_algorithms(int batch_size, int m, int n, int k, const float* A_scale, + const float* B_scale, size_t workspace_size_in_bytes, + cublasLtHandle_t lt_handle, void* algo_buf, int max_algos) { + Fp8GemmDescriptors desc(batch_size, m, n, k, A_scale, B_scale); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + int request_count = (max_algos > kMaxFp8Algorithms) ? kMaxFp8Algorithms : max_algos; + std::array results; + int returned_count = 0; + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), request_count, results.data(), &returned_count); + if (status != CUBLAS_STATUS_SUCCESS) return 0; + + auto* out = static_cast(algo_buf); + for (int i = 0; i < returned_count; ++i) { + std::memcpy(out + i * kAlgoBytes, &results[i].algo, kAlgoBytes); + } + return returned_count; +} + +/*! + * \brief Run FP8 BMM using a pre-resolved algorithm — zero heuristic overhead. + */ +template +cublasStatus_t bmm_fp8_run_with_algo(void* workspace, size_t workspace_size_in_bytes, const AT* A, + const BT* B, DT* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale, + cublasLtHandle_t lt_handle, cudaStream_t stream, + const void* algo_buf, int algo_idx) { + Fp8GemmDescriptors desc(batch_size, m, n, k, A_scale, B_scale); + + cublasLtMatmulAlgo_t algo; + std::memcpy(&algo, static_cast(algo_buf) + algo_idx * kAlgoBytes, kAlgoBytes); + + const float alpha = 1.0f; + const float beta = 0.0f; + FLASHINFER_CUBLAS_CALL(cublasLtMatmul( + lt_handle, desc.matmul_desc.descriptor(), &alpha, A, desc.a_layout.descriptor(), B, + desc.b_layout.descriptor(), &beta, nullptr, desc.d_layout.descriptor(), D, + desc.d_layout.descriptor(), &algo, workspace, workspace_size_in_bytes, stream)); + return CUBLAS_STATUS_SUCCESS; +} + template cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_in_bytes, const AT* A, const BT* B, DT* D, int batch_size, int m, diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh new file mode 100644 index 0000000000..c7f6307597 --- /dev/null +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2026 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ +#define FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ + +#include +#include + +#include +#include + +#include "bmm_fp8.cuh" + +namespace flashinfer { +namespace mm_bf16_cublaslt { + +using bmm_fp8::CuBlasLtMatmulDescriptor; +using bmm_fp8::CuBlasLtMatmulPreference; +using bmm_fp8::CuBlasLtMatrixLayout; + +static constexpr int kMaxAlgorithms = 100; +using bmm_fp8::kAlgoBytes; + +/*! + * \brief Set up cuBLASLt descriptors for BF16 GEMM in row-major convention. + * + * Python mm_bf16 passes mat1 (m,k) row-major and mat2 (n,k) row-major + * (after b.transpose(-2,-1)). We want out (m,n) row-major. + * + * cuBLASLt is column-major, so we use the standard trick: + * out^T = mat2 @ mat1^T (all column-major) + * + * Memory layouts: + * mat2 row-major (n,k) = col-major (k,n) ld=k → cuBLASLt "A", TRANSA=T → (n,k) + * mat1 row-major (m,k) = col-major (k,m) ld=k → cuBLASLt "B", TRANSB=N → (k,m) + * out row-major (m,n) = col-major (n,m) ld=n → cuBLASLt "D" + * + * Result: (n,k)×(k,m) = (n,m) col-major = (m,n) row-major ✓ + */ +struct GemmDescriptors { + CuBlasLtMatmulDescriptor matmul_desc; + CuBlasLtMatrixLayout a_layout; // mat2 + CuBlasLtMatrixLayout b_layout; // mat1 + CuBlasLtMatrixLayout d_layout; // out + + GemmDescriptors(int m, int n, int k, cudaDataType_t d_type) + : matmul_desc(CUBLAS_COMPUTE_32F, CUDA_R_32F), + a_layout(CUDA_R_16BF, n, k, k, /*t=*/true), + b_layout(CUDA_R_16BF, k, m, k), + d_layout(d_type, n, m, n) { + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + } +}; + +/*! + * \brief Query heuristics once and serialize all cublasLtMatmulAlgo_t structs into a buffer. + * + * Each algo occupies kAlgoBytes (64) contiguous bytes. The buffer can be cached + * and later passed to run_with_algo() to skip the heuristic lookup entirely. + * + * \param algo_buf Output buffer, must hold at least max_algos * kAlgoBytes bytes. + * \param max_algos Maximum number of algorithms to retrieve. + * \return Number of algorithms written to algo_buf. + */ +inline int get_algorithms(int m, int n, int k, cudaDataType_t d_type, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, + void* algo_buf, int max_algos) { + GemmDescriptors desc(m, n, k, d_type); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + int request_count = (max_algos > kMaxAlgorithms) ? kMaxAlgorithms : max_algos; + std::array results; + int returned_count = 0; + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), request_count, results.data(), &returned_count); + if (status != CUBLAS_STATUS_SUCCESS) return 0; + + auto* out = static_cast(algo_buf); + for (int i = 0; i < returned_count; ++i) { + std::memcpy(out + i * kAlgoBytes, &results[i].algo, kAlgoBytes); + } + return returned_count; +} + +/*! + * \brief Run a BF16 GEMM using a pre-resolved algorithm — zero heuristic overhead. + * + * \param algo_buf Buffer of serialized cublasLtMatmulAlgo_t structs (from get_algorithms). + * \param algo_idx Index into algo_buf selecting which algorithm to use. + */ +inline cublasStatus_t run_with_algo(const __nv_bfloat16* mat1, const __nv_bfloat16* mat2, void* out, + int m, int n, int k, cudaDataType_t d_type, void* workspace, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, + cudaStream_t stream, const void* algo_buf, int algo_idx) { + GemmDescriptors desc(m, n, k, d_type); + + cublasLtMatmulAlgo_t algo; + std::memcpy(&algo, static_cast(algo_buf) + algo_idx * kAlgoBytes, kAlgoBytes); + + const float alpha = 1.0f; + const float beta = 0.0f; + FLASHINFER_CUBLAS_CALL(cublasLtMatmul( + lt_handle, desc.matmul_desc.descriptor(), &alpha, mat2, desc.a_layout.descriptor(), mat1, + desc.b_layout.descriptor(), &beta, nullptr, desc.d_layout.descriptor(), out, + desc.d_layout.descriptor(), &algo, workspace, workspace_size_in_bytes, stream)); + return CUBLAS_STATUS_SUCCESS; +} + +} // namespace mm_bf16_cublaslt +} // namespace flashinfer + +#endif // FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 493dd38f73..646ac654f5 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("n", [80, 64]) @pytest.mark.parametrize("k", [64, 256]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("backend", ["cutlass", "cudnn"]) +@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "auto"]) def test_bmm_bf16(b, m, n, k, res_dtype, backend): compute_capability = get_compute_capability(torch.device(device="cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] @@ -21,8 +21,11 @@ def test_bmm_bf16(b, m, n, k, res_dtype, backend): f"bmm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) - if not bmm_bf16.is_backend_supported(backend, compute_capability_number): - pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend != "auto": + if not bmm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip( + f"{backend} backend not supported on current compute capability." + ) if backend == "cudnn" and not CUDNN_AVAILABLE: pytest.skip("cuDNN is not available on this system.") diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index 7d83c7f5c0..f5371d48d4 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -13,7 +13,8 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize("pdl", [True, False]) -@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv"]) +@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]) +@pytest.mark.parametrize("auto_tuning", [False, True]) def test_mm_bf16( m: int, n: int, @@ -22,6 +23,7 @@ def test_mm_bf16( enable_bias: bool, pdl: bool, backend: str, + auto_tuning: bool, ): compute_capability = get_compute_capability(torch.device(device="cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] @@ -30,12 +32,17 @@ def test_mm_bf16( f"mm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) - if not mm_bf16.is_backend_supported(backend, compute_capability_number): - pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend != "auto": + if not mm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip( + f"{backend} backend not supported on current compute capability." + ) if backend == "cudnn" and not CUDNN_AVAILABLE: pytest.skip("cuDNN is not available on this system.") + if backend == "auto" and (enable_bias or pdl): + pytest.skip("mm_bf16 with auto backend does not support bias or pdl arguments.") if backend == "cudnn" and (enable_bias or pdl): pytest.skip( "mm_bf16 with cuDNN backend does not support bias or pdl arguments." @@ -44,6 +51,10 @@ def test_mm_bf16( pytest.skip( "mm_bf16 with CUTLASS backend does not support bias or pdl arguments." ) + if backend == "cublaslt" and (enable_bias or pdl): + pytest.skip( + "mm_bf16 with cuBLASLt backend does not support bias or pdl arguments." + ) if res_dtype != torch.bfloat16 and backend == "tgv": pytest.skip( "mm_bf16 with TGV backend does not support specifying non-bfloat16 result dtypes." @@ -68,12 +79,41 @@ def test_mm_bf16( reference = torch.mm(input, mat2.T) out = torch.empty([m, n], device="cuda", dtype=res_dtype) - with autotune(): + with autotune(auto_tuning): mm_bf16(input, mat2.T, bias, pdl, out, res_dtype, backend) cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) assert cos_sim > 0.99 +def test_cublaslt_bf16_runner_zero_algos(): + """CublasltBf16GemmRunner.forward() must raise when heuristic returns 0 algorithms.""" + from flashinfer.gemm.gemm_base import get_mm_bf16_cublaslt_module + from flashinfer.utils import get_compute_capability + + compute_capability = get_compute_capability(torch.device("cuda")) + cc_num = compute_capability[0] * 10 + compute_capability[1] + if not mm_bf16.is_backend_supported("cublaslt", cc_num): + pytest.skip("cublaslt backend not supported on this GPU") + + runner = get_mm_bf16_cublaslt_module().cublaslt_bf16_gemm_runner() + + m, n, k = 16, 1024, 1024 + a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16).transpose(-2, -1) + out = torch.empty(m, n, device="cuda", dtype=torch.bfloat16) + workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + inputs = [a, b, None, None, out, workspace] + + zero_algo_buf = torch.empty(0, dtype=torch.uint8, device="cpu") + original_get_algos = runner._get_algos + runner._get_algos = lambda _inputs: (zero_algo_buf, 0) + try: + with pytest.raises(RuntimeError, match="zero algorithms"): + runner.forward(inputs) + finally: + runner._get_algos = original_get_algos + + if __name__ == "__main__": pytest.main([__file__])