Skip to content
Open
15 changes: 12 additions & 3 deletions benchmarks/routines/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ def parse_gemm_args(line, parser):
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass", "tgv", "cute-dsl", "auto"],
choices=[
"cudnn",
"cublas",
"trtllm",
"cutlass",
"tgv",
"cublaslt",
"cute-dsl",
"auto",
],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
Expand Down Expand Up @@ -1553,7 +1562,7 @@ def testMmBf16(args):
use_pdl = getattr(args, "enable_pdl", False)
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "auto"]
autotune_supported_backends = ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]
res = []

out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
Expand Down Expand Up @@ -1618,7 +1627,7 @@ def testMmBf16(args):
return res

def run_backend(backend, a, b, bias, use_pdl, out_dtype):
if backend in ["cudnn", "cutlass", "tgv", "auto"]:
if backend in ["cudnn", "cutlass", "tgv", "cublaslt", "auto"]:
return flashinfer.mm_bf16(
a=a,
b=b,
Expand Down
91 changes: 90 additions & 1 deletion csrc/bmm_fp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
auto stream = get_stream(A.device());

auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt(
workspace_buffer.data_ptr(), workspace_buffer.numel(),
workspace_buffer.data_ptr(),
workspace_buffer.numel() * get_element_size(workspace_buffer),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
Expand All @@ -61,3 +62,91 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso
});
});
}

int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer) {
CHECK_CUDA(A);
CHECK_CUDA(B);
CHECK_CUDA(D);
CHECK_DIM(3, A);
CHECK_DIM(3, B);
CHECK_DIM(3, D);
CHECK_CONTIGUOUS(algo_buffer);
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
<< "Result tensor has incorrect shape";

int64_t result = 0;
DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(A.device().device_id);

int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::bmm_fp8::kAlgoBytes);
result = flashinfer::bmm_fp8::get_fp8_algorithms<b_type, a_type, d_type>(
batch_size, n, m, k, static_cast<float*>(B_scale.data_ptr()),
static_cast<float*>(A_scale.data_ptr()),
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
algo_buffer.data_ptr(), max_algos);
return true;
});
});
});
return static_cast<int64_t>(result);
}

void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx) {
CHECK_CUDA(A);
CHECK_CUDA(B);
CHECK_CUDA(D);
CHECK_DIM(3, A);
CHECK_DIM(3, B);
CHECK_DIM(3, D);
CHECK_CONTIGUOUS(algo_buffer);
TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match";
TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes";
TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2))
<< "Result tensor has incorrect shape";

int64_t max_algos =
algo_buffer.numel() * get_element_size(algo_buffer) / flashinfer::bmm_fp8::kAlgoBytes;
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";

DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] {
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] {
auto batch_size = A.size(0);
auto m = A.size(1);
auto k = A.size(2);
auto n = B.size(2);

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(A.device().device_id);
auto stream = get_stream(A.device());

auto status = flashinfer::bmm_fp8::bmm_fp8_run_with_algo<b_type, a_type, d_type>(
workspace_buffer.data_ptr(),
workspace_buffer.numel() * get_element_size(workspace_buffer),
static_cast<b_type*>(B.data_ptr()), static_cast<a_type*>(A.data_ptr()),
static_cast<d_type*>(D.data_ptr()), batch_size, n, m, k,
static_cast<float*>(B_scale.data_ptr()), static_cast<float*>(A_scale.data_ptr()),
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
<< "bmm_fp8_run_with_algo failed: " << cublasGetStatusString(status);
return true;
});
});
});
}
10 changes: 10 additions & 0 deletions csrc/flashinfer_gemm_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,19 @@
void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, TensorView B_scale,
TensorView workspace_buffer, int64_t cublas_handle);

int64_t bmm_fp8_get_algos(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer);

void bmm_fp8_run_with_algo(TensorView A, TensorView B, TensorView D, TensorView A_scale,
TensorView B_scale, TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx);

