-
Notifications
You must be signed in to change notification settings - Fork 836
feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners
#2914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vadiklyutiy
wants to merge
15
commits into
flashinfer-ai:main
Choose a base branch
from
vadiklyutiy:mm-bf16-cublaslt
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: Add cuBLASLt backend for mm_bf16 and enable multi-tactic autotuning for FP8/MXFP8 runners
#2914
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
044dc7e
add cublaslt to mm_bf16
vadiklyutiy f6073b4
fixes
vadiklyutiy 40d0ac3
perf fix
vadiklyutiy 17233aa
fix
vadiklyutiy bdeaa6e
Enable multi-tactic autotuning for CublasFp8, CudnnFp8, and CudnnMxfp…
vadiklyutiy 3c4d422
add backed to tests
vadiklyutiy 214cb05
fix
vadiklyutiy 90d9511
CR fixes
vadiklyutiy e633cc7
support fp16,fp32 output
vadiklyutiy 98dde99
CR fixes
vadiklyutiy 4616733
CR fixes
vadiklyutiy 37e2abf
CR fixes
vadiklyutiy cac20a4
coderabbit review fixes
vadiklyutiy 9e82e66
coderabbit review fixes
vadiklyutiy b5f957d
Add get_cache_key_extras and fix out docstring
vadiklyutiy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,126 @@ | ||
| /* | ||
| * Copyright (c) 2026 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <cuda_bf16.h> | ||
| #include <driver_types.h> | ||
|
|
||
| #include <flashinfer/gemm/mm_bf16_cublaslt.cuh> | ||
|
|
||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| namespace { | ||
|
|
||
| cudaDataType_t get_d_type(DLDataType dtype) { | ||
| switch (encode_dlpack_dtype(dtype)) { | ||
| case bfloat16_code: | ||
| return CUDA_R_16BF; | ||
| case float16_code: | ||
| return CUDA_R_16F; | ||
| case float32_code: | ||
| return CUDA_R_32F; | ||
| default: | ||
| TVM_FFI_LOG_AND_THROW(NotImplementedError) << "out_dtype must be one of bf16/fp16/fp32."; | ||
| return CUDA_R_16BF; | ||
| } | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| // Serialize all heuristic algorithms into a CPU uint8 tensor for caching. | ||
| // algo_buffer: CPU uint8 tensor of size >= kMaxAlgorithms * kAlgoBytes. | ||
| // Returns number of algorithms written. | ||
| int64_t mm_bf16_cublaslt_get_algos(TensorView mat1, TensorView mat2, TensorView out, | ||
| TensorView workspace_buffer, int64_t cublas_handle, | ||
| TensorView algo_buffer) { | ||
| CHECK_CUDA(mat1); | ||
| CHECK_CUDA(mat2); | ||
| CHECK_CUDA(out); | ||
| CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); | ||
| CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); | ||
| CHECK_DIM(2, mat1); | ||
| CHECK_DIM(2, mat2); | ||
| CHECK_DIM(2, out); | ||
| CHECK_CPU(algo_buffer); | ||
| CHECK_CONTIGUOUS(algo_buffer); | ||
| CHECK_CUDA(workspace_buffer); | ||
|
|
||
| int64_t m = mat1.size(0); | ||
| int64_t k = mat1.size(1); | ||
| int64_t n = mat2.size(0); | ||
|
|
||
| TVM_FFI_ICHECK_EQ(mat2.size(1), k) | ||
| << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); | ||
| TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; | ||
| TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; | ||
|
|
||
| cudaDataType_t d_type = get_d_type(out.dtype()); | ||
|
|
||
| ffi::CUDADeviceGuard device_guard(mat1.device().device_id); | ||
| auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle); | ||
| int max_algos = static_cast<int>(algo_buffer.numel() * get_element_size(algo_buffer) / | ||
| flashinfer::mm_bf16_cublaslt::kAlgoBytes); | ||
| return static_cast<int64_t>(flashinfer::mm_bf16_cublaslt::get_algorithms( | ||
| static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type, | ||
| workspace_buffer.numel() * get_element_size(workspace_buffer), lt_handle, | ||
| algo_buffer.data_ptr(), max_algos)); | ||
| } | ||
|
|
||
| // Run matmul using a pre-cached algorithm — zero heuristic overhead. | ||
| void mm_bf16_cublaslt_run_with_algo(TensorView mat1, TensorView mat2, TensorView out, | ||
| TensorView workspace_buffer, int64_t cublas_handle, | ||
| TensorView algo_buffer, int64_t algo_idx) { | ||
| CHECK_CUDA(mat1); | ||
| CHECK_CUDA(mat2); | ||
| CHECK_CUDA(out); | ||
| CHECK_INPUT_AND_TYPE(mat1, dl_bfloat16); | ||
| CHECK_INPUT_AND_TYPE(mat2, dl_bfloat16); | ||
| CHECK_DIM(2, mat1); | ||
| CHECK_DIM(2, mat2); | ||
| CHECK_DIM(2, out); | ||
| CHECK_CPU(algo_buffer); | ||
| CHECK_CONTIGUOUS(algo_buffer); | ||
| CHECK_CUDA(workspace_buffer); | ||
|
|
||
| int64_t m = mat1.size(0); | ||
| int64_t k = mat1.size(1); | ||
| int64_t n = mat2.size(0); | ||
|
|
||
| TVM_FFI_ICHECK_EQ(mat2.size(1), k) | ||
| << "mat2 K dimension mismatch: expected " << k << ", got " << mat2.size(1); | ||
| TVM_FFI_ICHECK_EQ(out.size(0), m) << "out M dimension mismatch"; | ||
| TVM_FFI_ICHECK_EQ(out.size(1), n) << "out N dimension mismatch"; | ||
|
|
||
| int64_t max_algos = algo_buffer.numel() * get_element_size(algo_buffer) / | ||
| flashinfer::mm_bf16_cublaslt::kAlgoBytes; | ||
| TVM_FFI_ICHECK(algo_idx >= 0 && algo_idx < max_algos) | ||
| << "algo_idx " << algo_idx << " out of range [0, " << max_algos << ")"; | ||
|
|
||
| auto lt_handle = reinterpret_cast<cublasLtHandle_t>(cublas_handle); | ||
| ffi::CUDADeviceGuard device_guard(mat1.device().device_id); | ||
| auto stream = get_stream(mat1.device()); | ||
| cudaDataType_t d_type = get_d_type(out.dtype()); | ||
|
|
||
| auto status = flashinfer::mm_bf16_cublaslt::run_with_algo( | ||
| static_cast<__nv_bfloat16*>(mat1.data_ptr()), static_cast<__nv_bfloat16*>(mat2.data_ptr()), | ||
| out.data_ptr(), static_cast<int>(m), static_cast<int>(n), static_cast<int>(k), d_type, | ||
| workspace_buffer.data_ptr(), workspace_buffer.numel() * get_element_size(workspace_buffer), | ||
| lt_handle, stream, algo_buffer.data_ptr(), static_cast<int>(algo_idx)); | ||
| TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) | ||
| << "mm_bf16_cublaslt_run_with_algo failed: " << cublasGetStatusString(status); | ||
| } | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_get_algos, mm_bf16_cublaslt_get_algos); | ||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(mm_bf16_cublaslt_run_with_algo, mm_bf16_cublaslt_run_with_algo); | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.