-
Notifications
You must be signed in to change notification settings - Fork 584
make DeepGEMM swapAB available for linear gemm SM90 #2101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xuanzic
wants to merge
5
commits into
flashinfer-ai:main
Choose a base branch
from
xuanzic:vchen/dg_swapab_linear
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
3871b93
add swapab linear gemm binding
xuanzic 1eeba0f
fix binding
xuanzic e2cee34
Merge branch 'flashinfer-ai:main' into vchen/dg_swapab_linear
xuanzic 364ee70
rename function for SM90
xuanzic 6f3449d
Merge branch 'main' into vchen/dg_swapab_linear
xuanzic File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
|
|
||
| #include <tvm/ffi/extra/module.h> | ||
| #include "tvm_ffi_utils.h" | ||
| #include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" | ||
|
|
||
| #include <cstddef> | ||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <type_traits> | ||
| #include <vector> | ||
|
|
||
| namespace kernels = tensorrt_llm::kernels::fp8_blockscale_gemm; | ||
|
|
||
| using tvm::ffi::Function; | ||
| using tvm::ffi::Optional; | ||
| using tvm::ffi::TensorView; | ||
|
|
||
| #ifdef FLASHINFER_ENABLE_FP8_E4M3 | ||
| inline bool is_fp8_e4m3fn(DLDataType dtype) { | ||
| return encode_dlpack_dtype(dtype) == float8_e4m3fn_code; | ||
| } | ||
| #else | ||
| inline bool is_fp8_e4m3fn(DLDataType) { return false; } | ||
| #endif | ||
|
|
||
| /** | ||
| * @brief FP8 Block-Scale GEMM binding for SM90 | ||
| * | ||
| * Supports: | ||
| * - BF16 + BF16 β BF16 | ||
| * - BF16 + FP8 β BF16 | ||
| * | ||
| * @note Output is BF16 | ||
| */ | ||
| class Fp8BlockScaleGemmRunner : public tvm::ffi::ModuleObj { | ||
| public: | ||
| Fp8BlockScaleGemmRunner() { | ||
| // Instantiate runners | ||
| runner_bf16_bf16_ = std::make_unique<kernels::CutlassFp8BlockScaleGemmRunner< | ||
| __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>>(); | ||
|
|
||
| runner_bf16_fp8_ = std::make_unique<kernels::CutlassFp8BlockScaleGemmRunner< | ||
| __nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>>(); | ||
| } | ||
|
|
||
| ~Fp8BlockScaleGemmRunner() = default; | ||
|
|
||
| const char* type_key() const { return "flashinfer.Fp8BlockScaleGemmRunner"; } | ||
| const char* kind() const final { return "fp8_blockscale_gemm_runner"; } | ||
|
|
||
| Optional<Function> GetFunction(const tvm::ffi::String& name) { | ||
| if (name == "gemm") { | ||
| return Function::FromTyped( | ||
| [this](TensorView input, TensorView weight, TensorView output, | ||
| Optional<TensorView> scales_a, Optional<TensorView> scales_b) { | ||
| runGemm(input, weight, output, scales_a, scales_b); | ||
| }); | ||
| } else if (name == "get_workspace_size") { | ||
| return Function::FromTyped( | ||
| [this](int64_t shape_m, int64_t shape_n, int64_t shape_k) -> int64_t { | ||
| return getWorkspaceSize(shape_m, shape_n, shape_k); | ||
| }); | ||
| } else if (name == "configure_workspace") { | ||
| return Function::FromTyped([this](TensorView workspace) { | ||
| configureWorkspace(workspace); | ||
| }); | ||
| } | ||
| return Function(nullptr); | ||
| } | ||
|
|
||
| private: | ||
| /** | ||
| * @brief Runtime dtype dispatch | ||
| */ | ||
| kernels::CutlassFp8BlockScaleGemmRunnerInterface* selectRunner( | ||
| bool input_is_fp8, bool weight_is_fp8) { | ||
|
|
||
| if (!input_is_fp8 && !weight_is_fp8) { | ||
| return runner_bf16_bf16_.get(); | ||
| } else if (!input_is_fp8 && weight_is_fp8) { | ||
| return runner_bf16_fp8_.get(); | ||
| } else { | ||
| return nullptr; | ||
| } | ||
| } | ||
|
|
||
| void runGemm(const TensorView& input, const TensorView& weight, const TensorView& output, | ||
| const Optional<TensorView>& scales_a, const Optional<TensorView>& scales_b) { | ||
| auto stream = get_stream(input.device()); | ||
|
|
||
| auto input_ptr = input.data_ptr(); | ||
| auto weight_ptr = weight.data_ptr(); | ||
| auto output_ptr = output.data_ptr(); | ||
|
|
||
| int shape_m = input.size(0); | ||
| int shape_k = input.size(1); | ||
| int shape_n = weight.size(0); | ||
|
|
||
| TVM_FFI_ICHECK(input_ptr != nullptr) << "input is null"; | ||
| TVM_FFI_ICHECK(weight_ptr != nullptr) << "weight is null"; | ||
| TVM_FFI_ICHECK(output_ptr != nullptr) << "output is null"; | ||
| TVM_FFI_ICHECK(shape_k == weight.size(1)) << "K dimension mismatch"; | ||
|
|
||
| // Determine dtypes for runner selection | ||
| bool input_is_fp8 = is_fp8_e4m3fn(input.dtype()); | ||
| bool weight_is_fp8 = is_fp8_e4m3fn(weight.dtype()); | ||
|
|
||
| // Extract scale pointers | ||
| float const* scales_a_ptr = scales_a.has_value() | ||
| ? reinterpret_cast<float const*>(scales_a.value().data_ptr()) | ||
| : nullptr; | ||
| float const* scales_b_ptr = scales_b.has_value() | ||
| ? reinterpret_cast<float const*>(scales_b.value().data_ptr()) | ||
| : nullptr; | ||
|
|
||
| // Select appropriate runner | ||
| auto* runner = selectRunner(input_is_fp8, weight_is_fp8); | ||
| TVM_FFI_ICHECK(runner != nullptr) << "Unsupported dtype combination"; | ||
| TVM_FFI_ICHECK(workspace_ != nullptr) << "Workspace not configured. Call configure_workspace first."; | ||
|
|
||
| runner->gemm(output_ptr, input_ptr, weight_ptr, shape_m, shape_n, shape_k, | ||
| stream, scales_a_ptr, scales_b_ptr); | ||
| } | ||
|
|
||
| int64_t getWorkspaceSize(int64_t shape_m, int64_t shape_n, int64_t shape_k) { | ||
| size_t max_size = 0; | ||
|
|
||
| max_size = std::max(max_size, | ||
| runner_bf16_bf16_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); | ||
| max_size = std::max(max_size, | ||
| runner_bf16_fp8_->getWorkspaceSizeBase(shape_m, shape_n, shape_k, 1)); | ||
|
|
||
| return max_size; | ||
| } | ||
|
|
||
| void configureWorkspace(const TensorView& workspace) { | ||
| auto workspace_ptr = reinterpret_cast<char*>(workspace.data_ptr()); | ||
| workspace_ = workspace_ptr; | ||
|
|
||
| runner_bf16_bf16_->configureWorkspace(workspace_ptr); | ||
| runner_bf16_fp8_->configureWorkspace(workspace_ptr); | ||
| } | ||
|
|
||
| std::unique_ptr<kernels::CutlassFp8BlockScaleGemmRunner< | ||
| __nv_bfloat16, __nv_bfloat16, __nv_bfloat16>> runner_bf16_bf16_; | ||
| std::unique_ptr<kernels::CutlassFp8BlockScaleGemmRunner< | ||
| __nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>> runner_bf16_fp8_; | ||
|
|
||
| char* workspace_ = nullptr; | ||
| }; | ||
|
|
||
| tvm::ffi::Module init() { | ||
| auto ptr = tvm::ffi::make_object<Fp8BlockScaleGemmRunner>(); | ||
| return tvm::ffi::Module(ptr); | ||
| } | ||
|
|
||
| TVM_FFI_DLL_EXPORT_TYPED_FUNC(init, init); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please fix the pre-commit issues, by running |
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ runner only supports BF16 input; any FP8 input combination is rejected
selectRunnerdeliberately only returns a runner for(!input_is_fp8 && !weight_is_fp8)and(!input_is_fp8 && weight_is_fp8). Any case withinput_is_fp8 == truewill yieldnullptrand trip the"Unsupported dtype combination"check inrunGemm. This is fine as long as the Python API never calls this path with FP8 inputs. The current Python wrapper (fp8_blockscale_gemm_swapabingemm_base.py) allows FP8 inputs and advertises them in the docstring, which does not match what this binding actually supports.Iβd recommend either:
"Unsupported dtype combination"from this C++ layer.