Skip to content
Merged
20 changes: 14 additions & 6 deletions benchmarks/bench_groupwise_grouped_gemm_mxfp4_blackwell.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (c) 2025 by FlashInfer team.
Copyright (c) 2025-2026 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@

from itertools import product

from flashinfer.utils import is_sm12x_supported
import numpy as np
import torch

Expand Down Expand Up @@ -69,14 +70,21 @@ def bench_groupwise_grouped_gemm_mxfp4_blackwell(
segment_offsets = torch.arange(
0, (group_size + 1) * m, m, device="cuda:0", dtype=torch.int32
)
if is_sm12x_supported(a.device):
mma_sm_list = [1]
tile_m_list = [128]
tile_n_list = [128]
tile_k_list = [128]
swap_ab_list = [False]
else:
mma_sm_list = [1, 2]
tile_m_list = [128]
tile_n_list = [64, 128, 192, 256]
tile_k_list = [128, 256]
swap_ab_list = [True, False]

ms_best = float("inf")
config_best = None
mma_sm_list = [1, 2]
tile_m_list = [128]
tile_n_list = [64, 128, 192, 256]
tile_k_list = [128, 256]
swap_ab_list = [True, False]
for mma_sm, tile_m, tile_n, tile_k, swap_ab in product(
mma_sm_list, tile_m_list, tile_n_list, tile_k_list, swap_ab_list
):
Expand Down
118 changes: 118 additions & 0 deletions benchmarks/bench_groupwise_grouped_gemm_nvfp4_blackwell_geforce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Copyright (c) 2025-2026 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from itertools import product

import numpy as np
import torch

import flashinfer
from flashinfer.testing.utils import bench_gpu_time
from flashinfer.utils import get_compute_capability


def bench_groupwise_grouped_gemm_nvfp4_blackwell(group_size, m, n, k, out_dtype):
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] not in [12]:
print("group_gemm_nvfp4_nt_groupwise is only supported on SM120/SM121 GPUs.")
return
torch.random.manual_seed(0)
assert n % 8 == 0
assert k % 128 == 0
tile_size = 16
alignment_sf = 128
a = torch.randint(
0, 256, (group_size * m, k // 2), dtype=torch.uint8, device="cuda:0"
)
b = torch.randint(
0, 256, (group_size, n, k // 2), dtype=torch.uint8, device="cuda:0"
)
out = torch.empty(group_size * m, n, dtype=out_dtype, device="cuda:0")

a_scale = torch.randint(
0,
256,
(
(group_size * m + (alignment_sf - 1) * group_size)
// alignment_sf
* alignment_sf,
k // tile_size,
),
dtype=torch.uint8,
device="cuda:0",
)
b_scale = torch.randint(
0,
256,
(
group_size,
(n + alignment_sf - 1) // alignment_sf * alignment_sf,
k // tile_size,
),
dtype=torch.uint8,
device="cuda:0",
)

segment_offsets = torch.arange(
0, (group_size + 1) * m, m, device="cuda:0", dtype=torch.int32
)

tile_m_list = [128]
tile_n_list = [128]
tile_k_list = [128, 256]

ms_best = float("inf")
config_best = None
for tile_m, tile_n, tile_k in product(tile_m_list, tile_n_list, tile_k_list):
measurements = bench_gpu_time(
lambda: flashinfer.gemm.group_gemm_nvfp4_nt_groupwise(
a,
b,
a_scale,
b_scale,
segment_offsets,
out=out,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
),
dry_run_time_ms=10,
repeat_time_ms=100,
)
ms = np.median(measurements)
if ms < ms_best:
ms_best = ms
config_best = {
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
}
tflops_per_second = 2 * group_size * m * n * k * 1e-9 / ms_best
print(
f"group_gemm_nvfp4_nt_groupwise group_size={group_size} m={m} n={n} k={k} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
)
print(f"best config: {config_best}")
print()


if __name__ == "__main__":
for group_size in [1, 3, 8, 16]:
for m in [128, 512, 1024, 2048, 4096, 8192]:
for n in [1024, 2048, 4096, 8192]:
for k in [1024, 2048, 4096, 8192]:
bench_groupwise_grouped_gemm_nvfp4_blackwell(
group_size, m, n, k, torch.bfloat16
)
136 changes: 136 additions & 0 deletions csrc/group_gemm_mxfp4_groupwise_sm120.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/cutlass_utils.cuh>

#include "tvm_ffi_utils.h"

using namespace flashinfer;

#define DISPATCH_TILE_M(tile_m, TILE_M, ...) \
[&]() -> bool { \
if (tile_m == 128) { \
constexpr int TILE_M = 128; \
return __VA_ARGS__(); \
} \
TVM_FFI_ICHECK(false) << "Unsupported TILE M"; \
return false; \
}()

#define DISPATCH_TILE_N(tile_n, TILE_N, ...) \
[&]() -> bool { \
if (tile_n == 128) { \
constexpr int TILE_N = 128; \
return __VA_ARGS__(); \
} \
TVM_FFI_ICHECK(false) << "Unsupported TILE N"; \
return false; \
}()

#define DISPATCH_TILE_K(tile_k, TILE_K, ...) \
[&]() -> bool { \
if (tile_k == 128) { \
constexpr int TILE_K = 128; \
return __VA_ARGS__(); \
} \
TVM_FFI_ICHECK(false) << "Unsupported TILE K"; \
return false; \
}()

#define DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(input_a_dtype, input_b_dtype, sf_a_dtype, sf_b_dtype, \
output_dtype, c_type_in_a, c_type_in_b, c_type_sf_a, \
c_type_sf_b, c_type_out, ...) \
[&]() -> bool { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE(output_dtype, c_type_out, [&] { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF(sf_b_dtype, c_type_sf_b, [&] { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE_SF(sf_a_dtype, c_type_sf_a, [&] { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE(input_b_dtype, c_type_in_b, [&] { \
return DISPATCH_DLPACK_DTYPE_TO_CTYPE(input_a_dtype, c_type_in_a, \
[&] { return __VA_ARGS__(); }); \
}); \
}); \
}); \
}); \
}()

template <typename T_A, typename T_B, typename T_SFA, typename T_SFB, typename T_OUT>
constexpr bool is_valid_config() {
if constexpr ((std::is_same_v<T_A, __nv_fp8_e4m3> || std::is_same_v<T_A, __nv_fp8_e5m2>) &&
std::is_same_v<T_B, __nv_fp4_e2m1> && std::is_same_v<T_SFA, __nv_fp8_e8m0> &&
std::is_same_v<T_SFB, __nv_fp8_e8m0> &&
(std::is_same_v<T_OUT, nv_half> || std::is_same_v<T_OUT, nv_bfloat16>)) {
return true;
}
return false;
}

namespace flashinfer {
namespace group_gemm {

template <int TileM, int TileN, int TileK, typename DTypeInA, typename DTypeInB, typename DTypeSFA,
typename DTypeSFB, typename DTypeOut>
cudaError_t CutlassMXFP4GroupwiseScaledGroupGEMMSM120(
void* int_buffer, size_t int_buffer_size_in_bytes, void* float_buffer,
size_t float_buffer_size_in_bytes, DTypeInA* A, DTypeInB* B, DTypeSFA* SFA, DTypeSFB* SFB,
DTypeOut* D, int* m_indptr, int n, int k, int num_groups, cudaStream_t stream, int device_id);

} // namespace group_gemm
} // namespace flashinfer

void CutlassGroupGemmMXFP4GroupwiseScaledSM120(TensorView int_workspace_buffer,
TensorView float_workspace_buffer, TensorView A,
TensorView B, TensorView SFA, TensorView SFB,
TensorView D, TensorView m_indptr, int64_t n,
int64_t k, int64_t tile_m, int64_t tile_n,
int64_t tile_k) {
int device_id = float_workspace_buffer.device().device_id;
ffi::CUDADeviceGuard device_guard(device_id);
auto stream = get_stream(float_workspace_buffer.device());
int num_groups = m_indptr.size(0) - 1;
DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(
A.dtype(), B.dtype(), SFA.dtype(), SFB.dtype(), D.dtype(), c_type_in_a, c_type_in_b,
c_type_sf_a, c_type_sf_b, c_type_out, [&] {
return DISPATCH_TILE_M(tile_m, TILE_M, [&] {
return DISPATCH_TILE_N(tile_n, TILE_N, [&] {
return DISPATCH_TILE_K(tile_k, TILE_K, [&] {
if constexpr (is_valid_config<c_type_in_a, c_type_in_b, c_type_sf_a, c_type_sf_b,
c_type_out>()) {
using cutlass_t_in_a = cutlass_dtype_t<c_type_in_a>;
using cutlass_t_in_b = cutlass_dtype_t<c_type_in_b>;
using cutlass_t_sf_a = cutlass_dtype_t<c_type_sf_a>;
using cutlass_t_sf_b = cutlass_dtype_t<c_type_sf_b>;
using cutlass_t_out = cutlass_dtype_t<c_type_out>;
auto status = flashinfer::group_gemm::CutlassMXFP4GroupwiseScaledGroupGEMMSM120<
TILE_M, TILE_N, TILE_K>(
static_cast<int*>(int_workspace_buffer.data_ptr()),
get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0),
static_cast<float*>(float_workspace_buffer.data_ptr()),
get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0),
static_cast<cutlass_t_in_a*>(A.data_ptr()),
static_cast<cutlass_t_in_b*>(B.data_ptr()),
static_cast<cutlass_t_sf_a*>(SFA.data_ptr()),
static_cast<cutlass_t_sf_b*>(SFB.data_ptr()),
static_cast<cutlass_t_out*>(D.data_ptr()),
static_cast<int*>(m_indptr.data_ptr()), n, k, num_groups, stream, device_id);
return status == cudaSuccess;
} else {
TVM_FFI_ICHECK(false) << "Unsupported input data type";
return false;
}
});
});
});
});
}
54 changes: 54 additions & 0 deletions csrc/group_gemm_mxfp4_groupwise_sm120_kernel_inst.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2026 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/gemm/group_gemm_mxfp4_groupwise_sm120.cuh>

using namespace flashinfer;
using namespace flashinfer::group_gemm;

namespace flashinfer {
namespace group_gemm {

{% for tile_m in [128] %}
{% for tile_n in [128] %}
{% for tile_k in [128] %}
{% for dtype_sfa in ["cutlass::float_ue8m0_t"] %}
{% for dtype_sfb in ["cutlass::float_ue8m0_t"] %}


INSTANTIATE_GROUP_GEMM_MXFP4_GROUPWISE_SM120(
{{ tile_m }},
{{ tile_n }},
{{ tile_k }},
{{ dtype_a | trim }},
{{ dtype_b | trim }},
{{ dtype_sfa | trim }},
{{ dtype_sfb | trim }},
{{ dtype_d | trim }},
{{ dtype_a | replace("cutlass::", "") }},
{{ dtype_b | replace("cutlass::", "")}},
{{ dtype_sfa | replace("cutlass::", "")}},
{{ dtype_sfb | replace("cutlass::", "")}},
{{ dtype_d | replace("cutlass::", "")}}
)

{% endfor %}
{% endfor %}
{% endfor %}
{% endfor %}
{% endfor %}

}; // namespace group_gemm
}; // namespace flashinfer
Loading
Loading