diff --git a/csrc/fp4_gemm_cutlass.jinja b/csrc/fp4_gemm_cutlass.jinja index 27d6f3f659..b45ee09dca 100644 --- a/csrc/fp4_gemm_cutlass.jinja +++ b/csrc/fp4_gemm_cutlass.jinja @@ -26,6 +26,7 @@ INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ ct INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM) INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM) +INSTANTIATE_FP4_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM) } // namespace gemm } // namespace flashinfer diff --git a/csrc/fp4_gemm_cutlass_sm103.cu b/csrc/fp4_gemm_cutlass_sm103.cu new file mode 100644 index 0000000000..8bb5ebf984 --- /dev/null +++ b/csrc/fp4_gemm_cutlass_sm103.cu @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include + +#include "flashinfer/gemm/cutlass_gemm_configs.h" +#include "flashinfer/gemm/fp4_gemm_cutlass.h" +#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h" +#include "tvm_ffi_utils.h" + +using flashinfer::gemm::ClusterShape; +using flashinfer::gemm::CutlassFp4GemmRunner; +using flashinfer::gemm::CutlassFp4GemmRunnerInterface; +using flashinfer::gemm::CutlassGemmConfig; +using flashinfer::gemm::CutlassTileConfigSM100; +using flashinfer::gemm::EpilogueScheduleType; +using flashinfer::gemm::FP4GemmType; +using flashinfer::gemm::MainloopScheduleType; + +namespace flashinfer { +namespace gemm { +template class CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4>; +template class CutlassFp4GemmRunner; +} // namespace gemm +} // namespace flashinfer + +namespace torch_ext { + +namespace { + +CutlassGemmConfig getFp4GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tactic) { + auto getCutlassFp4GemmConfigs = []() { + CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner; + return gemmRunner.getConfigs(); + }; + static std::vector globalConfigs = getCutlassFp4GemmConfigs(); + TVM_FFI_ICHECK(tactic >= 0 && tactic < globalConfigs.size()) + << "tactic must be between 0 and " << globalConfigs.size(); + return globalConfigs[tactic]; +} + +template +void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Scale, + TensorView mat2Scale, TensorView globalScale, int64_t m, int64_t n, int64_t k, + int64_t batch_count, CutlassGemmConfig const& gemmConfig, + TensorView workspace_buffer) { + CutlassFp4GemmRunner gemmRunner; + + int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k, batch_count); + int64_t const provided_workspace_size = + workspace_buffer.numel() * get_element_size(workspace_buffer); + + auto runKernel = [&](void* workspace) { + gemmRunner.gemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), mat1Scale.data_ptr(), + mat2Scale.data_ptr(), static_cast(globalScale.data_ptr()), m, n, k, + batch_count, gemmConfig, reinterpret_cast(workspace), + required_workspace_size, get_stream(mat1.device())); + }; + + if (provided_workspace_size < required_workspace_size) { + Tensor new_workspace = + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); + } else { + runKernel(workspace_buffer.data_ptr()); + } +} + +constexpr auto FLOAT4_E2M1X2 = dl_uint8; // uint8_t +constexpr auto SF_DTYPE = dl_uint8; // uint8_t + +// mat1: [B, M, K / 2], FLOAT4_E2M1X2 or [B, M, K], FLOAT8_E4M3FN +// mat2: [B, N, K / 2], FLOAT4_E2M1X2 +// out: [B, M, N], fp16/bf16/fp32 +// mat1Scale: ceil(M / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) +// mat2Scale: ceil(N / 128) * 128 * ceil(K / sfVecSize / 4) * 4, SF_DTYPE (UE4M3 or UE8M0) +// globalScale: [1], 1 / (((448 * 6) / mat1.abs().max()) * ((448 * 6) / mat2.abs().max())) +// B = 1 for GEMM op as a special case +void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, + int64_t tactic) { + CHECK_INPUT_AND_TYPE(mat1, FLOAT4_E2M1X2); + CHECK_INPUT_AND_TYPE(mat2, FLOAT4_E2M1X2); + + int mat2_k_scale = 1; + + CHECK_INPUT_AND_TYPE(mat1Scale, SF_DTYPE); + CHECK_INPUT_AND_TYPE(mat2Scale, SF_DTYPE); + + CHECK_INPUT_AND_TYPE(globalScale, dl_float32); + + int64_t m, n, k, b; + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); + k = mat2.size(1) * 2; + b = 1; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); + k = mat2.size(2) * 2; + b = mat1.size(0); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; + } + + // No heuristic for now, we rely on the autotuner to select the best tactic. + if (tactic == -1) { + tactic = 0; + } + auto config = getFp4GemmConfig(m, n, k, tactic); + + constexpr int alignment = 32; + TVM_FFI_ICHECK_EQ(k % alignment, 0) + << "Expected k to be divisible by " << alignment << ", but got mat1 shape: (" << mat1.size(0) + << "x" << mat1.size(1) << "), k: " << k << "."; + TVM_FFI_ICHECK_EQ(n % alignment, 0) + << "Expected n to be divisible by " << alignment << ", but got mat2 shape: (" << mat2.size(0) + << "x" << mat2.size(1) << ")."; + + // Validate out dimensions + std::vector out_shape = + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size()) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); + for (int i = 0; i < out_shape.size(); ++i) { + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) + << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " + << out.size(i); + } + + switch (encode_dlpack_dtype(out.dtype())) { + case float16_code: + runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config, + workspace_buffer); + break; + case bfloat16_code: + runGemm<__nv_bfloat16>(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config, + workspace_buffer); + break; + default: + TVM_FFI_ICHECK(false) << "out_dtype must be one of fp16/bf16."; + } +} + +} // namespace + +void fp4_gemm(TensorView mat1, TensorView mat2, TensorView mat1Scale, TensorView mat2Scale, + TensorView globalScale, TensorView out, TensorView workspace_buffer, int64_t tactic) { + fp4_bmm_impl(mat1, mat2, mat1Scale, mat2Scale, globalScale, out, workspace_buffer, tactic); +} + +int64_t fp4_gemm_tactic_num() { + auto getCutlassConfigs = []() { + CutlassFp4GemmRunner<__nv_bfloat16, FP4GemmType::W4A4_NVFP4_NVFP4> gemmRunner; + return gemmRunner.getConfigs(); + }; + static int64_t totalTactics = getCutlassConfigs().size(); + return totalTactics; +} + +} // namespace torch_ext + +TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_gemm, torch_ext::fp4_gemm); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(fp4_gemm_tactic_num, torch_ext::fp4_gemm_tactic_num); diff --git a/csrc/fp4_gemm_cutlass_sm103.jinja b/csrc/fp4_gemm_cutlass_sm103.jinja new file mode 100644 index 0000000000..21dc67176c --- /dev/null +++ b/csrc/fp4_gemm_cutlass_sm103.jinja @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h" + +namespace flashinfer { +namespace gemm { +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 1, 1, _1SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 2, 1, _1SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 1, 4, 1, _1SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 1, 1, _2SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 2, 1, _2SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 2, 4, 1, _2SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 2, 1, _2SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 4, 1, _2SM_sm103) +INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER({{ type }}, {{ cta_m }}, {{ cta_n }}, {{ cta_k }}, 4, 1, 1, _2SM_sm103) + +} // namespace gemm +} // namespace flashinfer diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index 4175126827..98c1ea70ef 100644 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -51,6 +51,7 @@ from ..jit.gemm import gen_gemm_sm120_module from ..jit.gemm import gen_gemm_sm120_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp4 +from ..jit.gemm import gen_gemm_sm103_module_cutlass_fp4 from ..jit.gemm import gen_gemm_sm100_module_cutlass_fp8 from ..jit.gemm import gen_trtllm_gen_gemm_module from ..jit.gemm import gen_tgv_gemm_sm10x_module @@ -515,13 +516,22 @@ def forward( @functools.cache def get_gemm_sm100_module_cutlass_fp4(): - """Get the SM100/103/110 FP4 GEMM module.""" + """Get the SM100/110 FP4 GEMM module.""" module = gen_gemm_sm100_module_cutlass_fp4().build_and_load() return _create_cutlass_fp4_gemm_module( module, "flashinfer::cutlass_fp4_gemm", "cutlass_fp4_gemm" ) +@functools.cache +def get_gemm_sm103_module_cutlass_fp4(): + """Get the SM103 FP4 GEMM module.""" + module = gen_gemm_sm103_module_cutlass_fp4().build_and_load() + return _create_cutlass_fp4_gemm_module( + module, "flashinfer::cutlass_fp4_gemm", "cutlass_fp4_gemm" + ) + + @functools.cache def get_gemm_sm120_module_cutlass_fp4(): """Get the SM120/121 FP4 GEMM module.""" @@ -533,9 +543,13 @@ def get_gemm_sm120_module_cutlass_fp4(): def get_cutlass_fp4_gemm_module( sm_major: int, + sm_minor: int, ): if sm_major in [10, 11]: - return get_gemm_sm100_module_cutlass_fp4() + if sm_minor == 3: + return get_gemm_sm103_module_cutlass_fp4() + else: + return get_gemm_sm100_module_cutlass_fp4() elif sm_major == 12: return get_gemm_sm120_module_cutlass_fp4() else: @@ -2273,14 +2287,16 @@ def mm_fp4( # At this point, backends contains a supported backend if specified, or all supported backends if backend='auto'. # Lazy initialization of runners to avoid overhead of creating a new runner that will not be used - major, _ = get_compute_capability(a.device) + major, minor = get_compute_capability(a.device) backend_to_runner_factory = { "cudnn": lambda: _cudnn_gemm_fp4_runner(), "trtllm": lambda: get_trtllm_fp4_gemm_module().trtllm_fp4_gemm_runner( use_8x4_sf_layout ), - "cutlass": lambda: get_cutlass_fp4_gemm_module(major).cutlass_fp4_gemm_runner(), + "cutlass": lambda: get_cutlass_fp4_gemm_module( + major, minor + ).cutlass_fp4_gemm_runner(), } runners = [backend_to_runner_factory[cur_backend]() for cur_backend in backends] diff --git a/flashinfer/jit/gemm/__init__.py b/flashinfer/jit/gemm/__init__.py index e81d51e15f..a8527cbd4b 100644 --- a/flashinfer/jit/gemm/__init__.py +++ b/flashinfer/jit/gemm/__init__.py @@ -17,6 +17,7 @@ from .core import ( gen_gemm_module, gen_gemm_sm100_module_cutlass_fp4, + gen_gemm_sm103_module_cutlass_fp4, gen_gemm_sm120_module_cutlass_fp4, gen_gemm_sm100_module_cutlass_fp8, gen_gemm_sm100_module, @@ -32,6 +33,7 @@ __all__ = [ "gen_gemm_module", "gen_gemm_sm100_module_cutlass_fp4", + "gen_gemm_sm103_module_cutlass_fp4", "gen_gemm_sm120_module_cutlass_fp4", "gen_gemm_sm100_module_cutlass_fp8", "gen_gemm_sm100_module", diff --git a/flashinfer/jit/gemm/core.py b/flashinfer/jit/gemm/core.py index 7873d0de14..360bde4cd8 100644 --- a/flashinfer/jit/gemm/core.py +++ b/flashinfer/jit/gemm/core.py @@ -94,6 +94,73 @@ def gen_gemm_sm100_module_cutlass_fp4() -> JitSpec: ) +def gen_gemm_sm103_module_cutlass_fp4() -> JitSpec: + gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm103_cutlass_fp4" + os.makedirs(gen_directory, exist_ok=True) + source_paths = [ + jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass_sm103.cu", + ] + + with open(jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass_sm103.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + dtype_list = ["__nv_bfloat16", "half"] + cta_m_n_k_list = [(128, 128, 768), (128, 192, 768), (128, 256, 768)] + for cta_m, cta_n, cta_k in cta_m_n_k_list: + for dtype in dtype_list: + dest_path = ( + gen_directory + / f"fp4_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + type=dtype, + cta_m=cta_m, + cta_n=cta_n, + cta_k=cta_k, + ) + write_if_different(dest_path, source) + + with open(jit_env.FLASHINFER_CSRC_DIR / "fp4_gemm_cutlass.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) + dtype_list = ["__nv_bfloat16", "half"] + cta_m_n_k_list = [ + (128, 64, 128), + (128, 256, 128), + (128, 128, 256), + (128, 256, 256), + ] + for cta_m, cta_n, cta_k in cta_m_n_k_list: + for dtype in dtype_list: + dest_path = ( + gen_directory + / f"fp4_gemm_cutlass_{dtype}_{cta_m}_{cta_n}_{cta_k}.cu" + ) + source_paths.append(dest_path) + source = kernel_inst_templ.render( + type=dtype, + cta_m=cta_m, + cta_n=cta_n, + cta_k=cta_k, + ) + write_if_different(dest_path, source) + + nvcc_flags = current_compilation_context.get_nvcc_flags_list( + supported_major_versions=[10, 11, 12] + ) + return gen_jit_spec( + "fp4_gemm_cutlass_sm103", + source_paths, + extra_cuda_cflags=nvcc_flags + + [ + "-DENABLE_BF16", + "-DENABLE_FP4", + ], + extra_cflags=[ + "-DFAST_BUILD", + ], + ) + + def gen_gemm_sm120_module_cutlass_fp4() -> JitSpec: gen_directory = jit_env.FLASHINFER_GEN_SRC_DIR / "gen_gemm_sm120_cutlass_fp4" os.makedirs(gen_directory, exist_ok=True) diff --git a/include/flashinfer/gemm/cutlass_gemm_configs.h b/include/flashinfer/gemm/cutlass_gemm_configs.h index a0a6775153..768fc179c2 100644 --- a/include/flashinfer/gemm/cutlass_gemm_configs.h +++ b/include/flashinfer/gemm/cutlass_gemm_configs.h @@ -133,6 +133,11 @@ enum class CutlassTileConfigSM100 { CtaShape256x64x128B, CtaShape256x128x128B, CtaShape256x256x128B, + + // SM103 + CtaShape128x128x768B, + CtaShape128x192x768B, + CtaShape128x256x768B, }; enum class CutlassTileConfigSM120 { @@ -188,7 +193,11 @@ enum class TileShape { TileShape_128x32x128, TileShape_128x64x128, TileShape_128x128x128, - TileShape_128x256x128 + TileShape_128x256x128, + // SM103 + TileShape_128x128x768, + TileShape_128x192x768, + TileShape_128x256x768 }; template @@ -216,6 +225,12 @@ constexpr auto get_tile_shape() { return cute::Shape<_128, _128, _128>{}; } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x128) { return cute::Shape<_128, _256, _128>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x128x768) { // SM103 + return cute::Shape<_128, _128, _768>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x192x768) { // SM103 + return cute::Shape<_128, _192, _768>{}; + } else if constexpr (Shape_MNK == TileShape::TileShape_128x256x768) { // SM103 + return cute::Shape<_128, _256, _768>{}; } } @@ -242,6 +257,12 @@ static auto get_tile_shape_name(TileShape Shape_MNK) { return "128x128x128"; } else if (Shape_MNK == TileShape::TileShape_128x256x128) { return "128x256x128"; + } else if (Shape_MNK == TileShape::TileShape_128x128x768) { // SM103 + return "128x128x768"; + } else if (Shape_MNK == TileShape::TileShape_128x192x768) { // SM103 + return "128x192x768"; + } else if (Shape_MNK == TileShape::TileShape_128x256x768) { // SM103 + return "128x256x768"; } return "Unknown shape"; } @@ -256,7 +277,8 @@ enum class ClusterShape { ClusterShape_2x4x1, ClusterShape_4x4x1, ClusterShape_1x8x1, - ClusterShape_8x1x1 + ClusterShape_8x1x1, + ClusterShape_4x1x1 }; static auto get_cluster_shape_name(ClusterShape Shape_MNK) { @@ -272,6 +294,8 @@ static auto get_cluster_shape_name(ClusterShape Shape_MNK) { return "1x8x1"; } else if (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return "8x1x1"; + } else if (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { + return "4x1x1"; } return "Unknown shape"; } @@ -291,6 +315,8 @@ constexpr auto get_cluster_shape() { return cute::Shape<_1, _8, _1>{}; } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_8x1x1) { return cute::Shape<_8, _1, _1>{}; + } else if constexpr (Shape_MNK == ClusterShape::ClusterShape_4x1x1) { + return cute::Shape<_4, _1, _1>{}; } } diff --git a/include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h b/include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h new file mode 100644 index 0000000000..ab551b11c7 --- /dev/null +++ b/include/flashinfer/gemm/fp4_gemm_cutlass_template_sm103.h @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ +#define FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/gemm/cutlass_gemm_configs.h" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +#include "flashinfer/gemm/fp4_gemm_cutlass.h" +#include "fp4_gemm_template_sm100.h" +#include "fp4_gemm_template_sm103.h" + +namespace flashinfer { +namespace gemm { +using namespace cute; + +template +size_t dispatchNVFP4xNVFP4GemmClusterShapeSm100(T* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, + int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, const size_t workspaceBytes, + cudaStream_t stream, int* occupancy = nullptr) { + switch (gemmConfig.cluster_shape) { + case ClusterShape::ClusterShape_1x1x1: + return genericFp4GemmKernelLauncher, cute::Int<1>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x1x1: + return genericFp4GemmKernelLauncher, cute::Int<1>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_1x2x1: + return genericFp4GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x2x1: + return genericFp4GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_1x4x1: + return genericFp4GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _1SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x2x1: + return genericFp4GemmKernelLauncher, cute::Int<2>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x4x1: + return genericFp4GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x4x1: + return genericFp4GemmKernelLauncher, cute::Int<4>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x1x1: + return genericFp4GemmKernelLauncher, cute::Int<1>, + cute::Int<1>, _2SM>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[Error][FP4][dispatch_gemm_cluster_shape] Config is invalid for FP4 GEMM."); + break; + } +} + +template +size_t dispatchNVFP4xNVFP4GemmClusterShapeSm103(T* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, + int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, const size_t workspaceBytes, + cudaStream_t stream, int* occupancy = nullptr) { + switch (gemmConfig.cluster_shape) { + case ClusterShape::ClusterShape_1x1x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<1>, cute::Int<1>, _1SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x1x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<1>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_1x2x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<2>, cute::Int<1>, _1SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x2x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<2>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_1x4x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<4>, cute::Int<1>, _1SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x2x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<2>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_2x4x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<4>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x4x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<4>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case ClusterShape::ClusterShape_4x1x1: + return genericFp4UltraGemmKernelLauncher, + cute::Int<1>, cute::Int<1>, _2SM_sm103>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + default: + throw std::runtime_error( + "[Error][FP4][dispatch_gemm_cluster_shape] Config is invalid for FP4 GEMM."); + break; + } +} + +template +size_t dispatchNVFP4xNVFP4GemmCTAShapeSm100(T* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, + int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, const size_t workspaceBytes, + cudaStream_t stream, int* occupancy = nullptr) { + // Several constraints: + // Cta N should be one of 128/192/256. + // M-mode size should be 128 or 256 for 2 CTA cluster MMA; + // M-mode size should be 128 for 1 CTA cluster OMMA. + // K256 looks to be better than K128 + switch (gemmConfig.tile_config_sm100) { + case CutlassTileConfigSM100::CtaShape128x64x128B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<64>, + cute::Int<128>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x256x128B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<256>, + cute::Int<128>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x128x256B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<128>, + cute::Int<256>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x256x256B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm100, cute::Int<256>, + cute::Int<256>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x128x768B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm103, cute::Int<128>, + cute::Int<768>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x192x768B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm103, cute::Int<192>, + cute::Int<768>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::CtaShape128x256x768B: + return dispatchNVFP4xNVFP4GemmClusterShapeSm103, cute::Int<256>, + cute::Int<768>>( + D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + break; + case CutlassTileConfigSM100::Undefined: + throw std::runtime_error("[Error][FP4][dispatch_gemm_cta_shape] Gemm config undefined."); + break; + case CutlassTileConfigSM100::ChooseWithHeuristic: + throw std::runtime_error( + "[Error][FP4][dispatch_gemm_cta_shape] Gemm config should have already been " + "set by " + "heuristic."); + break; + default: + throw std::runtime_error( + "[Error][FP4][dispatch_gemm_cta_shape] Config is invalid for FP4 GEMM."); + break; + } +} +template +CutlassFp4GemmRunner::CutlassFp4GemmRunner() {} + +template +CutlassFp4GemmRunner::~CutlassFp4GemmRunner() {} + +template +size_t CutlassFp4GemmRunner::dispatchToArch( + T* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { + if constexpr (fp4GemmType == FP4GemmType::W4A4_NVFP4_NVFP4) { + return dispatchNVFP4xNVFP4GemmCTAShapeSm100(D, A, B, input_sf, weight_sf, global_sf, m, n, k, + batch_count, gemmConfig, workspace, + workspaceBytes, stream, occupancy); + } else { + throw std::runtime_error( + "[Error][CutlassFp4GemmRunner][GEMM Dispatch] FP4 Gemm type unsupported for " + "CUTLASS FP4 GEMM"); + } +} + +template +void CutlassFp4GemmRunner::gemm(void* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, + int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, const size_t workspaceBytes, + cudaStream_t stream) { + CutlassFp4GemmRunner::dispatchToArch( + reinterpret_cast(D), A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count, + gemmConfig, workspace, workspaceBytes, stream); +} + +template +std::vector CutlassFp4GemmRunner::getConfigs() const { + std::vector candidateConfigs; + + std::vector tilesSm100 = { + CutlassTileConfigSM100::CtaShape128x64x128B, CutlassTileConfigSM100::CtaShape128x256x128B, + CutlassTileConfigSM100::CtaShape128x128x256B, CutlassTileConfigSM100::CtaShape128x256x256B, + CutlassTileConfigSM100::CtaShape128x128x768B, CutlassTileConfigSM100::CtaShape128x192x768B, + CutlassTileConfigSM100::CtaShape128x256x768B, + }; + std::vector clusterShapes = { + ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1, + ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1, + ClusterShape::ClusterShape_1x4x1, ClusterShape::ClusterShape_4x2x1, + ClusterShape::ClusterShape_2x4x1, ClusterShape::ClusterShape_4x4x1, + ClusterShape::ClusterShape_4x1x1, + }; + for (auto const& tile_config : tilesSm100) { + for (auto const& cluster_config : clusterShapes) { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + cluster_config); + candidateConfigs.push_back(config); + } + } + + // There’s no heuristic yet, so for users without autotuning, we provide an ordering based on + // performance sweeps from common workloads. + std::vector best_tactics_index = {22, 20, 29, 4, 18}; + std::vector newCandidateConfigs; + for (auto const& tactic_index : best_tactics_index) { + newCandidateConfigs.push_back(candidateConfigs[tactic_index]); + } + for (int64_t i = 0; i < candidateConfigs.size(); i++) { + if (std::find(best_tactics_index.begin(), best_tactics_index.end(), i) == + best_tactics_index.end()) { + newCandidateConfigs.push_back(candidateConfigs[i]); + } + } + return newCandidateConfigs; +} + +template +size_t CutlassFp4GemmRunner::getWorkspaceSizeImpl(int const m, int const n, + int const k, + int const batch_count) { + size_t workspace_size = 0; + auto gemmConfigs = CutlassFp4GemmRunner{}.getConfigs(); + for (auto const& gemmConfig : gemmConfigs) { + try { + size_t curr_workspace_size = CutlassFp4GemmRunner::dispatchToArch( + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, m, n, k, batch_count, gemmConfig, + nullptr, 0, 0); + workspace_size = std::max(workspace_size, curr_workspace_size); + } catch (std::runtime_error& e) { + // Swallow errors when SMEM exceeds maximum allowed + continue; + } + } + return workspace_size; +} + +template +size_t CutlassFp4GemmRunner::getWorkspaceSize(int const m, int const n, int const k, + int const batch_count) { + // Custom hash function for the MNKB type + using MNK = std::tuple; + + struct MNKHash { + size_t operator()(const MNK& mnk) const { + auto h1 = std::hash{}(std::get<0>(mnk)); + auto h2 = std::hash{}(std::get<1>(mnk)); + auto h3 = std::hash{}(std::get<2>(mnk)); + auto h4 = std::hash{}(std::get<3>(mnk)); + return h1 ^ h2 ^ h3 ^ h4; + } + }; + + static std::unordered_map workspace_hashmap; + + size_t workspace_size = 0; + if (workspace_hashmap.find(std::make_tuple(m, n, k, batch_count)) == workspace_hashmap.end()) { + workspace_size = + CutlassFp4GemmRunner::getWorkspaceSizeImpl(m, n, k, batch_count); + workspace_hashmap[std::make_tuple(m, n, k, batch_count)] = workspace_size; + } else { + workspace_size = workspace_hashmap[std::make_tuple(m, n, k, batch_count)]; + } + return workspace_size; +} + +} // namespace gemm +} // namespace flashinfer +#endif // FLASHINFER_FP4_GEMM_CUTLASS_TEMPLATE_H_ diff --git a/include/flashinfer/gemm/fp4_gemm_template_sm103.h b/include/flashinfer/gemm/fp4_gemm_template_sm103.h new file mode 100644 index 0000000000..024ed23598 --- /dev/null +++ b/include/flashinfer/gemm/fp4_gemm_template_sm103.h @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_FP4_GEMM_TEMPLATE_SM103_H_ +#define FLASHINFER_FP4_GEMM_TEMPLATE_SM103_H_ + +#ifndef _WIN32 +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // #ifndef _WIN32 + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/gemm.h" +#include "flashinfer/arch_condition.h" +#include "flashinfer/cutlass_utils.cuh" + +#ifndef _WIN32 +#pragma GCC diagnostic pop +#endif // #ifndef _WIN32 + +namespace flashinfer { +namespace gemm { +using namespace cute; + +#ifdef ENABLE_BF16 +using SafeBF16_sm103 = __nv_bfloat16; +#else +using SafeBF16_sm103 = void; +#endif + +struct _1SM_sm103 {}; + +struct _2SM_sm103 {}; + +template +struct SMTypeAdapter_sm103 {}; + +template <> +struct SMTypeAdapter_sm103<_1SM_sm103> { + static int const Scale = 1; + using AtomThrShape = cute::Shape<_1, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized1Sm; + using MainloopSchedule = + cutlass::gemm::KernelTmaWarpSpecialized1SmBlockScaledMxNvf4UltraVs16Sm103; +}; + +template <> +struct SMTypeAdapter_sm103<_2SM_sm103> { + static int const Scale = 2; + using AtomThrShape = cute::Shape<_2, _1, _1>; + using EpilogueSchedule = cutlass::epilogue::NoSmemWarpSpecialized2Sm; + using MainloopSchedule = + cutlass::gemm::KernelTmaWarpSpecialized2SmBlockScaledMxNvf4UltraVs16Sm103; +}; + +template +constexpr auto always_false_sm103 = false; + +template +size_t genericFp4UltraGemmKernelLauncher(void* D, void const* A, void const* B, + void const* input_sf, void const* weight_sf, + float const* global_sf, int m, int n, int k, + int batch_count, CutlassGemmConfig gemmConfig, + char* workspace, size_t const workspaceBytes, + cudaStream_t stream, int* occupancy); + +#ifdef PLACEHOLDER_KERNELS + +#define INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, \ + CGA_K_, XSM_) \ + template <> \ + size_t genericFp4UltraGemmKernelLauncher, cute::Int, \ + cute::Int, cute::Int, \ + cute::Int, cute::Int, XSM_>( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, \ + float const* global_sf, int m, int n, int k, int batch_count, CutlassGemmConfig gemmConfig, \ + char* workspace, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { \ + throw std::runtime_error( \ + "FP4 gemm kernel is not compiled with support for " \ + "this Architecture."); \ + } + +#else + +#define INSTANTIATE_FP4_ULTRA_GEMM_KERNEL_LAUNCHER(T, CTA_M_, CTA_N_, CTA_K_, CGA_M_, CGA_N_, \ + CGA_K_, XSM_) \ + struct \ + DeviceGemmFp4GemmSm103_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_ { \ + using OutElementType = flashinfer::cutlass_dtype::type; \ + using CTAShape = cute::Shape, cute::Int, cute::Int>; \ + /*using ClusterShape = cute::Shape, cute::Int, cute::Int>;*/ \ + using ClusterShape = cute::Shape; \ + using ElementType = cutlass::float_e2m1_t; \ + using Arch = cutlass::arch::Sm103; \ + /* // Input A */ \ + using ElementA = ElementType; \ + using LayoutA = cutlass::layout::RowMajor; \ + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; \ + /* // Input B */ \ + using ElementB = ElementType; \ + using LayoutB = cutlass::layout::ColumnMajor; \ + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; \ + /* // Input C */ \ + using ElementC = void; \ + using LayoutC = cutlass::layout::RowMajor; \ + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; \ + \ + using SFType = cutlass::float_ue4m3_t; \ + using ElementCompute = float; \ + using ElementAccumulator = float; \ + using OperatorClass = cutlass::arch::OpClassTensorOp; \ + using EpilogueTileType = std::conditional_t, \ + cutlass::epilogue::collective::EpilogueTileAuto>; \ + using EpilogueSchedule = SMTypeAdapter_sm103::EpilogueSchedule; \ + using MainloopSchedule = SMTypeAdapter_sm103::MainloopSchedule; \ + using MmaTileShape = cute::Shape::Scale>, \ + cute::Int, cute::Int>; \ + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< \ + Arch, OperatorClass, MmaTileShape, ClusterShape, EpilogueTileType, ElementAccumulator, \ + ElementCompute, ElementC, LayoutC, AlignmentC, OutElementType, LayoutC, AlignmentC, \ + EpilogueSchedule, \ + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; \ + \ + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< \ + Arch, cutlass::arch::OpClassBlockScaledTensorOp, cute::tuple, LayoutA, \ + AlignmentA, cute::tuple, LayoutB, AlignmentB, ElementAccumulator, \ + MmaTileShape, ClusterShape, \ + cutlass::gemm::collective::StageCountAutoCarveout( \ + sizeof(typename CollectiveEpilogue::SharedStorage))>, \ + MainloopSchedule>::CollectiveOp; \ + \ + template \ + struct Sm103Only : Base { \ + using typename Base::Params; \ + CUTLASS_DEVICE \ + void operator()(Params const& params, char* smem_buf) { \ + if constexpr (flashinfer::arch::is_match_v<103>) { \ + this->Base::operator()(params, smem_buf); \ + } else { \ + if (cute::thread0()) { \ + printf("%s : This kernel shall only run on SM103 devices.\n", __PRETTY_FUNCTION__); \ + __trap(); \ + } \ + } \ + } \ + }; \ + using GemmKernel = \ + Sm103Only, \ + CollectiveMainloop, CollectiveEpilogue, \ + cutlass::gemm::PersistentScheduler>>; \ + \ + using Gemm = typename cutlass::gemm::device::GemmUniversalAdapter; \ + }; \ + \ + template \ + typename Gemm::Arguments \ + prepareGemmArgsSm103_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, \ + float const* global_sf, int m, int n, int k, int batch_count) { \ + using Sm1xxBlkScaledConfig = \ + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; \ + using ElementA = typename Gemm::ElementA; \ + using ElementB = typename Gemm::ElementB; \ + using ElementSFA = cutlass::float_ue4m3_t; \ + using ElementSFB = cutlass::float_ue4m3_t; \ + using ElementC = void; \ + using ElementD = typename Gemm::ElementD; \ + using ElementCompute = float; \ + \ + typename Gemm::Arguments operator_args; \ + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGemm; \ + auto& fusion_args = operator_args.epilogue.thread; \ + fusion_args.alpha_ptr = static_cast(global_sf); \ + \ + operator_args.problem_shape = cute::make_shape(m, n, k, batch_count); \ + \ + operator_args.mainloop.ptr_A = static_cast(A); \ + operator_args.mainloop.ptr_B = static_cast(B); \ + operator_args.mainloop.ptr_SFA = static_cast(input_sf); \ + operator_args.mainloop.ptr_SFB = static_cast(weight_sf); \ + operator_args.epilogue.ptr_C = static_cast(D); \ + operator_args.epilogue.ptr_D = static_cast(D); \ + \ + int const stride_A = batch_count == 1 ? 0 : m * k; \ + int const stride_B = batch_count == 1 ? 0 : n * k; \ + int const stride_C = batch_count == 1 ? 0 : m * n; \ + \ + operator_args.mainloop.dA = \ + cute::make_int_tuple_from(k, stride_A); \ + operator_args.mainloop.dB = \ + cute::make_int_tuple_from(k, stride_B); \ + operator_args.epilogue.dC = \ + cute::make_int_tuple_from(n, stride_C); \ + operator_args.epilogue.dD = operator_args.epilogue.dC; \ + \ + operator_args.mainloop.layout_SFA = \ + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); \ + operator_args.mainloop.layout_SFB = \ + Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); \ + \ + if constexpr (!std::is_const_v) { \ + operator_args.scheduler.max_swizzle_size = 1; \ + } \ + if constexpr (!std::is_const_v) { \ + using Enum_t = decltype(operator_args.scheduler.raster_order); \ + operator_args.scheduler.raster_order = Enum_t::Heuristic; \ + } \ + operator_args.hw_info.cluster_shape = dim3(CGA_M_, CGA_N_, CGA_K_); \ + operator_args.hw_info.cluster_shape_fallback = dim3(SMTypeAdapter_sm103::Scale, 1, 1); \ + \ + return operator_args; \ + } \ + \ + template <> \ + size_t genericFp4UltraGemmKernelLauncher, cute::Int, \ + cute::Int, cute::Int, \ + cute::Int, cute::Int, XSM_>( \ + void* D, void const* A, void const* B, void const* input_sf, void const* weight_sf, \ + float const* global_sf, int m, int n, int k, int batch_count, CutlassGemmConfig gemmConfig, \ + char* workspace, const size_t workspaceBytes, cudaStream_t stream, int* occupancy) { \ + using ElementOutput__ = \ + typename cutlass::platform::conditional::value, \ + cutlass::half_t, T>::type; \ + using ElementOutput_ = typename cutlass::platform::conditional< \ + cutlass::platform::is_same::value, float, ElementOutput__>::type; \ + using ElementOutput = typename cutlass::platform::conditional< \ + cutlass::platform::is_same::value, cutlass::bfloat16_t, \ + ElementOutput_>::type; \ + \ + using Fp4GemmOperator = \ + DeviceGemmFp4GemmSm103_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_:: \ + Gemm; \ + Fp4GemmOperator gemm; \ + auto args = \ + prepareGemmArgsSm103_##T##_##CTA_M_##_##CTA_N_##_##CTA_K_##_##CGA_M_##_##CGA_N_##_##CGA_K_##XSM_< \ + Fp4GemmOperator>(D, A, B, input_sf, weight_sf, global_sf, m, n, k, batch_count); \ + /* // Return workspace size */ \ + if (!A && !B && !D) { \ + return gemm.get_workspace_size(args); \ + } \ + if (gemm.get_workspace_size(args) > workspaceBytes) { \ + std::string errMsg("Requested workspace size insufficient. Required " + \ + std::to_string(gemm.get_workspace_size(args)) + ", got " + \ + std::to_string(workspaceBytes)); \ + throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ + } \ + auto can_implement = gemm.can_implement(args); \ + if (can_implement != cutlass::Status::kSuccess) { \ + std::string errMsg = "FP4 Gemm cutlass kernel will fail for params. Error: " + \ + std::string(cutlassGetStatusString(can_implement)); \ + throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ + } \ + auto initStatus = gemm.initialize(args, workspace, stream); \ + if (initStatus != cutlass::Status::kSuccess) { \ + std::string errMsg = "Failed to initialize cutlass FP4 gemm on sm103. Error: " + \ + std::string(cutlassGetStatusString(initStatus)); \ + throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ + } \ + auto runStatus = gemm.run(args, workspace, stream, nullptr, /*enablePDL=*/true); \ + if (runStatus != cutlass::Status::kSuccess) { \ + std::string errMsg = "Failed to run cutlass FP4 gemm on sm103. Error: " + \ + std::string(cutlassGetStatusString(runStatus)); \ + throw std::runtime_error("[FP4 gemm Runner] " + errMsg); \ + } \ + return gemm.get_workspace_size(args); \ + } + +#endif + +} // namespace gemm +} // namespace flashinfer +#endif // FLASHINFER_FP4_GEMM_TEMPLATE_SM103_H_