void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr,
TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld,
TensorView y_ld, TensorView empty_x_data, bool weight_column_major);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(cutlass_segment_gemm, CutlassSegmentGEMM);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8, bmm_fp8);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_get_algos, bmm_fp8_get_algos);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(bmm_fp8_run_with_algo, bmm_fp8_run_with_algo);
126 changes: 126 additions & 0 deletions csrc/mm_bf16_cublaslt.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuda_bf16.h>
#include <driver_types.h>

#include <flashinfer/gemm/mm_bf16_cublaslt.cuh>

#include "tvm_ffi_utils.h"

namespace {

cudaDataType_t get_d_type(DLDataType dtype) {
switch (encode_dlpack_dtype(dtype)) {
case bfloat16_code:
return CUDA_R_16BF;
case float16_code:
return CUDA_R_16F;
case float32_code:
return CUDA_R_32F;
default:
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32.";
return CUDA_R_16BF;
}
}

} // namespace

// Serialize all heuristic algorithms into a CPU uint8 tensor for caching.
// algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes.
// Returns number of algorithms written.
int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);

int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);

TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";

cudaDataType_t d_type = get_d_type(out.dtype());

ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes);
return static_cast<int64_t>(flashinfer::mm_bf16_cublaslt::get_algorithms(
static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle,
algo_buffer.data_ptr(), max_algos));
}

// Run matmul using a pre-cached algorithm — zero heuristic overhead.
void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out,
TensorView workspace_buffer, int64_t cublas_handle,
TensorView algo_buffer, int64_t algo_idx) {
CHECK_CUDA(mat1);
CHECK_CUDA(mat2);
CHECK_CUDA(out);
CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16);
CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
CHECK_DIM(2, out);
CHECK_CPU(algo_buffer);
CHECK_CONTIGUOUS(algo_buffer);
CHECK_CUDA(workspace_buffer);

int64_t m = mat1.size(0);
int64_t k = mat1.size(1);
int64_t n = mat2.size(0);

TVM_FFI_ICHECK_EQ(mat2.size(1), k)
<< "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1);
TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch";
TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch";

int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) /
flashinfer::mm_bf16_cublaslt::kAlgoBytes;
TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos)
<< "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")";

auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle);
ffi::CUDADeviceGuard device_guard(mat1.device().device_id);
auto stream = get_stream(mat1.device());
cudaDataType_t d_type = get_d_type(out.dtype());

auto status = flashinfer::mm_bf16_cublaslt::run_with_algo(
static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()),
out.data_ptr(), static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type,
workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer),
lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx));
TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS)
<< "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status);
}

TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo);
3 changes: 3 additions & 0 deletions flashinfer/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
gen_gemm_sm100_module_cutlass_mxfp8,
gen_gemm_sm120_module,
gen_gemm_sm120_module_cutlass_fp4,
gen_mm_bf16_cublaslt_module,
gen_tgv_gemm_sm10x_module,
gen_trtllm_gen_gemm_module,
gen_trtllm_low_latency_gemm_module,
Expand Down Expand Up @@ -511,6 +512,8 @@ def gen_all_modules(
)
jit_specs.append(gen_tgv_gemm_sm10x_module(torch.float16, use_sm_100f=True))
jit_specs.append(gen_moe_utils_module())
if has_sm100 or has_sm103:
jit_specs.append(gen_mm_bf16_cublaslt_module())
if has_sm103:
jit_specs.append(gen_fp4_quantization_sm103_module())
jit_specs.append(gen_cutlass_fused_moe_sm103_module())
Expand Down
11 changes: 10 additions & 1 deletion flashinfer/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,16 @@ def forward(
raise NotImplementedError

def __hash__(self):
return hash(tuple(self.__dict__.values()))
hashable_vals = []
for k, v in self.__dict__.items():
if k.endswith("_cache"):
continue
try:
hash(v)
hashable_vals.append(v)
except TypeError:
hashable_vals.append(id(v))
return hash(tuple(hashable_vals))


@contextlib.contextmanager
Expand Down
Loading