-
Notifications
You must be signed in to change notification settings - Fork 888
Support for MXFP4 and NVFP4 group GEMMs on GeForce and Spark #2738
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7ce16f2
Support for MXFP4 and NVFP4 group GEMMs on Geforce and Spark
aaf5518
Run precommit
97800c9
Just enabling for SM100 is sufficient
2af092a
Clean up and some fixes
9cad471
Apply suggestions from code review
depaulmillz 7182f84
Fixup UE4M3
af0f4aa
Address coderabbit comments
57df4a7
Fix partial commit
c0f5c64
Finish addressing code rabbit comments
depaulmillz aac9c5b
Merge remote-tracking branch 'origin' into geforce_and_spark
6659ad0
Ensure no race conditions after revert
a2cd02d
Address comments from code rabbit and some formatting
b93397d
Merge branch 'main' into geforce_and_spark
depaulmillz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
118 changes: 118 additions & 0 deletions
118
benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| """ | ||
| Copyright (c) 2025-2026 by FlashInfer team. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| """ | ||
|
|
||
| from itertools import product | ||
|
|
||
| import numpy as np | ||
| import torch | ||
|
|
||
| import flashinfer | ||
| from flashinfer.testing.utils import bench_gpu_time | ||
| from flashinfer.utils import get_compute_capability | ||
|
|
||
|
|
||
| def bench_groupwise_grouped_gemm_nvfp4_blackwell(group_size, m, n, k, out_dtype): | ||
| compute_capability = get_compute_capability(torch.device("cuda")) | ||
| if compute_capability[0] not in [12]: | ||
| print("group_gemm_nvfp4_nt_groupwise is only supported on SM120/SM121 GPUs.") | ||
| return | ||
| torch.random.manual_seed(0) | ||
| assert n % 8 == 0 | ||
| assert k % 128 == 0 | ||
| tile_size = 16 | ||
| alignment_sf = 128 | ||
| a = torch.randint( | ||
| 0, 256, (group_size * m, k // 2), dtype=torch.uint8, device="cuda:0" | ||
| ) | ||
| b = torch.randint( | ||
| 0, 256, (group_size, n, k // 2), dtype=torch.uint8, device="cuda:0" | ||
| ) | ||
| out = torch.empty(group_size * m, n, dtype=out_dtype, device="cuda:0") | ||
|
|
||
| a_scale = torch.randint( | ||
| 0, | ||
| 256, | ||
| ( | ||
| (group_size * m + (alignment_sf - 1) * group_size) | ||
| // alignment_sf | ||
| * alignment_sf, | ||
| k // tile_size, | ||
| ), | ||
| dtype=torch.uint8, | ||
| device="cuda:0", | ||
| ) | ||
| b_scale = torch.randint( | ||
| 0, | ||
| 256, | ||
| ( | ||
| group_size, | ||
| (n + alignment_sf - 1) // alignment_sf * alignment_sf, | ||
| k // tile_size, | ||
| ), | ||
| dtype=torch.uint8, | ||
| device="cuda:0", | ||
| ) | ||
|
|
||
| segment_offsets = torch.arange( | ||
| 0, (group_size + 1) * m, m, device="cuda:0", dtype=torch.int32 | ||
| ) | ||
|
|
||
| tile_m_list = [128] | ||
| tile_n_list = [128] | ||
| tile_k_list = [128, 256] | ||
|
|
||
| ms_best = float("inf") | ||
| config_best = None | ||
| for tile_m, tile_n, tile_k in product(tile_m_list, tile_n_list, tile_k_list): | ||
| measurements = bench_gpu_time( | ||
| lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise( | ||
| a, | ||
| b, | ||
| a_scale, | ||
| b_scale, | ||
| segment_offsets, | ||
| out=out, | ||
| tile_m=tile_m, | ||
| tile_n=tile_n, | ||
| tile_k=tile_k, | ||
| ), | ||
| dry_run_time_ms=10, | ||
| repeat_time_ms=100, | ||
| ) | ||
| ms = np.median(measurements) | ||
| if ms < ms_best: | ||
| ms_best = ms | ||
| config_best = { | ||
| "tile_m": tile_m, | ||
| "tile_n": tile_n, | ||
| "tile_k": tile_k, | ||
| } | ||
| tflops_per_second = 2 * group_size * m * n * k * 1e-9 / ms_best | ||
| print( | ||
| f"group_gemm_nvfp4_nt_groupwise group_size={group_size} m={m} n={n} k={k} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s" | ||
| ) | ||
| print(f"best config: {config_best}") | ||
| print() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| for group_size in [1, 3, 8, 16]: | ||
| for m in [128, 512, 1024, 2048, 4096, 8192]: | ||
| for n in [1024, 2048, 4096, 8192]: | ||
| for k in [1024, 2048, 4096, 8192]: | ||
| bench_groupwise_grouped_gemm_nvfp4_blackwell( | ||
| group_size, m, n, k, torch.bfloat16 | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| /* | ||
| * Copyright (c) 2026 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include <flashinfer/cutlass_utils.cuh> | ||
|
|
||
| #include "tvm_ffi_utils.h" | ||
|
|
||
| using namespace flashinfer; | ||
|
|
||
| #define DISPATCH_TILE_M(tile_m, TILE_M, ...) \ | ||
| [&]() -> bool { \ | ||
| if (tile_m == 128) { \ | ||
| constexpr int TILE_M = 128; \ | ||
| return __VA_ARGS__(); \ | ||
| } \ | ||
| TVM_FFI_ICHECK(false) << "Unsupported TILE M"; \ | ||
| return false; \ | ||
| }() | ||
|
|
||
| #define DISPATCH_TILE_N(tile_n, TILE_N, ...) \ | ||
| [&]() -> bool { \ | ||
| if (tile_n == 128) { \ | ||
| constexpr int TILE_N = 128; \ | ||
| return __VA_ARGS__(); \ | ||
| } \ | ||
| TVM_FFI_ICHECK(false) << "Unsupported TILE N"; \ | ||
| return false; \ | ||
| }() | ||
|
|
||
| #define DISPATCH_TILE_K(tile_k, TILE_K, ...) \ | ||
| [&]() -> bool { \ | ||
| if (tile_k == 128) { \ | ||
| constexpr int TILE_K = 128; \ | ||
| return __VA_ARGS__(); \ | ||
| } \ | ||
| TVM_FFI_ICHECK(false) << "Unsupported TILE K"; \ | ||
| return false; \ | ||
| }() | ||
|
|
||
| #define DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(input_a_dtype, input_b_dtype, sf_a_dtype, sf_b_dtype, \ | ||
| output_dtype, c_type_in_a, c_type_in_b, c_type_sf_a, \ | ||
| c_type_sf_b, c_type_out, ...) \ | ||
| [&]() -> bool { \ | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE(output_dtype, c_type_out, [&] { \ | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF(sf_b_dtype, c_type_sf_b, [&] { \ | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF(sf_a_dtype, c_type_sf_a, [&] { \ | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE(input_b_dtype, c_type_in_b, [&] { \ | ||
| return DISPATCH_DLPACK_DTYPE_TO_CTYPE(input_a_dtype, c_type_in_a, \ | ||
| [&] { return __VA_ARGS__(); }); \ | ||
| }); \ | ||
| }); \ | ||
| }); \ | ||
| }); \ | ||
| }() | ||
|
|
||
| template <typename T_A, typename T_B, typename T_SFA, typename T_SFB, typename T_OUT> | ||
| constexpr bool is_valid_config() { | ||
| if constexpr ((std::is_same_v<T_A, __nv_fp8_e4m3> || std::is_same_v<T_A, __nv_fp8_e5m2>) && | ||
| std::is_same_v<T_B, __nv_fp4_e2m1> && std::is_same_v<T_SFA, __nv_fp8_e8m0> && | ||
| std::is_same_v<T_SFB, __nv_fp8_e8m0> && | ||
| (std::is_same_v<T_OUT, nv_half> || std::is_same_v<T_OUT, nv_bfloat16>)) { | ||
| return true; | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| namespace flashinfer { | ||
| namespace group_gemm { | ||
|
|
||
| template <int TileM, int TileN, int TileK, typename DTypeInA, typename DTypeInB, typename DTypeSFA, | ||
| typename DTypeSFB, typename DTypeOut> | ||
| cudaError_t CutlassMXFP4GroupwiseScaledGroupGEMMSM120( | ||
| void* int_buffer, size_t int_buffer_size_in_bytes, void* float_buffer, | ||
| size_t float_buffer_size_in_bytes, DTypeInA* A, DTypeInB* B, DTypeSFA* SFA, DTypeSFB* SFB, | ||
| DTypeOut* D, int* m_indptr, int n, int k, int num_groups, cudaStream_t stream, int device_id); | ||
|
|
||
| } // namespace group_gemm | ||
| } // namespace flashinfer | ||
|
|
||
| void CutlassGroupGemmMXFP4GroupwiseScaledSM120(TensorView int_workspace_buffer, | ||
| TensorView float_workspace_buffer, TensorView A, | ||
| TensorView B, TensorView SFA, TensorView SFB, | ||
| TensorView D, TensorView m_indptr, int64_t n, | ||
| int64_t k, int64_t tile_m, int64_t tile_n, | ||
| int64_t tile_k) { | ||
| int device_id = float_workspace_buffer.device().device_id; | ||
| ffi::CUDADeviceGuard device_guard(device_id); | ||
| auto stream = get_stream(float_workspace_buffer.device()); | ||
| int num_groups = m_indptr.size(0) - 1; | ||
| DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE( | ||
| A.dtype(), B.dtype(), SFA.dtype(), SFB.dtype(), D.dtype(), c_type_in_a, c_type_in_b, | ||
| c_type_sf_a, c_type_sf_b, c_type_out, [&] { | ||
| return DISPATCH_TILE_M(tile_m, TILE_M, [&] { | ||
| return DISPATCH_TILE_N(tile_n, TILE_N, [&] { | ||
| return DISPATCH_TILE_K(tile_k, TILE_K, [&] { | ||
| if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b, | ||
| c_type_out>()) { | ||
| using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>; | ||
| using cutlass_t_in_b = cutlass_dtype_t<c_type_in_b>; | ||
| using cutlass_t_sf_a = cutlass_dtype_t<c_type_sf_a>; | ||
| using cutlass_t_sf_b = cutlass_dtype_t<c_type_sf_b>; | ||
| using cutlass_t_out = cutlass_dtype_t<c_type_out>; | ||
| auto status = flashinfer::group_gemm::CutlassMXFP4GroupwiseScaledGroupGEMMSM120< | ||
| TILE_M, TILE_N, TILE_K>( | ||
| static_cast<int*>(int_workspace_buffer.data_ptr()), | ||
| get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), | ||
| static_cast<float*>(float_workspace_buffer.data_ptr()), | ||
| get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), | ||
| static_cast<cutlass_t_in_a*>(A.data_ptr()), | ||
| static_cast<cutlass_t_in_b*>(B.data_ptr()), | ||
| static_cast<cutlass_t_sf_a*>(SFA.data_ptr()), | ||
| static_cast<cutlass_t_sf_b*>(SFB.data_ptr()), | ||
| static_cast<cutlass_t_out*>(D.data_ptr()), | ||
| static_cast<int*>(m_indptr.data_ptr()), n, k, num_groups, stream, device_id); | ||
| return status == cudaSuccess; | ||
| } else { | ||
| TVM_FFI_ICHECK(false) << "Unsupported input data type"; | ||
| return false; | ||
| } | ||
| }); | ||
| }); | ||
| }); | ||
| }); | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| /* | ||
| * Copyright (c) 2026 by FlashInfer team. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #include <flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh> | ||
|
|
||
| using namespace flashinfer; | ||
| using namespace flashinfer::group_gemm; | ||
|
|
||
| namespace flashinfer { | ||
| namespace group_gemm { | ||
|
|
||
| {% for tile_m in [128] %} | ||
| {% for tile_n in [128] %} | ||
| {% for tile_k in [128] %} | ||
| {% for dtype_sfa in ["cutlass::float_ue8m0_t"] %} | ||
| {% for dtype_sfb in ["cutlass::float_ue8m0_t"] %} | ||
|
|
||
|
|
||
| INSTANTIATE_GROUP_GEMM_MXFP4_GROUPWISE_SM120( | ||
| {{ tile_m }}, | ||
| {{ tile_n }}, | ||
| {{ tile_k }}, | ||
| {{ dtype_a | trim }}, | ||
| {{ dtype_b | trim }}, | ||
| {{ dtype_sfa | trim }}, | ||
| {{ dtype_sfb | trim }}, | ||
| {{ dtype_d | trim }}, | ||
| {{ dtype_a | replace("cutlass::", "") }}, | ||
| {{ dtype_b | replace("cutlass::", "")}}, | ||
| {{ dtype_sfa | replace("cutlass::", "")}}, | ||
| {{ dtype_sfb | replace("cutlass::", "")}}, | ||
| {{ dtype_d | replace("cutlass::", "")}} | ||
| ) | ||
|
|
||
| {% endfor %} | ||
| {% endfor %} | ||
| {% endfor %} | ||
| {% endfor %} | ||
| {% endfor %} | ||
|
|
||
| }; // namespace group_gemm | ||
| }; // namespace flashinfer |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.