From 044dc7ee0bb3f24df591b1f353b3a7dcd2a99ba9 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 00:40:59 +0400 Subject: [PATCH 01/15] add cublaslt to mm_bf16 Signed-off-by: Vadim Gimpelson --- benchmarks/routines/gemm.py | 15 +- csrc/mm_bf16_cublaslt.cu | 93 +++++++++++++ flashinfer/gemm/gemm_base.py | 125 +++++++++++++++-- flashinfer/jit/gemm/__init__.py | 2 + flashinfer/jit/gemm/core.py | 10 ++ include/flashinfer/gemm/mm_bf16_cublaslt.cuh | 137 +++++++++++++++++++ tests/gemm/test_mm_bf16.py | 8 +- 7 files changed, 375 insertions(+), 15 deletions(-) create mode 100644 csrc/mm_bf16_cublaslt.cu create mode 100644 include/flashinfer/gemm/mm_bf16_cublaslt.cuh diff --git a/benchmarks/routines/gemm.py b/benchmarks/routines/gemm.py index 41ba7b3c18..0209e1bd40 100644 --- a/benchmarks/routines/gemm.py +++ b/benchmarks/routines/gemm.py @@ -140,7 +140,16 @@ def parse_gemm_args(line, parser): required=False, nargs="+", default=["cudnn"], - choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "cute-dsl", "auto"], + choices=[ + "cudnn", + "cublas", + "trtllm", + "cutlass", + "tgv", + "cublaslt", + "cute-dsl", + "auto", + ], help="Kernel backends to test. Default: cudnn", ) parser.add_argument( @@ -1553,7 +1562,7 @@ def testMmBf16(args): use_pdl = getattr(args, "enable_pdl", False) is_cuda_graph_compatible = not args.no_cuda_graph run_refcheck = args.refcheck - autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"] + autotune_supported_backends = ["cudnn", "cutlass", "tgv", "cublaslt", "auto"] res = [] out_dtype = dtype_str_to_torch_dtype(args.out_dtype) @@ -1618,7 +1627,7 @@ def testMmBf16(args): return res def run_backend(backend, a, b, bias, use_pdl, out_dtype): - if backend in ["cudnn", "cutlass", "tgv", "auto"]: + if backend in ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]: return flashinfer.mm_bf16( a=a, b=b, diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu new file mode 100644 index 0000000000..baa630acf8 --- /dev/null +++ b/csrc/mm_bf16_cublaslt.cu @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include "tvm_ffi_utils.h" + +namespace { + +cudaDataType_t get_d_type(DLDataType dtype) { + switch (encode_dlpack_dtype(dtype)) { + case bfloat16_code: + return CUDA_R_16BF; + case float16_code: + return CUDA_R_16F; + case float32_code: + return CUDA_R_32F; + default: + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32."; + return CUDA_R_16BF; + } +} + +} // namespace + +// mat1: (m, k) bf16 contiguous row-major +// mat2: (n, k) bf16 contiguous row-major (Python passes b.transpose(-2,-1)) +// out: (m, n) bf16/fp16/fp32 contiguous row-major +void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle, int64_t tactic) { + CHECK_CUDA(mat1); + CHECK_CUDA(mat2); + CHECK_CUDA(out); + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + CHECK_DIM(2, out); + + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + + TVM_FFI_ICHECK_EQ(mat2.size(1), k) + << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); + TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; + TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); + auto stream = get_stream(mat1.device()); + cudaDataType_t d_type = get_d_type(out.dtype()); + + auto status = flashinfer::mm_bf16_cublaslt::run( + static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), + out.data_ptr(), static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.data_ptr(), workspace_buffer.numel(), lt_handle, stream, + static_cast(tactic)); + TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) + << "mm_bf16_cublaslt failed: " << cublasGetStatusString(status); +} + +int64_t mm_bf16_cublaslt_tactic_num(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle) { + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + cudaDataType_t d_type = get_d_type(out.dtype()); + + auto lt_handle = reinterpret_cast(cublas_handle); + return static_cast(flashinfer::mm_bf16_cublaslt::get_algorithm_count( + static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.numel(), lt_handle)); +} + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt, mm_bf16_cublaslt); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_tactic_num, mm_bf16_cublaslt_tactic_num); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 02b54b5b8b..a9238d9397 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -54,6 +54,7 @@ from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 from ..jit.gemm import gen_gemm_sm100_module_cutlass_mxfp8 from ..jit.gemm import gen_gemm_sm100_module_cutlass_bf16 +from ..jit.gemm import gen_mm_bf16_cublaslt_module from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module from ..jit.gemm import gen_deepgemm_sm100_module @@ -197,7 +198,7 @@ def _cutlass_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -213,6 +214,34 @@ def _cutlass_mm_bf16_requirement( return True +@supported_compute_capability([100, 103]) +def _cublaslt_mm_bf16_requirement( + a: torch.Tensor, + b: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: torch.dtype = torch.bfloat16, + bias: Optional[torch.Tensor] = None, + pdl: bool = False, + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", +): + if bias is not None: + raise ValueError( + "You cannot use the cuBLASLt backend with a bias. Use the TGV backend instead." + ) + if pdl: + raise ValueError( + "The cuBLASLt backend does not support PDL. Use the TGV backend instead." + ) + if out_dtype == torch.float16: + raise ValueError( + "cuBLASLt BF16 GEMM does not support float16 output. " + "Use bfloat16 or float32, or use the CUTLASS/cuDNN backend for float16 output." + ) + _validate_bf16_output_dtype(out_dtype) + + return True + + @supported_compute_capability([100, 103]) def _cudnn_mm_bf16_requirement( a: torch.Tensor, @@ -221,7 +250,7 @@ def _cudnn_mm_bf16_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if bias is not None: raise ValueError( @@ -246,7 +275,7 @@ def _tgv_gemm_requirement( out_dtype: torch.dtype = torch.bfloat16, bias: Optional[torch.Tensor] = None, pdl: bool = False, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if out_dtype != torch.bfloat16: raise ValueError( @@ -262,7 +291,7 @@ def _check_mm_bf16_problem_size( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): if a.dtype != torch.bfloat16: raise ValueError( @@ -303,16 +332,17 @@ def _heuristic_func_mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ): heuristic_backends = [] if bias is not None or pdl: - # cuDNN and CUTLASS don't support bias/pdl, only TGV does if "tgv" in suitable_backends: heuristic_backends.append("tgv") else: if "cudnn" in suitable_backends: heuristic_backends.append("cudnn") + if "cublaslt" in suitable_backends: + heuristic_backends.append("cublaslt") if "cutlass" in suitable_backends: heuristic_backends.append("cutlass") if "tgv" in suitable_backends: @@ -325,6 +355,7 @@ def _heuristic_func_mm_bf16( "cudnn": _cudnn_mm_bf16_requirement, "cutlass": _cutlass_mm_bf16_requirement, "tgv": _tgv_gemm_requirement, + "cublaslt": _cublaslt_mm_bf16_requirement, }, common_check=_check_mm_bf16_problem_size, heuristic_func=_heuristic_func_mm_bf16, @@ -337,7 +368,7 @@ def mm_bf16( pdl: bool = False, out: Optional[torch.Tensor] = None, out_dtype: torch.dtype = torch.bfloat16, - backend: Literal["cudnn", "cutlass", "tgv", "auto"] = "cudnn", + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] = "cudnn", ) -> torch.Tensor: r"""MM BF16 @@ -356,18 +387,19 @@ def mm_bf16( Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to ``False``. out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends. - Defaults to ``None``. + Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt + backends. Defaults to ``None``. out_dtype: torch.dtype - Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS and cuDNN backends. + Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt backends. Defaults to ``torch.bfloat16``. - backend: Literal["cudnn", "cutlass", "tgv", "auto"] + backend: Literal["cudnn", "cutlass", "tgv", "cublaslt", "auto"] The backend to use for the operation. Defaults to ``"cudnn"``. ``"cudnn"`` uses the cuDNN backend. ``"cutlass"`` uses the CUTLASS backend. ``"tgv"`` uses the TGV backend. + ``"cublaslt"`` uses the cuBLASLt backend with heuristic algorithm search. ``"auto"`` allows selecting the best tactic from all available backends when autotune is enabled. Returns @@ -401,6 +433,12 @@ def mm_bf16( torch.Size([48, 80]) >>> out.dtype torch.bfloat16 + >>> # Using the cuBLASLt backend + >>> out = flashinfer.mm_bf16(a, b, backend="cublaslt") + >>> out.shape + torch.Size([48, 80]) + >>> out.dtype + torch.bfloat16 """ if out is None: @@ -427,6 +465,10 @@ def mm_bf16( backends = _heuristic_func_mm_bf16( ["tgv"], a, b, bias, pdl, out, out_dtype, backend ) + elif backend == "cublaslt": + backends = _heuristic_func_mm_bf16( + ["cublaslt"], a, b, None, False, out, out_dtype, backend + ) else: backends = [backend] @@ -870,6 +912,65 @@ def forward( ) +@functools.cache +def get_mm_bf16_cublaslt_module(): + module = gen_mm_bf16_cublaslt_module().build_and_load() + + def cublaslt_bf16_gemm_runner(): + class CublasltBf16GemmRunner(TunableRunner): + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + a, b, _, _, out, workspace_buffer = inputs + cublas_handle = torch.cuda.current_blas_handle() + num_tactics = module.mm_bf16_cublaslt_tactic_num( + a, + b.transpose(-2, -1), + out, + workspace_buffer, + cublas_handle, + ) + return list(range(num_tactics)) + + def forward( + self, + inputs: List[torch.Tensor], + tactic: int = -1, + do_preparation: bool = False, + **kwargs, + ) -> torch.Tensor: + a, b, _, _, out, workspace_buffer = inputs + cublas_handle = torch.cuda.current_blas_handle() + b_t = b.transpose(-2, -1) + if tactic >= 0: + num_available = module.mm_bf16_cublaslt_tactic_num( + a, + b_t, + out, + workspace_buffer, + cublas_handle, + ) + if tactic >= num_available: + tactic = 0 + module.mm_bf16_cublaslt( + a, + b_t, + out, + workspace_buffer, + cublas_handle, + tactic, + ) + return out + + return CublasltBf16GemmRunner() + + return SimpleNamespace( + cublaslt_bf16_gemm_runner=cublaslt_bf16_gemm_runner, + ) + + _BF16_GEMM_SM100_TUNING_CONFIG = TuningConfig( dynamic_tensor_specs=( DynamicTensorSpec( @@ -902,6 +1003,8 @@ def bf16_gemm_sm100( use_sm_100f = is_sm100f_supported(a.device) if "cudnn" in runner_names: runners.append(_cudnn_gemm_bf16_runner()) + if "cublaslt" in runner_names: + runners.append(get_mm_bf16_cublaslt_module().cublaslt_bf16_gemm_runner()) if "cutlass" in runner_names: runners.append(get_gemm_sm100_module_cutlass_bf16().cutlass_bf16_gemm_runner()) if "tgv" in runner_names: diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index 7fa72c353d..92c11b6859 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -22,6 +22,7 @@ gen_gemm_sm100_module_cutlass_fp8, gen_gemm_sm100_module_cutlass_mxfp8, gen_gemm_sm100_module_cutlass_bf16, + gen_mm_bf16_cublaslt_module, gen_gemm_sm100_module, gen_gemm_sm120_module, gen_trtllm_gen_gemm_module, @@ -40,6 +41,7 @@ "gen_gemm_sm100_module_cutlass_fp8", "gen_gemm_sm100_module_cutlass_mxfp8", "gen_gemm_sm100_module_cutlass_bf16", + "gen_mm_bf16_cublaslt_module", "gen_gemm_sm100_module", "gen_gemm_sm120_module", "gen_trtllm_gen_gemm_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 590f839678..c435e0fc86 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -50,6 +50,16 @@ def gen_gemm_module() -> JitSpec: ) +def gen_mm_bf16_cublaslt_module() -> JitSpec: + return gen_jit_spec( + "mm_bf16_cublaslt", + [ + jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu", + ], + extra_ldflags=["-lcublas", "-lcublasLt"], + ) + + def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm100_cutlass_fp4" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh new file mode 100644 index 0000000000..ecdb180248 --- /dev/null +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2025 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ +#define FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ + +#include +#include +#include + +#include + +#include "bmm_fp8.cuh" + +namespace flashinfer { +namespace mm_bf16_cublaslt { + +using bmm_fp8::CuBlasLtMatmulDescriptor; +using bmm_fp8::CuBlasLtMatmulPreference; +using bmm_fp8::CuBlasLtMatrixLayout; + +static constexpr int kMaxAlgorithms = 100; + +/*! + * \brief Set up cuBLASLt descriptors for BF16 GEMM in row-major convention. + * + * Python mm_bf16 passes mat1 (m,k) row-major and mat2 (n,k) row-major + * (after b.transpose(-2,-1)). We want out (m,n) row-major. + * + * cuBLASLt is column-major, so we use the standard trick: + * out^T = mat2 @ mat1^T (all column-major) + * + * Memory layouts: + * mat2 row-major (n,k) = col-major (k,n) ld=k → cuBLASLt "A", TRANSA=T → (n,k) + * mat1 row-major (m,k) = col-major (k,m) ld=k → cuBLASLt "B", TRANSB=N → (k,m) + * out row-major (m,n) = col-major (n,m) ld=n → cuBLASLt "D" + * + * Result: (n,k)×(k,m) = (n,m) col-major = (m,n) row-major ✓ + */ +struct GemmDescriptors { + CuBlasLtMatmulDescriptor matmul_desc; + CuBlasLtMatrixLayout a_layout; // mat2 + CuBlasLtMatrixLayout b_layout; // mat1 + CuBlasLtMatrixLayout d_layout; // out + + GemmDescriptors(int m, int n, int k, cudaDataType_t d_type) + : matmul_desc(CUBLAS_COMPUTE_32F, CUDA_R_32F), + a_layout(CUDA_R_16BF, n, k, k, /*t=*/true), + b_layout(CUDA_R_16BF, k, m, k), + d_layout(d_type, n, m, n) { + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + } +}; + +/*! + * \brief Get the number of available cuBLASLt algorithms for a BF16 GEMM. + */ +inline int get_algorithm_count(int m, int n, int k, cudaDataType_t d_type, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle) { + GemmDescriptors desc(m, n, k, d_type); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + std::vector results(kMaxAlgorithms); + int returned_count = 0; + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), kMaxAlgorithms, results.data(), &returned_count); + if (status != CUBLAS_STATUS_SUCCESS) { + return 0; + } + return returned_count; +} + +/*! + * \brief Run a BF16 GEMM using a specific cuBLASLt algorithm (tactic). + * + * \param mat1 Pointer to mat1 data, row-major (m, k) + * \param mat2 Pointer to mat2 data, row-major (n, k) — after b.transpose(-2,-1) in Python + * \param out Pointer to output data, row-major (m, n) + * \param tactic Algorithm index; -1 means use the top heuristic (index 0). + */ +inline cublasStatus_t run(const __nv_bfloat16* mat1, const __nv_bfloat16* mat2, void* out, int m, + int n, int k, cudaDataType_t d_type, void* workspace, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, + cudaStream_t stream, int tactic) { + GemmDescriptors desc(m, n, k, d_type); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + int algo_idx = (tactic < 0) ? 0 : tactic; + int request_count = algo_idx + 1; + if (request_count > kMaxAlgorithms) { + request_count = kMaxAlgorithms; + } + + std::vector results(request_count); + int returned_count = 0; + cublasStatus_t heur_status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), request_count, results.data(), &returned_count); + if (heur_status != CUBLAS_STATUS_SUCCESS || returned_count <= algo_idx) { + return CUBLAS_STATUS_NOT_SUPPORTED; + } + + const float alpha = 1.0f; + const float beta = 0.0f; + // Note: mat2 is cuBLASLt "A", mat1 is cuBLASLt "B" (swap for row-major output) + FLASHINFER_CUBLAS_CALL( + cublasLtMatmul(lt_handle, desc.matmul_desc.descriptor(), &alpha, mat2, + desc.a_layout.descriptor(), mat1, desc.b_layout.descriptor(), &beta, nullptr, + desc.d_layout.descriptor(), out, desc.d_layout.descriptor(), + &results[algo_idx].algo, workspace, workspace_size_in_bytes, stream)); + return CUBLAS_STATUS_SUCCESS; +} + +} // namespace mm_bf16_cublaslt +} // namespace flashinfer + +#endif // FLASHINFER_GEMM_MM_BF16_CUBLASLT_CUH_ diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index 7d83c7f5c0..805c41c909 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize("pdl", [True, False]) -@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv"]) +@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv", "cublaslt"]) def test_mm_bf16( m: int, n: int, @@ -44,6 +44,12 @@ def test_mm_bf16( pytest.skip( "mm_bf16 with CUTLASS backend does not support bias or pdl arguments." ) + if backend == "cublaslt" and (enable_bias or pdl): + pytest.skip( + "mm_bf16 with cuBLASLt backend does not support bias or pdl arguments." + ) + if backend == "cublaslt" and res_dtype == torch.float16: + pytest.skip("mm_bf16 with cuBLASLt backend does not support float16 output.") if res_dtype != torch.bfloat16 and backend == "tgv": pytest.skip( "mm_bf16 with TGV backend does not support specifying non-bfloat16 result dtypes." From f6073b418ae5564243f5e3d55f66226f4eec8752 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 02:19:13 +0400 Subject: [PATCH 02/15] fixes Signed-off-by: Vadim Gimpelson --- csrc/mm_bf16_cublaslt.cu | 13 +++++++------ flashinfer/aot.py | 2 ++ include/flashinfer/gemm/mm_bf16_cublaslt.cuh | 9 ++++----- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index baa630acf8..fe583c6138 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025 by FlashInfer team. + * Copyright (c) 2026 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,8 +42,8 @@ cudaDataType_t get_d_type(DLDataType dtype) { // mat1: (m, k) bf16 contiguous row-major // mat2: (n, k) bf16 contiguous row-major (Python passes b.transpose(-2,-1)) // out: (m, n) bf16/fp16/fp32 contiguous row-major -void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, - TensorView workspace_buffer, int64_t cublas_handle, int64_t tactic) { +void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, + int64_t cublas_handle, int64_t tactic) { CHECK_CUDA(mat1); CHECK_CUDA(mat2); CHECK_CUDA(out); @@ -70,8 +70,8 @@ void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, auto status = flashinfer::mm_bf16_cublaslt::run( static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), out.data_ptr(), static_cast(m), static_cast(n), static_cast(k), d_type, - workspace_buffer.data_ptr(), workspace_buffer.numel(), lt_handle, stream, - static_cast(tactic)); + workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer), + lt_handle, stream, static_cast(tactic)); TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) << "mm_bf16_cublaslt failed: " << cublasGetStatusString(status); } @@ -83,10 +83,11 @@ int64_t mm_bf16_cublaslt_tactic_num(TensorView mat1, TensorView mat2, TensorView int64_t n = mat2.size(0); cudaDataType_t d_type = get_d_type(out.dtype()); + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); auto lt_handle = reinterpret_cast(cublas_handle); return static_cast(flashinfer::mm_bf16_cublaslt::get_algorithm_count( static_cast(m), static_cast(n), static_cast(k), d_type, - workspace_buffer.numel(), lt_handle)); + workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle)); } TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt, mm_bf16_cublaslt); diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 9909befbbd..aaa5b0ca90 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -77,6 +77,7 @@ gen_gemm_sm100_module_cutlass_mxfp8, gen_gemm_sm120_module, gen_gemm_sm120_module_cutlass_fp4, + gen_mm_bf16_cublaslt_module, gen_tgv_gemm_sm10x_module, gen_trtllm_gen_gemm_module, gen_trtllm_low_latency_gemm_module, @@ -493,6 +494,7 @@ def gen_all_modules( jit_specs.append(gen_gemm_sm100_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp8()) jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8()) + jit_specs.append(gen_mm_bf16_cublaslt_module()) # Add TGV GEMM modules for both bf16 and fp16 jit_specs.append( gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False) diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh index ecdb180248..6462beb6a7 100644 --- a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025 by FlashInfer team. + * Copyright (c) 2026 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,8 @@ #include #include -#include -#include +#include #include "bmm_fp8.cuh" @@ -75,7 +74,7 @@ inline int get_algorithm_count(int m, int n, int k, cudaDataType_t d_type, CuBlasLtMatmulPreference preference; preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); - std::vector results(kMaxAlgorithms); + std::array results; int returned_count = 0; cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), @@ -110,7 +109,7 @@ inline cublasStatus_t run(const __nv_bfloat16* mat1, const __nv_bfloat16* mat2, request_count = kMaxAlgorithms; } - std::vector results(request_count); + std::array results; int returned_count = 0; cublasStatus_t heur_status = cublasLtMatmulAlgoGetHeuristic( lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), From 40d0ac32988d5127d4e4e2328fbafae6577ca3f1 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 12:53:46 +0400 Subject: [PATCH 03/15] perf fix Signed-off-by: Vadim Gimpelson --- csrc/mm_bf16_cublaslt.cu | 59 +++++++++++++++++++ flashinfer/gemm/gemm_base.py | 58 +++++++++++++------ include/flashinfer/gemm/mm_bf16_cublaslt.cuh | 61 ++++++++++++++++++++ 3 files changed, 160 insertions(+), 18 deletions(-) diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index fe583c6138..6827ef6f2f 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -90,5 +90,64 @@ int64_t mm_bf16_cublaslt_tactic_num(TensorView mat1, TensorView mat2, TensorView workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle)); } +// Serialize all heuristic algorithms into a CPU uint8 tensor for caching. +// algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes. +// Returns number of algorithms written. +int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer) { + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + cudaDataType_t d_type = get_d_type(out.dtype()); + + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); + auto lt_handle = reinterpret_cast(cublas_handle); + int max_algos = static_cast(algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::mm_bf16_cublaslt::kAlgoBytes); + return static_cast(flashinfer::mm_bf16_cublaslt::get_algorithms( + static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle, + algo_buffer.data_ptr(), max_algos)); +} + +// Run matmul using a pre-cached algorithm — zero heuristic overhead. +void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out, + TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx) { + CHECK_CUDA(mat1); + CHECK_CUDA(mat2); + CHECK_CUDA(out); + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + CHECK_DIM(2, out); + + int64_t m = mat1.size(0); + int64_t k = mat1.size(1); + int64_t n = mat2.size(0); + + TVM_FFI_ICHECK_EQ(mat2.size(1), k) + << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); + TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; + TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(mat1.device().device_id); + auto stream = get_stream(mat1.device()); + cudaDataType_t d_type = get_d_type(out.dtype()); + + auto status = flashinfer::mm_bf16_cublaslt::run_with_algo( + static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), + out.data_ptr(), static_cast(m), static_cast(n), static_cast(k), d_type, + workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer), + lt_handle, stream, algo_buffer.data_ptr(), static_cast(algo_idx)); + TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) + << "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status); +} + TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt, mm_bf16_cublaslt); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_tactic_num, mm_bf16_cublaslt_tactic_num); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a9238d9397..029dd88241 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -916,23 +916,41 @@ def forward( def get_mm_bf16_cublaslt_module(): module = gen_mm_bf16_cublaslt_module().build_and_load() + _ALGO_BYTES = 64 # sizeof(cublasLtMatmulAlgo_t) = uint64_t[8] + _MAX_ALGOS = 100 + def cublaslt_bf16_gemm_runner(): class CublasltBf16GemmRunner(TunableRunner): - def get_valid_tactics( - self, - inputs: List[torch.Tensor], - profile: OptimizationProfile, - ) -> List[int]: + def __init__(self): + self._algo_cache: dict = {} + + def _get_algos(self, inputs): a, b, _, _, out, workspace_buffer = inputs + key = (a.shape[0], b.shape[0], a.shape[1], out.dtype) + cached = self._algo_cache.get(key) + if cached is not None: + return cached + algo_buf = torch.empty(_MAX_ALGOS * _ALGO_BYTES, dtype=torch.uint8) cublas_handle = torch.cuda.current_blas_handle() - num_tactics = module.mm_bf16_cublaslt_tactic_num( + count = module.mm_bf16_cublaslt_get_algos( a, b.transpose(-2, -1), out, workspace_buffer, cublas_handle, + algo_buf, ) - return list(range(num_tactics)) + result = (algo_buf, count) + self._algo_cache[key] = result + return result + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + ) -> List[int]: + _, count = self._get_algos(inputs) + return list(range(count)) def forward( self, @@ -945,23 +963,27 @@ def forward( cublas_handle = torch.cuda.current_blas_handle() b_t = b.transpose(-2, -1) if tactic >= 0: - num_available = module.mm_bf16_cublaslt_tactic_num( + algo_buf, count = self._get_algos(inputs) + if tactic >= count: + tactic = 0 + module.mm_bf16_cublaslt_run_with_algo( a, b_t, out, workspace_buffer, cublas_handle, + algo_buf, + tactic, + ) + else: + module.mm_bf16_cublaslt( + a, + b_t, + out, + workspace_buffer, + cublas_handle, + tactic, ) - if tactic >= num_available: - tactic = 0 - module.mm_bf16_cublaslt( - a, - b_t, - out, - workspace_buffer, - cublas_handle, - tactic, - ) return out return CublasltBf16GemmRunner() diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh index 6462beb6a7..ed327ad3ad 100644 --- a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -20,6 +20,7 @@ #include #include +#include #include "bmm_fp8.cuh" @@ -31,6 +32,8 @@ using bmm_fp8::CuBlasLtMatmulPreference; using bmm_fp8::CuBlasLtMatrixLayout; static constexpr int kMaxAlgorithms = 100; +// cublasLtMatmulAlgo_t is { uint64_t data[8]; } — 64 bytes, trivially serializable per NVIDIA docs +static constexpr size_t kAlgoBytes = sizeof(cublasLtMatmulAlgo_t); /*! * \brief Set up cuBLASLt descriptors for BF16 GEMM in row-major convention. @@ -86,6 +89,64 @@ inline int get_algorithm_count(int m, int n, int k, cudaDataType_t d_type, return returned_count; } +/*! + * \brief Query heuristics once and serialize all cublasLtMatmulAlgo_t structs into a buffer. + * + * Each algo occupies kAlgoBytes (64) contiguous bytes. The buffer can be cached + * and later passed to run_with_algo() to skip the heuristic lookup entirely. + * + * \param algo_buf Output buffer, must hold at least max_algos * kAlgoBytes bytes. + * \param max_algos Maximum number of algorithms to retrieve. + * \return Number of algorithms written to algo_buf. + */ +inline int get_algorithms(int m, int n, int k, cudaDataType_t d_type, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, + void* algo_buf, int max_algos) { + GemmDescriptors desc(m, n, k, d_type); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + int request_count = (max_algos > kMaxAlgorithms) ? kMaxAlgorithms : max_algos; + std::array results; + int returned_count = 0; + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), request_count, results.data(), &returned_count); + if (status != CUBLAS_STATUS_SUCCESS) return 0; + + auto* out = static_cast(algo_buf); + for (int i = 0; i < returned_count; ++i) { + std::memcpy(out + i * kAlgoBytes, &results[i].algo, kAlgoBytes); + } + return returned_count; +} + +/*! + * \brief Run a BF16 GEMM using a pre-resolved algorithm — zero heuristic overhead. + * + * \param algo_buf Buffer of serialized cublasLtMatmulAlgo_t structs (from get_algorithms). + * \param algo_idx Index into algo_buf selecting which algorithm to use. + */ +inline cublasStatus_t run_with_algo(const __nv_bfloat16* mat1, const __nv_bfloat16* mat2, void* out, + int m, int n, int k, cudaDataType_t d_type, void* workspace, + size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, + cudaStream_t stream, const void* algo_buf, int algo_idx) { + GemmDescriptors desc(m, n, k, d_type); + + cublasLtMatmulAlgo_t algo; + std::memcpy(&algo, static_cast(algo_buf) + algo_idx * kAlgoBytes, kAlgoBytes); + + const float alpha = 1.0f; + const float beta = 0.0f; + FLASHINFER_CUBLAS_CALL(cublasLtMatmul( + lt_handle, desc.matmul_desc.descriptor(), &alpha, mat2, desc.a_layout.descriptor(), mat1, + desc.b_layout.descriptor(), &beta, nullptr, desc.d_layout.descriptor(), out, + desc.d_layout.descriptor(), &algo, workspace, workspace_size_in_bytes, stream)); + return CUBLAS_STATUS_SUCCESS; +} + /*! * \brief Run a BF16 GEMM using a specific cuBLASLt algorithm (tactic). * From 17233aaf498c6b41fad86cd40734b8a170b53700 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 12:57:50 +0400 Subject: [PATCH 04/15] fix Signed-off-by: Vadim Gimpelson --- flashinfer/gemm/gemm_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 029dd88241..17d3a531f1 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -384,7 +384,7 @@ def mm_bf16( Optional bias tensor, shape (n,). Enabled for TGV backend. Defaults to ``None``. pdl: bool - Whether to use persistant data loader mode. Enabled for TGV backend. Defaults to ``False``. + Whether to enable Programmatic Dependent Launch (PDL). Enabled for TGV backend. Defaults to ``False``. out: Optional[torch.Tensor] Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt From bdeaa6eaf3bad015c459c2b437144e8e89247f70 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 14:02:44 +0400 Subject: [PATCH 05/15] Enable multi-tactic autotuning for CublasFp8, CudnnFp8, and CudnnMxfp8 GEMM runners Signed-off-by: Vadim Gimpelson --- csrc/bmm_fp8.cu | 69 +++++++++++++ csrc/flashinfer_gemm_binding.cu | 10 ++ flashinfer/gemm/gemm_base.py | 151 ++++++++++++++++++++++++---- include/flashinfer/gemm/bmm_fp8.cuh | 103 +++++++++++++++++++ tests/gemm/test_mm_bf16.py | 4 +- 5 files changed, 317 insertions(+), 20 deletions(-) diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index 4de464fac0..da8a8670e4 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -61,3 +61,72 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso }); }); } + +int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer) { + int64_t result = 0; + // Same dtype dispatch and pointer swap as bmm_fp8 + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(A.device().device_id); + + int max_algos = static_cast(algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::bmm_fp8::kAlgoBytes); + result = flashinfer::bmm_fp8::get_fp8_algorithms( + batch_size, n, m, k, static_cast(B_scale.data_ptr()), + static_cast(A_scale.data_ptr()), workspace_buffer.numel(), lt_handle, + algo_buffer.data_ptr(), max_algos); + return true; + }); + }); + }); + return static_cast(result); +} + +void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx) { + CHECK_CUDA(A); + CHECK_CUDA(B); + CHECK_CUDA(D); + CHECK_DIM(3, A); + CHECK_DIM(3, B); + CHECK_DIM(3, D); + TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; + TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; + TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) + << "Result tensor has incorrect shape"; + + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); + + auto lt_handle = reinterpret_cast(cublas_handle); + ffi::CUDADeviceGuard device_guard(A.device().device_id); + auto stream = get_stream(A.device()); + + auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo( + workspace_buffer.data_ptr(), workspace_buffer.numel(), + static_cast(B.data_ptr()), static_cast(A.data_ptr()), + static_cast(D.data_ptr()), batch_size, n, m, k, + static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), + lt_handle, stream, algo_buffer.data_ptr(), static_cast(algo_idx)); + TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) + << "bmm_fp8_run_with_algo failed: " << cublasGetStatusString(status); + return true; + }); + }); + }); +} diff --git a/csrc/flashinfer_gemm_binding.cu b/csrc/flashinfer_gemm_binding.cu index 52d0551413..80fffb4de4 100644 --- a/csrc/flashinfer_gemm_binding.cu +++ b/csrc/flashinfer_gemm_binding.cu @@ -19,9 +19,19 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle); +int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer); + +void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale, + TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, + TensorView algo_buffer, int64_t algo_idx); + void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr, TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld, TensorView y_ld, TensorView empty_x_data, bool weight_column_major); TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM); TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_get_algos, bmm_fp8_get_algos); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_run_with_algo, bmm_fp8_run_with_algo); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 17d3a531f1..a011d88713 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -109,16 +109,45 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): def get_gemm_module(): module = gen_gemm_module().build_and_load() - # auto-tuned cublas fp8 gemm runner + _FP8_ALGO_BYTES = 64 # sizeof(cublasLtMatmulAlgo_t) + _FP8_MAX_ALGOS = 100 + def cublas_fp8_gemm_runner(): class CublasFp8GemmRunner(TunableRunner): + def __init__(self): + self._algo_cache: dict = {} + + def _get_algos(self, inputs): + a, b, _, _, out, workspace_buffer = inputs + key = (a.shape, b.shape, a.dtype, b.dtype, out.dtype) + cached = self._algo_cache.get(key) + if cached is not None: + return cached + algo_buf = torch.empty( + _FP8_MAX_ALGOS * _FP8_ALGO_BYTES, dtype=torch.uint8 + ) + cublas_handle = torch.cuda.current_blas_handle() + count = module.bmm_fp8_get_algos( + a, + b, + out, + inputs[2], + inputs[3], + workspace_buffer, + cublas_handle, + algo_buf, + ) + result = (algo_buf, count) + self._algo_cache[key] = result + return result + def get_valid_tactics( self, inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # cublas has heuristic for fp8 gemm, so we only need to use the default tactic - return [0] + _, count = self._get_algos(inputs) + return list(range(count)) def forward( self, @@ -129,9 +158,25 @@ def forward( ) -> torch.Tensor: cublas_handle = torch.cuda.current_blas_handle() a, b, scale_a, scale_b, out, workspace_buffer = inputs - module.bmm_fp8( - a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle - ) + if tactic >= 0: + algo_buf, count = self._get_algos(inputs) + if tactic >= count: + tactic = 0 + module.bmm_fp8_run_with_algo( + a, + b, + out, + scale_a, + scale_b, + workspace_buffer, + cublas_handle, + algo_buf, + tactic, + ) + else: + module.bmm_fp8( + a, b, out, scale_a, scale_b, workspace_buffer, cublas_handle + ) return out return CublasFp8GemmRunner() @@ -658,7 +703,8 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # For now, return a single default tactic + # TODO: add multi-tactic support when PingPong 64x128x128 schedule + # is implemented (see gemm_groupwise_sm120.cuh) return [-1] def forward( @@ -2590,7 +2636,15 @@ def _cudnn_gemm_mxfp8_override_shape( @functools.lru_cache(maxsize=1024) def build_cudnn_gemm_with_per_tensor_q_graph( - a_shape, a_stride, b_shape, b_stride, a_type, b_type, o_type, device + a_shape, + a_stride, + b_shape, + b_stride, + a_type, + b_type, + o_type, + device, + policy=None, ): """Build a cuDNN graph for GEMM with per-tensor quantization. @@ -2604,11 +2658,15 @@ def build_cudnn_gemm_with_per_tensor_q_graph( a_type: Data type for input tensor A b_type: Data type for input tensor B o_type: Data type for output tensor + policy: cuDNN build plan policy. None defaults to HEURISTICS_CHOICE. + Use ALL to enumerate all execution plans for autotuning. Returns: cuDNN graph object """ _check_cudnn_availability() + if policy is None: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE stream = torch.cuda.current_stream(device) with cudnn.graph(_get_cudnn_handle(device, stream)) as (graph, _): @@ -2664,13 +2722,20 @@ def build_cudnn_gemm_with_per_tensor_q_graph( graph.build_operation_graph() graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) graph.check_support() - graph.build_plans() + graph.build_plans(policy) return graph def execute_cudnn_gemm_with_per_tensor_q_graph( - graph, a, b, a_scale, b_scale, c_final, workspace + graph, + a, + b, + a_scale, + b_scale, + c_final, + workspace, + tactic: int = -1, ): variant_pack = { UIDs.A_UID.value: a, @@ -2688,7 +2753,12 @@ def execute_cudnn_gemm_with_per_tensor_q_graph( graph.get_workspace_size(), device=a.device, dtype=torch.uint8 ) - graph.execute(variant_pack, workspace, handle=cudnn_handle) + if tactic == -1: + graph.execute(variant_pack, workspace, handle=cudnn_handle) + else: + graph.execute_plan_at_index( + variant_pack, workspace, tactic, handle=cudnn_handle + ) # --------------------------------------------------------------------------- @@ -2833,9 +2903,15 @@ def _cudnn_gemm_fp8( b_scale: torch.Tensor, out: Optional[torch.Tensor], torch_out_dtype: torch.dtype, + tactic: int = -1, ): _check_cudnn_availability() + if tactic == -1: + policy = cudnn.build_plan_policy.HEURISTICS_CHOICE + else: + policy = cudnn.build_plan_policy.ALL + graph = build_cudnn_gemm_with_per_tensor_q_graph( a.shape, a.stride(), @@ -2845,10 +2921,18 @@ def _cudnn_gemm_fp8( _torch_data_type_to_cudnn_data_type(b.dtype), _torch_data_type_to_cudnn_data_type(torch_out_dtype), a.device, + policy=policy, ) execute_cudnn_gemm_with_per_tensor_q_graph( - graph, a, b, a_scale, b_scale, out, workspace + graph, + a, + b, + a_scale, + b_scale, + out, + workspace, + tactic=tactic, ) return out @@ -2860,8 +2944,19 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # cudnn has heuristic for fp8 gemm, so we only need to use the default tactic - return [0] + a, b, _, _, out, _ = inputs + graph = build_cudnn_gemm_with_per_tensor_q_graph( + a.shape, + a.stride(), + b.shape, + b.stride(), + _torch_data_type_to_cudnn_data_type(a.dtype), + _torch_data_type_to_cudnn_data_type(b.dtype), + _torch_data_type_to_cudnn_data_type(out.dtype), + a.device, + policy=cudnn.build_plan_policy.ALL, + ) + return list(range(graph.get_execution_plan_count())) def forward( self, @@ -2871,7 +2966,16 @@ def forward( **kwargs, ) -> torch.Tensor: a, b, scale_a, scale_b, out, workspace_buffer = inputs - _cudnn_gemm_fp8(workspace_buffer, a, b, scale_a, scale_b, out, out.dtype) + _cudnn_gemm_fp8( + workspace_buffer, + a, + b, + scale_a, + scale_b, + out, + out.dtype, + tactic=tactic, + ) return out return CudnnFp8GemmRunner() @@ -6844,6 +6948,7 @@ def _get_cudnn_mxfp8_gemm_graph( out: Optional[torch.Tensor] = None, block_size: int = 32, # mxfp8 block size is 32 tactic: int = -1, + build_all: bool = False, ): graph = create_cudnn_execution_plans_mxfp8_gemm( a_shape=a.shape, @@ -6858,7 +6963,9 @@ def _get_cudnn_mxfp8_gemm_graph( ) graph.check_support() - if tactic != -1: + if build_all: + graph.build_plans() + elif tactic != -1: graph.build_plan_at_index(tactic) else: graph.build_plans() @@ -6907,9 +7014,15 @@ def get_valid_tactics( inputs: List[torch.Tensor], profile: OptimizationProfile, ) -> List[int]: - # TODO: check if this is correct - # cudnn has heuristic for mxfp8 gemm, so we only need to use the default tactic - return [0] + a, b, _, _, out, _ = inputs + graph = _get_cudnn_mxfp8_gemm_graph( + a=a, + b=b, + out_dtype=out.dtype, + out=out, + build_all=True, + ) + return list(range(graph.get_execution_plan_count())) def forward( self, diff --git a/include/flashinfer/gemm/bmm_fp8.cuh b/include/flashinfer/gemm/bmm_fp8.cuh index 7778934160..4c8539d1c0 100644 --- a/include/flashinfer/gemm/bmm_fp8.cuh +++ b/include/flashinfer/gemm/bmm_fp8.cuh @@ -19,6 +19,8 @@ #include #include +#include +#include #include #include #include @@ -137,6 +139,107 @@ cudaDataType_t get_cuda_data_type() { } } +static constexpr int kMaxFp8Algorithms = 100; +static constexpr size_t kAlgoBytes = sizeof(cublasLtMatmulAlgo_t); + +/*! + * \brief Set up cuBLASLt descriptors for FP8 BMM. + * + * Factors out the common descriptor creation shared by heuristic query, + * run, and run_with_algo paths. + */ +template +struct Fp8GemmDescriptors { + CuBlasLtMatmulDescriptor matmul_desc; + CuBlasLtMatrixLayout a_layout; + CuBlasLtMatrixLayout b_layout; + CuBlasLtMatrixLayout d_layout; + + Fp8GemmDescriptors(int batch_size, int m, int n, int k, const float* A_scale, + const float* B_scale) + : matmul_desc(CUBLAS_COMPUTE_32F, CUDA_R_32F), + a_layout(get_cuda_data_type(), m, k, k, true), + b_layout(get_cuda_data_type(), k, n, k), + d_layout(get_cuda_data_type
(), m, n, m) { + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, CUBLAS_OP_T); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, CUBLAS_OP_N); + int8_t fast_accum = 1; + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_FAST_ACCUM, fast_accum); + + const void* A_scale_ptr = static_cast(A_scale); + const void* B_scale_ptr = static_cast(B_scale); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, A_scale_ptr); + matmul_desc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, B_scale_ptr); + + if constexpr (std::is_same_v && std::is_same_v) { + FLASHINFER_ERROR("Unsupported combination: both A and B are e5m2"); + } + + if (batch_size > 1) { + int64_t stride_a = m * k; + int64_t stride_b = k * n; + int64_t stride_d = m * n; + a_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + a_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_a); + b_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + b_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_b); + d_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, batch_size); + d_layout.setAttribute(CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stride_d); + } + } +}; + +/*! + * \brief Query heuristics and serialize all cublasLtMatmulAlgo_t structs for FP8 BMM. + */ +template +int get_fp8_algorithms(int batch_size, int m, int n, int k, const float* A_scale, + const float* B_scale, size_t workspace_size_in_bytes, + cublasLtHandle_t lt_handle, void* algo_buf, int max_algos) { + Fp8GemmDescriptors desc(batch_size, m, n, k, A_scale, B_scale); + + CuBlasLtMatmulPreference preference; + preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); + + int request_count = (max_algos > kMaxFp8Algorithms) ? kMaxFp8Algorithms : max_algos; + std::array results; + int returned_count = 0; + cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( + lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), + desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), + preference.descriptor(), request_count, results.data(), &returned_count); + if (status != CUBLAS_STATUS_SUCCESS) return 0; + + auto* out = static_cast(algo_buf); + for (int i = 0; i < returned_count; ++i) { + std::memcpy(out + i * kAlgoBytes, &results[i].algo, kAlgoBytes); + } + return returned_count; +} + +/*! + * \brief Run FP8 BMM using a pre-resolved algorithm — zero heuristic overhead. + */ +template +cublasStatus_t bmm_fp8_run_with_algo(void* workspace, size_t workspace_size_in_bytes, const AT* A, + const BT* B, DT* D, int batch_size, int m, int n, int k, + const float* A_scale, const float* B_scale, + cublasLtHandle_t lt_handle, cudaStream_t stream, + const void* algo_buf, int algo_idx) { + Fp8GemmDescriptors desc(batch_size, m, n, k, A_scale, B_scale); + + cublasLtMatmulAlgo_t algo; + std::memcpy(&algo, static_cast(algo_buf) + algo_idx * kAlgoBytes, kAlgoBytes); + + const float alpha = 1.0f; + const float beta = 0.0f; + FLASHINFER_CUBLAS_CALL(cublasLtMatmul( + lt_handle, desc.matmul_desc.descriptor(), &alpha, A, desc.a_layout.descriptor(), B, + desc.b_layout.descriptor(), &beta, nullptr, desc.d_layout.descriptor(), D, + desc.d_layout.descriptor(), &algo, workspace, workspace_size_in_bytes, stream)); + return CUBLAS_STATUS_SUCCESS; +} + template cublasStatus_t bmm_fp8_internal_cublaslt(void* workspace, size_t workspace_size_in_bytes, const AT* A, const BT* B, DT* D, int batch_size, int m, diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index 805c41c909..c75f68563f 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -14,6 +14,7 @@ @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize("pdl", [True, False]) @pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv", "cublaslt"]) +@pytest.mark.parametrize("auto_tuning", [False, True]) def test_mm_bf16( m: int, n: int, @@ -22,6 +23,7 @@ def test_mm_bf16( enable_bias: bool, pdl: bool, backend: str, + auto_tuning: bool, ): compute_capability = get_compute_capability(torch.device(device="cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] @@ -74,7 +76,7 @@ def test_mm_bf16( reference = torch.mm(input, mat2.T) out = torch.empty([m, n], device="cuda", dtype=res_dtype) - with autotune(): + with autotune(auto_tuning): mm_bf16(input, mat2.T, bias, pdl, out, res_dtype, backend) cos_sim = F.cosine_similarity(reference.reshape(-1), out.reshape(-1), dim=0) From 3c4d422f57fb6c7aa3b334b1a2220d5b07cdb074 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 14:27:50 +0400 Subject: [PATCH 06/15] add backed to tests Signed-off-by: Vadim Gimpelson --- tests/gemm/test_bmm_bf16.py | 9 ++++++--- tests/gemm/test_mm_bf16.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/gemm/test_bmm_bf16.py b/tests/gemm/test_bmm_bf16.py index 493dd38f73..646ac654f5 100644 --- a/tests/gemm/test_bmm_bf16.py +++ b/tests/gemm/test_bmm_bf16.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("n", [80, 64]) @pytest.mark.parametrize("k", [64, 256]) @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32]) -@pytest.mark.parametrize("backend", ["cutlass", "cudnn"]) +@pytest.mark.parametrize("backend", ["cutlass", "cudnn", "auto"]) def test_bmm_bf16(b, m, n, k, res_dtype, backend): compute_capability = get_compute_capability(torch.device(device="cuda")) compute_capability_number = compute_capability[0] * 10 + compute_capability[1] @@ -21,8 +21,11 @@ def test_bmm_bf16(b, m, n, k, res_dtype, backend): f"bmm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) - if not bmm_bf16.is_backend_supported(backend, compute_capability_number): - pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend != "auto": + if not bmm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip( + f"{backend} backend not supported on current compute capability." + ) if backend == "cudnn" and not CUDNN_AVAILABLE: pytest.skip("cuDNN is not available on this system.") diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index c75f68563f..ca964ec924 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16, torch.float32]) @pytest.mark.parametrize("enable_bias", [True, False]) @pytest.mark.parametrize("pdl", [True, False]) -@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv", "cublaslt"]) +@pytest.mark.parametrize("backend", ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]) @pytest.mark.parametrize("auto_tuning", [False, True]) def test_mm_bf16( m: int, @@ -32,12 +32,17 @@ def test_mm_bf16( f"mm_bf16 not supported on current compute capability." f"Detected sm{compute_capability_number}." ) - if not mm_bf16.is_backend_supported(backend, compute_capability_number): - pytest.skip(f"{backend} backend not supported on current compute capability.") + if backend != "auto": + if not mm_bf16.is_backend_supported(backend, compute_capability_number): + pytest.skip( + f"{backend} backend not supported on current compute capability." + ) if backend == "cudnn" and not CUDNN_AVAILABLE: pytest.skip("cuDNN is not available on this system.") + if backend == "auto" and (enable_bias or pdl): + pytest.skip("mm_bf16 with auto backend does not support bias or pdl arguments.") if backend == "cudnn" and (enable_bias or pdl): pytest.skip( "mm_bf16 with cuDNN backend does not support bias or pdl arguments." From 214cb05061a33b224c5b421c367422dae2c77400 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Sun, 29 Mar 2026 15:27:01 +0400 Subject: [PATCH 07/15] fix Signed-off-by: Vadim Gimpelson --- flashinfer/autotuner.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index cb0141fdf8..dd8802318c 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -397,7 +397,14 @@ def forward( raise NotImplementedError def __hash__(self): - return hash(tuple(self.__dict__.values())) + hashable_vals = [] + for v in self.__dict__.values(): + try: + hash(v) + hashable_vals.append(v) + except TypeError: + hashable_vals.append(id(v)) + return hash(tuple(hashable_vals)) @contextlib.contextmanager From 90d95117ae595fd58a8c0ac8975e6adf102479d7 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 00:36:50 +0400 Subject: [PATCH 08/15] CR fixes Signed-off-by: Vadim Gimpelson --- csrc/bmm_fp8.cu | 26 ++++++++++++++++++---- csrc/mm_bf16_cublaslt.cu | 35 +++++++++++++++++------------- flashinfer/gemm/gemm_base.py | 42 +++++++++++++++--------------------- 3 files changed, 59 insertions(+), 44 deletions(-) diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index da8a8670e4..7abd3add91 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -49,7 +49,8 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso auto stream = get_stream(A.device()); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( - workspace_buffer.data_ptr(), workspace_buffer.numel(), + workspace_buffer.data_ptr(), + workspace_buffer.numel() * get_element_size(workspace_buffer), static_cast(B.data_ptr()), static_cast(A.data_ptr()), static_cast(D.data_ptr()), batch_size, n, m, k, static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), @@ -65,8 +66,18 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle, TensorView algo_buffer) { + CHECK_CUDA(A); + CHECK_CUDA(B); + CHECK_CUDA(D); + CHECK_DIM(3, A); + CHECK_DIM(3, B); + CHECK_DIM(3, D); + TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; + TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; + TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) + << "Result tensor has incorrect shape"; + int64_t result = 0; - // Same dtype dispatch and pointer swap as bmm_fp8 DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { @@ -82,7 +93,8 @@ int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A flashinfer::bmm_fp8::kAlgoBytes); result = flashinfer::bmm_fp8::get_fp8_algorithms( batch_size, n, m, k, static_cast(B_scale.data_ptr()), - static_cast(A_scale.data_ptr()), workspace_buffer.numel(), lt_handle, + static_cast(A_scale.data_ptr()), + workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle, algo_buffer.data_ptr(), max_algos); return true; }); @@ -105,6 +117,11 @@ void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) << "Result tensor has incorrect shape"; + int64_t max_algos = + algo_buffer.numel() * get_element_size(algo_buffer) / flashinfer::bmm_fp8::kAlgoBytes; + TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos) + << "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")"; + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { @@ -118,7 +135,8 @@ void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView auto stream = get_stream(A.device()); auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo( - workspace_buffer.data_ptr(), workspace_buffer.numel(), + workspace_buffer.data_ptr(), + workspace_buffer.numel() * get_element_size(workspace_buffer), static_cast(B.data_ptr()), static_cast(A.data_ptr()), static_cast(D.data_ptr()), batch_size, n, m, k, static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index 6827ef6f2f..3a131ba90a 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -76,29 +76,30 @@ void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, TensorVi << "mm_bf16_cublaslt failed: " << cublasGetStatusString(status); } -int64_t mm_bf16_cublaslt_tactic_num(TensorView mat1, TensorView mat2, TensorView out, - TensorView workspace_buffer, int64_t cublas_handle) { - int64_t m = mat1.size(0); - int64_t k = mat1.size(1); - int64_t n = mat2.size(0); - cudaDataType_t d_type = get_d_type(out.dtype()); - - ffi::CUDADeviceGuard device_guard(mat1.device().device_id); - auto lt_handle = reinterpret_cast(cublas_handle); - return static_cast(flashinfer::mm_bf16_cublaslt::get_algorithm_count( - static_cast(m), static_cast(n), static_cast(k), d_type, - workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle)); -} - // Serialize all heuristic algorithms into a CPU uint8 tensor for caching. // algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes. // Returns number of algorithms written. int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, int64_t cublas_handle, TensorView algo_buffer) { + CHECK_CUDA(mat1); + CHECK_CUDA(mat2); + CHECK_CUDA(out); + CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); + CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + CHECK_DIM(2, out); + int64_t m = mat1.size(0); int64_t k = mat1.size(1); int64_t n = mat2.size(0); + + TVM_FFI_ICHECK_EQ(mat2.size(1), k) + << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); + TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; + TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + cudaDataType_t d_type = get_d_type(out.dtype()); ffi::CUDADeviceGuard device_guard(mat1.device().device_id); @@ -133,6 +134,11 @@ void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; + int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) / + flashinfer::mm_bf16_cublaslt::kAlgoBytes; + TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos) + << "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")"; + auto lt_handle = reinterpret_cast(cublas_handle); ffi::CUDADeviceGuard device_guard(mat1.device().device_id); auto stream = get_stream(mat1.device()); @@ -148,6 +154,5 @@ void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView } TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt, mm_bf16_cublaslt); -TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_tactic_num, mm_bf16_cublaslt_tactic_num); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a011d88713..0437a8678c 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -118,7 +118,7 @@ def __init__(self): self._algo_cache: dict = {} def _get_algos(self, inputs): - a, b, _, _, out, workspace_buffer = inputs + a, b, scale_a, scale_b, out, workspace_buffer = inputs key = (a.shape, b.shape, a.dtype, b.dtype, out.dtype) cached = self._algo_cache.get(key) if cached is not None: @@ -131,8 +131,8 @@ def _get_algos(self, inputs): a, b, out, - inputs[2], - inputs[3], + scale_a, + scale_b, workspace_buffer, cublas_handle, algo_buf, @@ -259,6 +259,8 @@ def _cutlass_mm_bf16_requirement( return True +# Gated to Blackwell (SM100/SM103) for the initial scope of this backend. +# cuBLASLt supports BF16 GEMM on SM80+; the gate can be widened in a follow-up. @supported_compute_capability([100, 103]) def _cublaslt_mm_bf16_requirement( a: torch.Tensor, @@ -1008,28 +1010,18 @@ def forward( a, b, _, _, out, workspace_buffer = inputs cublas_handle = torch.cuda.current_blas_handle() b_t = b.transpose(-2, -1) - if tactic >= 0: - algo_buf, count = self._get_algos(inputs) - if tactic >= count: - tactic = 0 - module.mm_bf16_cublaslt_run_with_algo( - a, - b_t, - out, - workspace_buffer, - cublas_handle, - algo_buf, - tactic, - ) - else: - module.mm_bf16_cublaslt( - a, - b_t, - out, - workspace_buffer, - cublas_handle, - tactic, - ) + algo_buf, count = self._get_algos(inputs) + if tactic < 0 or tactic >= count: + tactic = 0 + module.mm_bf16_cublaslt_run_with_algo( + a, + b_t, + out, + workspace_buffer, + cublas_handle, + algo_buf, + tactic, + ) return out return CublasltBf16GemmRunner() From e633cc723e230b90b529312f1ef913516e60fafa Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 01:19:27 +0400 Subject: [PATCH 09/15] support fp16,fp32 output Signed-off-by: Vadim Gimpelson --- csrc/mm_bf16_cublaslt.cu | 38 ----------- flashinfer/gemm/gemm_base.py | 36 ++++++++--- include/flashinfer/gemm/mm_bf16_cublaslt.cuh | 66 -------------------- tests/gemm/test_mm_bf16.py | 2 - 4 files changed, 27 insertions(+), 115 deletions(-) diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index 3a131ba90a..fa9a9ded2a 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -39,43 +39,6 @@ cudaDataType_t get_d_type(DLDataType dtype) { } // namespace -// mat1: (m, k) bf16 contiguous row-major -// mat2: (n, k) bf16 contiguous row-major (Python passes b.transpose(-2,-1)) -// out: (m, n) bf16/fp16/fp32 contiguous row-major -void mm_bf16_cublaslt(TensorView mat1, TensorView mat2, TensorView out, TensorView workspace_buffer, - int64_t cublas_handle, int64_t tactic) { - CHECK_CUDA(mat1); - CHECK_CUDA(mat2); - CHECK_CUDA(out); - CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); - CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); - CHECK_DIM(2, mat1); - CHECK_DIM(2, mat2); - CHECK_DIM(2, out); - - int64_t m = mat1.size(0); - int64_t k = mat1.size(1); - int64_t n = mat2.size(0); - - TVM_FFI_ICHECK_EQ(mat2.size(1), k) - << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); - TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; - TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; - - auto lt_handle = reinterpret_cast(cublas_handle); - ffi::CUDADeviceGuard device_guard(mat1.device().device_id); - auto stream = get_stream(mat1.device()); - cudaDataType_t d_type = get_d_type(out.dtype()); - - auto status = flashinfer::mm_bf16_cublaslt::run( - static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), - out.data_ptr(), static_cast(m), static_cast(n), static_cast(k), d_type, - workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer), - lt_handle, stream, static_cast(tactic)); - TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) - << "mm_bf16_cublaslt failed: " << cublasGetStatusString(status); -} - // Serialize all heuristic algorithms into a CPU uint8 tensor for caching. // algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes. // Returns number of algorithms written. @@ -153,6 +116,5 @@ void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView << "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status); } -TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt, mm_bf16_cublaslt); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos); TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 0437a8678c..e75e517951 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -279,11 +279,6 @@ def _cublaslt_mm_bf16_requirement( raise ValueError( "The cuBLASLt backend does not support PDL. Use the TGV backend instead." ) - if out_dtype == torch.float16: - raise ValueError( - "cuBLASLt BF16 GEMM does not support float16 output. " - "Use bfloat16 or float32, or use the CUTLASS/cuDNN backend for float16 output." - ) _validate_bf16_output_dtype(out_dtype) return True @@ -972,18 +967,32 @@ class CublasltBf16GemmRunner(TunableRunner): def __init__(self): self._algo_cache: dict = {} + @staticmethod + def _compute_dtype(out_dtype): + # cuBLASLt with BF16 inputs supports BF16 or FP32 output natively. + # FP16 output is achieved via BF16 compute + cast. + if out_dtype == torch.float16: + return torch.bfloat16 + return out_dtype + def _get_algos(self, inputs): a, b, _, _, out, workspace_buffer = inputs - key = (a.shape[0], b.shape[0], a.shape[1], out.dtype) + compute_dt = self._compute_dtype(out.dtype) + key = (a.shape[0], b.shape[0], a.shape[1], compute_dt) cached = self._algo_cache.get(key) if cached is not None: return cached algo_buf = torch.empty(_MAX_ALGOS * _ALGO_BYTES, dtype=torch.uint8) cublas_handle = torch.cuda.current_blas_handle() + proxy_out = ( + out + if out.dtype == compute_dt + else torch.empty_like(out, dtype=compute_dt) + ) count = module.mm_bf16_cublaslt_get_algos( a, b.transpose(-2, -1), - out, + proxy_out, workspace_buffer, cublas_handle, algo_buf, @@ -1010,18 +1019,27 @@ def forward( a, b, _, _, out, workspace_buffer = inputs cublas_handle = torch.cuda.current_blas_handle() b_t = b.transpose(-2, -1) + + need_cast = out.dtype == torch.float16 + if need_cast: + compute_out = torch.empty_like(out, dtype=torch.bfloat16) + else: + compute_out = out + algo_buf, count = self._get_algos(inputs) if tactic < 0 or tactic >= count: tactic = 0 module.mm_bf16_cublaslt_run_with_algo( a, b_t, - out, + compute_out, workspace_buffer, cublas_handle, algo_buf, tactic, ) + if need_cast: + out.copy_(compute_out) return out return CublasltBf16GemmRunner() @@ -2626,7 +2644,7 @@ def _cudnn_gemm_mxfp8_override_shape( ) -@functools.lru_cache(maxsize=1024) +@functools.lru_cache(maxsize=2048) def build_cudnn_gemm_with_per_tensor_q_graph( a_shape, a_stride, diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh index ed327ad3ad..e91d035ea1 100644 --- a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -67,28 +67,6 @@ struct GemmDescriptors { } }; -/*! - * \brief Get the number of available cuBLASLt algorithms for a BF16 GEMM. - */ -inline int get_algorithm_count(int m, int n, int k, cudaDataType_t d_type, - size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle) { - GemmDescriptors desc(m, n, k, d_type); - - CuBlasLtMatmulPreference preference; - preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); - - std::array results; - int returned_count = 0; - cublasStatus_t status = cublasLtMatmulAlgoGetHeuristic( - lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), - desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), - preference.descriptor(), kMaxAlgorithms, results.data(), &returned_count); - if (status != CUBLAS_STATUS_SUCCESS) { - return 0; - } - return returned_count; -} - /*! * \brief Query heuristics once and serialize all cublasLtMatmulAlgo_t structs into a buffer. * @@ -147,50 +125,6 @@ inline cublasStatus_t run_with_algo(const __nv_bfloat16* mat1, const __nv_bfloat return CUBLAS_STATUS_SUCCESS; } -/*! - * \brief Run a BF16 GEMM using a specific cuBLASLt algorithm (tactic). - * - * \param mat1 Pointer to mat1 data, row-major (m, k) - * \param mat2 Pointer to mat2 data, row-major (n, k) — after b.transpose(-2,-1) in Python - * \param out Pointer to output data, row-major (m, n) - * \param tactic Algorithm index; -1 means use the top heuristic (index 0). - */ -inline cublasStatus_t run(const __nv_bfloat16* mat1, const __nv_bfloat16* mat2, void* out, int m, - int n, int k, cudaDataType_t d_type, void* workspace, - size_t workspace_size_in_bytes, cublasLtHandle_t lt_handle, - cudaStream_t stream, int tactic) { - GemmDescriptors desc(m, n, k, d_type); - - CuBlasLtMatmulPreference preference; - preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspace_size_in_bytes); - - int algo_idx = (tactic < 0) ? 0 : tactic; - int request_count = algo_idx + 1; - if (request_count > kMaxAlgorithms) { - request_count = kMaxAlgorithms; - } - - std::array results; - int returned_count = 0; - cublasStatus_t heur_status = cublasLtMatmulAlgoGetHeuristic( - lt_handle, desc.matmul_desc.descriptor(), desc.a_layout.descriptor(), - desc.b_layout.descriptor(), desc.d_layout.descriptor(), desc.d_layout.descriptor(), - preference.descriptor(), request_count, results.data(), &returned_count); - if (heur_status != CUBLAS_STATUS_SUCCESS || returned_count <= algo_idx) { - return CUBLAS_STATUS_NOT_SUPPORTED; - } - - const float alpha = 1.0f; - const float beta = 0.0f; - // Note: mat2 is cuBLASLt "A", mat1 is cuBLASLt "B" (swap for row-major output) - FLASHINFER_CUBLAS_CALL( - cublasLtMatmul(lt_handle, desc.matmul_desc.descriptor(), &alpha, mat2, - desc.a_layout.descriptor(), mat1, desc.b_layout.descriptor(), &beta, nullptr, - desc.d_layout.descriptor(), out, desc.d_layout.descriptor(), - &results[algo_idx].algo, workspace, workspace_size_in_bytes, stream)); - return CUBLAS_STATUS_SUCCESS; -} - } // namespace mm_bf16_cublaslt } // namespace flashinfer diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index ca964ec924..eafb7b1926 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -55,8 +55,6 @@ def test_mm_bf16( pytest.skip( "mm_bf16 with cuBLASLt backend does not support bias or pdl arguments." ) - if backend == "cublaslt" and res_dtype == torch.float16: - pytest.skip("mm_bf16 with cuBLASLt backend does not support float16 output.") if res_dtype != torch.bfloat16 and backend == "tgv": pytest.skip( "mm_bf16 with TGV backend does not support specifying non-bfloat16 result dtypes." From 98dde9970daeb02f2e82eed83a16049ff390b0cb Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 02:31:39 +0400 Subject: [PATCH 10/15] CR fixes Signed-off-by: Vadim Gimpelson --- csrc/bmm_fp8.cu | 2 ++ csrc/mm_bf16_cublaslt.cu | 2 ++ flashinfer/gemm/gemm_base.py | 22 +++++++++++++------- include/flashinfer/gemm/mm_bf16_cublaslt.cuh | 3 +-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index 7abd3add91..54e8c26898 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -72,6 +72,7 @@ int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A CHECK_DIM(3, A); CHECK_DIM(3, B); CHECK_DIM(3, D); + CHECK_CONTIGUOUS(algo_buffer); TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) @@ -112,6 +113,7 @@ void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView CHECK_DIM(3, A); CHECK_DIM(3, B); CHECK_DIM(3, D); + CHECK_CONTIGUOUS(algo_buffer); TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index fa9a9ded2a..2df28b0777 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -53,6 +53,7 @@ int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView CHECK_DIM(2, mat1); CHECK_DIM(2, mat2); CHECK_DIM(2, out); + CHECK_CONTIGUOUS(algo_buffer); int64_t m = mat1.size(0); int64_t k = mat1.size(1); @@ -87,6 +88,7 @@ void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView CHECK_DIM(2, mat1); CHECK_DIM(2, mat2); CHECK_DIM(2, out); + CHECK_CONTIGUOUS(algo_buffer); int64_t m = mat1.size(0); int64_t k = mat1.size(1); diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index e75e517951..a65d1a2f5a 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -90,6 +90,11 @@ DEFAULT_WORKSPACE_SIZE = 32 * 1024 * 1024 +# sizeof(cublasLtMatmulAlgo_t) = uint64_t[8] = 64 bytes. +# Shared by cuBLAS FP8, cuBLASLt BF16, and any other cuBLASLt-based runners. +_CUBLASLT_ALGO_BYTES = 64 +_CUBLASLT_MAX_ALGOS = 100 + # Error messages CUDNN_FP4_MXFP4_SM120_CUDNN_VERSION_ERROR = "cudnn FP4 GEMM with mxfp4 quantization is not supported on SM120/SM121 with cuDNN backend version < 9.14.0." @@ -109,9 +114,6 @@ def _match_sm_version(device: torch.device, sm_version: list[str]): def get_gemm_module(): module = gen_gemm_module().build_and_load() - _FP8_ALGO_BYTES = 64 # sizeof(cublasLtMatmulAlgo_t) - _FP8_MAX_ALGOS = 100 - def cublas_fp8_gemm_runner(): class CublasFp8GemmRunner(TunableRunner): def __init__(self): @@ -124,7 +126,9 @@ def _get_algos(self, inputs): if cached is not None: return cached algo_buf = torch.empty( - _FP8_MAX_ALGOS * _FP8_ALGO_BYTES, dtype=torch.uint8 + _CUBLASLT_MAX_ALGOS * _CUBLASLT_ALGO_BYTES, + dtype=torch.uint8, + device="cpu", ) cublas_handle = torch.cuda.current_blas_handle() count = module.bmm_fp8_get_algos( @@ -378,6 +382,7 @@ def _heuristic_func_mm_bf16( ): heuristic_backends = [] if bias is not None or pdl: + # cuDNN, CUTLASS, and cuBLASLt don't support bias/pdl, only TGV does if "tgv" in suitable_backends: heuristic_backends.append("tgv") else: @@ -959,9 +964,6 @@ def forward( def get_mm_bf16_cublaslt_module(): module = gen_mm_bf16_cublaslt_module().build_and_load() - _ALGO_BYTES = 64 # sizeof(cublasLtMatmulAlgo_t) = uint64_t[8] - _MAX_ALGOS = 100 - def cublaslt_bf16_gemm_runner(): class CublasltBf16GemmRunner(TunableRunner): def __init__(self): @@ -982,7 +984,11 @@ def _get_algos(self, inputs): cached = self._algo_cache.get(key) if cached is not None: return cached - algo_buf = torch.empty(_MAX_ALGOS * _ALGO_BYTES, dtype=torch.uint8) + algo_buf = torch.empty( + _CUBLASLT_MAX_ALGOS * _CUBLASLT_ALGO_BYTES, + dtype=torch.uint8, + device="cpu", + ) cublas_handle = torch.cuda.current_blas_handle() proxy_out = ( out diff --git a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh index e91d035ea1..c7f6307597 100644 --- a/include/flashinfer/gemm/mm_bf16_cublaslt.cuh +++ b/include/flashinfer/gemm/mm_bf16_cublaslt.cuh @@ -32,8 +32,7 @@ using bmm_fp8::CuBlasLtMatmulPreference; using bmm_fp8::CuBlasLtMatrixLayout; static constexpr int kMaxAlgorithms = 100; -// cublasLtMatmulAlgo_t is { uint64_t data[8]; } — 64 bytes, trivially serializable per NVIDIA docs -static constexpr size_t kAlgoBytes = sizeof(cublasLtMatmulAlgo_t); +using bmm_fp8::kAlgoBytes; /*! * \brief Set up cuBLASLt descriptors for BF16 GEMM in row-major convention. From 46167338e7490adecac67ceaaaa701066e8f590f Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 02:39:06 +0400 Subject: [PATCH 11/15] CR fixes Signed-off-by: Vadim Gimpelson --- flashinfer/gemm/gemm_base.py | 7 +++++++ tests/gemm/test_mm_bf16.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index a65d1a2f5a..cbf1fe9ff6 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1033,6 +1033,13 @@ def forward( compute_out = out algo_buf, count = self._get_algos(inputs) + if count == 0: + raise RuntimeError( + "cuBLASLt heuristic returned zero algorithms for " + f"M={a.shape[0]}, N={b.shape[0]}, K={a.shape[1]}, " + f"dtype={compute_out.dtype}. " + "This shape/dtype combination may not be supported." + ) if tactic < 0 or tactic >= count: tactic = 0 module.mm_bf16_cublaslt_run_with_algo( diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index eafb7b1926..c7c9244fcd 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -86,5 +86,34 @@ def test_mm_bf16( assert cos_sim > 0.99 +def test_cublaslt_bf16_runner_zero_algos(): + """CublasltBf16GemmRunner.forward() must raise when heuristic returns 0 algorithms.""" + from flashinfer.gemm.gemm_base import get_mm_bf16_cublaslt_module + from flashinfer.utils import get_compute_capability + + compute_capability = get_compute_capability(torch.device("cuda")) + cc_num = compute_capability[0] * 10 + compute_capability[1] + if not mm_bf16.is_backend_supported("cublaslt", cc_num): + pytest.skip("cublaslt backend not supported on this GPU") + + runner = get_mm_bf16_cublaslt_module().cublaslt_bf16_gemm_runner() + + m, n, k = 16, 1024, 1024 + a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) + b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16) + out = torch.empty(m, n, device="cuda", dtype=torch.bfloat16) + workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) + inputs = [a, b, None, None, out, workspace] + + zero_algo_buf = torch.empty(0, dtype=torch.uint8, device="cpu") + original_get_algos = runner._get_algos + runner._get_algos = lambda _inputs: (zero_algo_buf, 0) + try: + with pytest.raises(RuntimeError, match="zero algorithms"): + runner.forward(inputs) + finally: + runner._get_algos = original_get_algos + + if __name__ == "__main__": pytest.main([__file__]) From 37e2abf13eab6238b92eb8df2a42e68fecfc382e Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 13:02:57 +0400 Subject: [PATCH 12/15] CR fixes Signed-off-by: Vadim Gimpelson --- flashinfer/gemm/gemm_base.py | 2 +- tests/gemm/test_mm_bf16.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index cbf1fe9ff6..7d7f26baf7 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -980,7 +980,7 @@ def _compute_dtype(out_dtype): def _get_algos(self, inputs): a, b, _, _, out, workspace_buffer = inputs compute_dt = self._compute_dtype(out.dtype) - key = (a.shape[0], b.shape[0], a.shape[1], compute_dt) + key = (a.shape[0], b.shape[1], a.shape[1], compute_dt) cached = self._algo_cache.get(key) if cached is not None: return cached diff --git a/tests/gemm/test_mm_bf16.py b/tests/gemm/test_mm_bf16.py index c7c9244fcd..f5371d48d4 100644 --- a/tests/gemm/test_mm_bf16.py +++ b/tests/gemm/test_mm_bf16.py @@ -100,7 +100,7 @@ def test_cublaslt_bf16_runner_zero_algos(): m, n, k = 16, 1024, 1024 a = torch.randn(m, k, device="cuda", dtype=torch.bfloat16) - b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16) + b = torch.randn(n, k, device="cuda", dtype=torch.bfloat16).transpose(-2, -1) out = torch.empty(m, n, device="cuda", dtype=torch.bfloat16) workspace = torch.empty(32 * 1024 * 1024, device="cuda", dtype=torch.uint8) inputs = [a, b, None, None, out, workspace] From cac20a47ee48d3c0aad1230222f9a762d440905d Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 13:50:25 +0400 Subject: [PATCH 13/15] coderabbit review fixes Signed-off-by: Vadim Gimpelson --- csrc/mm_bf16_cublaslt.cu | 4 ++++ flashinfer/aot.py | 3 ++- flashinfer/autotuner.py | 4 +++- flashinfer/gemm/gemm_base.py | 7 ++++++- flashinfer/jit/gemm/core.py | 4 ++++ 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/csrc/mm_bf16_cublaslt.cu b/csrc/mm_bf16_cublaslt.cu index 2df28b0777..d6f0330620 100644 --- a/csrc/mm_bf16_cublaslt.cu +++ b/csrc/mm_bf16_cublaslt.cu @@ -53,7 +53,9 @@ int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView CHECK_DIM(2, mat1); CHECK_DIM(2, mat2); CHECK_DIM(2, out); + CHECK_CPU(algo_buffer); CHECK_CONTIGUOUS(algo_buffer); + CHECK_CUDA(workspace_buffer); int64_t m = mat1.size(0); int64_t k = mat1.size(1); @@ -88,7 +90,9 @@ void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView CHECK_DIM(2, mat1); CHECK_DIM(2, mat2); CHECK_DIM(2, out); + CHECK_CPU(algo_buffer); CHECK_CONTIGUOUS(algo_buffer); + CHECK_CUDA(workspace_buffer); int64_t m = mat1.size(0); int64_t k = mat1.size(1); diff --git a/flashinfer/aot.py b/flashinfer/aot.py index aaa5b0ca90..e839a13162 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -494,7 +494,6 @@ def gen_all_modules( jit_specs.append(gen_gemm_sm100_module_cutlass_fp4()) jit_specs.append(gen_gemm_sm100_module_cutlass_fp8()) jit_specs.append(gen_gemm_sm100_module_cutlass_mxfp8()) - jit_specs.append(gen_mm_bf16_cublaslt_module()) # Add TGV GEMM modules for both bf16 and fp16 jit_specs.append( gen_tgv_gemm_sm10x_module(torch.bfloat16, use_sm_100f=False) @@ -513,6 +512,8 @@ def gen_all_modules( ) jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True)) jit_specs.append(gen_moe_utils_module()) + if has_sm100 or has_sm103: + jit_specs.append(gen_mm_bf16_cublaslt_module()) if has_sm103: jit_specs.append(gen_fp4_quantization_sm103_module()) jit_specs.append(gen_cutlass_fused_moe_sm103_module()) diff --git a/flashinfer/autotuner.py b/flashinfer/autotuner.py index dd8802318c..48120ba23f 100644 --- a/flashinfer/autotuner.py +++ b/flashinfer/autotuner.py @@ -398,7 +398,9 @@ def forward( def __hash__(self): hashable_vals = [] - for v in self.__dict__.values(): + for k, v in self.__dict__.items(): + if k.endswith("_cache"): + continue try: hash(v) hashable_vals.append(v) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 7d7f26baf7..ea4d045f67 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -164,6 +164,11 @@ def forward( a, b, scale_a, scale_b, out, workspace_buffer = inputs if tactic >= 0: algo_buf, count = self._get_algos(inputs) + if count == 0: + raise RuntimeError( + "cuBLASLt heuristic returned zero FP8 algorithms for " + f"A={tuple(a.shape)}, B={tuple(b.shape)}, out={tuple(out.shape)}." + ) if tactic >= count: tactic = 0 module.bmm_fp8_run_with_algo( @@ -1036,7 +1041,7 @@ def forward( if count == 0: raise RuntimeError( "cuBLASLt heuristic returned zero algorithms for " - f"M={a.shape[0]}, N={b.shape[0]}, K={a.shape[1]}, " + f"M={a.shape[0]}, N={b.shape[1]}, K={a.shape[1]}, " f"dtype={compute_out.dtype}. " "This shape/dtype combination may not be supported." ) diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index c435e0fc86..e91c522639 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -51,11 +51,15 @@ def gen_gemm_module() -> JitSpec: def gen_mm_bf16_cublaslt_module() -> JitSpec: + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10] + ) return gen_jit_spec( "mm_bf16_cublaslt", [ jit_env.FLASHINFER_CSRC_DIR / "mm_bf16_cublaslt.cu", ], + extra_cuda_cflags=nvcc_flags, extra_ldflags=["-lcublas", "-lcublasLt"], ) From 9e82e66dc9a4ccf365c2cf2dfd81f823ec608d39 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 14:04:02 +0400 Subject: [PATCH 14/15] coderabbit review fixes Signed-off-by: Vadim Gimpelson --- flashinfer/gemm/gemm_base.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index ea4d045f67..abb4b65a96 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -130,7 +130,8 @@ def _get_algos(self, inputs): dtype=torch.uint8, device="cpu", ) - cublas_handle = torch.cuda.current_blas_handle() + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() count = module.bmm_fp8_get_algos( a, b, @@ -160,8 +161,9 @@ def forward( do_preparation: bool = False, **kwargs, ) -> torch.Tensor: - cublas_handle = torch.cuda.current_blas_handle() a, b, scale_a, scale_b, out, workspace_buffer = inputs + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() if tactic >= 0: algo_buf, count = self._get_algos(inputs) if count == 0: @@ -994,7 +996,8 @@ def _get_algos(self, inputs): dtype=torch.uint8, device="cpu", ) - cublas_handle = torch.cuda.current_blas_handle() + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() proxy_out = ( out if out.dtype == compute_dt @@ -1028,7 +1031,8 @@ def forward( **kwargs, ) -> torch.Tensor: a, b, _, _, out, workspace_buffer = inputs - cublas_handle = torch.cuda.current_blas_handle() + with torch.cuda.device(a.device): + cublas_handle = torch.cuda.current_blas_handle() b_t = b.transpose(-2, -1) need_cast = out.dtype == torch.float16 From b5f957d4d15bc5d71ce5a1f1e497f241d58f8430 Mon Sep 17 00:00:00 2001 From: Vadim Gimpelson Date: Mon, 30 Mar 2026 14:18:50 +0400 Subject: [PATCH 15/15] Add get_cache_key_extras and fix out docstring Signed-off-by: Vadim Gimpelson --- flashinfer/gemm/gemm_base.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index abb4b65a96..565304b2fe 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -119,6 +119,10 @@ class CublasFp8GemmRunner(TunableRunner): def __init__(self): self._algo_cache: dict = {} + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def _get_algos(self, inputs): a, b, scale_a, scale_b, out, workspace_buffer = inputs key = (a.shape, b.shape, a.dtype, b.dtype, out.dtype) @@ -441,8 +445,9 @@ def mm_bf16( Whether to enable Programmatic Dependent Launch (PDL). Enabled for TGV backend. Defaults to ``False``. out: Optional[torch.Tensor] - Out tensor, shape (m, n), bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt - backends. Defaults to ``None``. + Out tensor, shape (m, n). Preallocated output is supported by all backends. + TGV requires ``torch.bfloat16``; CUTLASS, cuDNN, and cuBLASLt also support fp16/fp32. + Defaults to ``None``. out_dtype: torch.dtype Output dtype, bf16, fp16, or fp32. Enabled for CUTLASS, cuDNN, and cuBLASLt backends. @@ -976,6 +981,10 @@ class CublasltBf16GemmRunner(TunableRunner): def __init__(self): self._algo_cache: dict = {} + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + _, _, _, _, out, _ = inputs + return (self._compute_dtype(out.dtype),) + @staticmethod def _compute_dtype(out_dtype): # cuBLASLt with BF16 inputs supports BF16 or FP32 output natively. @@ -2971,6 +2980,10 @@ def _cudnn_gemm_fp8( def _cudnn_gemm_fp8_runner(): class CudnnFp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor], @@ -7041,6 +7054,10 @@ def _cudnn_gemm_mxfp8( def _cudnn_gemm_mxfp8_runner(): class CudnnMxfp8GemmRunner(TunableRunner): + def get_cache_key_extras(self, inputs: List[torch.Tensor]) -> tuple: + a, b, _, _, out, _ = inputs + return (a.dtype, b.dtype, out.dtype) + def get_valid_tactics( self, inputs: List[torch.Tensor],