diff --git a/3rdparty/cutlass b/3rdparty/cutlass index ff61a49dd1a7..bbe579a9e3be 160000 --- a/3rdparty/cutlass +++ b/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit ff61a49dd1a728a96e9a8434ed408a2a52d73119 +Subproject commit bbe579a9e3beb6ea6626d9227ec32d0dae119a49 diff --git a/CMakeLists.txt b/CMakeLists.txt index c9d836b6812c..906509004a23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -369,6 +369,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/minrpc/*.cc src/runtime/relax_vm/*.cc ) +set(TVM_RUNTIME_EXT_OBJS "") if(BUILD_FOR_HEXAGON) if(NOT BUILD_STATIC_RUNTIME) @@ -595,18 +596,32 @@ add_library(tvm_libinfo_objs OBJECT ${LIBINFO_FILE}) include(GNUInstallDirs) if(NOT BUILD_DUMMY_LIBTVM) - add_library(tvm SHARED $ $ $) + add_library(tvm SHARED + $ + $ + $ + ${TVM_RUNTIME_EXT_OBJS} + ) + else() # dummy version of libtvm that can be used by downstream to specify dependencies # the real runner still need a full version of libtvm - add_library(tvm SHARED $ $) + add_library(tvm SHARED + $ + $ + ${TVM_RUNTIME_EXT_OBJS} + ) endif() target_include_directories(tvm PUBLIC "$") set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") set_property(TARGET tvm APPEND PROPERTY LINK_OPTIONS "${TVM_VISIBILITY_FLAG}") if(BUILD_STATIC_RUNTIME) - add_library(tvm_runtime STATIC $ $) + add_library(tvm_runtime STATIC + $ + $ + ${TVM_RUNTIME_EXT_OBJS} + ) set(NOTICE_MULTILINE "You have build static version of the TVM runtime library. Make " "sure to use --whole-archive when linking it into your project.") @@ -614,7 +629,11 @@ if(BUILD_STATIC_RUNTIME) add_custom_command(TARGET tvm_runtime POST_BUILD COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --yellow --bold ${NOTICE}) else() - add_library(tvm_runtime SHARED $ $) + add_library(tvm_runtime SHARED + $ + $ + ${TVM_RUNTIME_EXT_OBJS} + ) set_property(TARGET tvm_runtime APPEND PROPERTY LINK_OPTIONS "${TVM_NO_UNDEFINED_SYMBOLS}") endif() diff --git a/cmake/modules/contrib/CUTLASS.cmake b/cmake/modules/contrib/CUTLASS.cmake index 9ce27820b8f2..fa4a608f6161 100644 --- a/cmake/modules/contrib/CUTLASS.cmake +++ b/cmake/modules/contrib/CUTLASS.cmake @@ -16,16 +16,59 @@ # under the License. if(USE_CUDA AND USE_CUTLASS) - tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc src/relax/backend/contrib/cutlass/*.cc) + set(CUTLASS_GEN_COND "$,$>") + set(CUTLASS_RUNTIME_OBJS "") + + tvm_file_glob(GLOB CUTLASS_CONTRIB_SRC + src/relay/backend/contrib/cutlass/*.cc + src/relax/backend/contrib/cutlass/*.cc + ) list(APPEND COMPILER_SRCS ${CUTLASS_CONTRIB_SRC}) set(FPA_INTB_GEMM_TVM_BINDING ON) set(FPA_INTB_GEMM_TVM_HOME ${PROJECT_SOURCE_DIR}) - set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) + ### Build cutlass runtime objects for fpA_intB_gemm using its cutlass submodule add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm) + target_include_directories(fpA_intB_gemm PRIVATE + ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm + ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include + ) + set(CUTLASS_FPA_INTB_RUNTIME_SRCS "") + list(APPEND CUTLASS_FPA_INTB_RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) + add_library(fpA_intB_cutlass_objs OBJECT ${CUTLASS_FPA_INTB_RUNTIME_SRCS}) + target_compile_definitions(fpA_intB_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) + target_include_directories(fpA_intB_cutlass_objs PRIVATE + ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm + ${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass/include + ) + list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$>") + + ### Build cutlass runtime objects for flash attention add_subdirectory(${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn) - list(APPEND RUNTIME_SRCS src/runtime/contrib/cutlass/weight_preprocess.cc) + target_include_directories(flash_attn PRIVATE + ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn + ${PROJECT_SOURCE_DIR}/3rdparty/libflash_attn/cutlass/include + ) + + ### Build cutlass runtime objects using TVM's 3rdparty/cutlass submodule + set(CUTLASS_DIR ${PROJECT_SOURCE_DIR}/3rdparty/cutlass) + set(TVM_CUTLASS_RUNTIME_SRCS "") + + if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a") + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu) + list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu) + endif() + if(TVM_CUTLASS_RUNTIME_SRCS) + add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS}) + target_compile_options(tvm_cutlass_objs PRIVATE $<$:--expt-relaxed-constexpr>) + target_include_directories(tvm_cutlass_objs PRIVATE ${CUTLASS_DIR}/include) + target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=) + list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$>") + endif() + + ### Add cutlass objects to list of TVM runtime extension objs + list(APPEND TVM_RUNTIME_EXT_OBJS "${CUTLASS_RUNTIME_OBJS}") message(STATUS "Build with CUTLASS") endif() diff --git a/src/runtime/contrib/cutlass/fp16_group_gemm.cu b/src/runtime/contrib/cutlass/fp16_group_gemm.cu new file mode 100644 index 000000000000..3c051819b232 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp16_group_gemm.cu @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "group_gemm_runner.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; + +namespace tvm { +namespace runtime { + +template +void tvm_cutlass_group_gemm_sm90(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + CHECK_EQ(x->ndim, 2); + CHECK_EQ(weight->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + int num_groups = weight->shape[0]; + int n = weight->shape[1]; + int k = weight->shape[2]; + float alpha = 1.0f; + float beta = 0.0f; + cudaStream_t stream = static_cast((*func)().operator void*()); + cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, alpha, beta, + static_cast(out->data), stream); +} + +TVM_REGISTER_GLOBAL("cutlass.group_gemm_fp16_sm90") + .set_body_typed(tvm_cutlass_group_gemm_sm90); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/fp8_group_gemm.cu b/src/runtime/contrib/cutlass/fp8_group_gemm.cu new file mode 100644 index 000000000000..c93da6ff5766 --- /dev/null +++ b/src/runtime/contrib/cutlass/fp8_group_gemm.cu @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "group_gemm_runner.cuh" + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +template <> +struct KernelTraits { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; + using TileShape = Shape<_128, _256, _64>; // Threadblock-level tile size + using ClusterShape = Shape<_2, _2, _1>; // Shape of the threadblocks in a cluster +}; + +template <> +struct KernelTraits : KernelTraits {}; + +namespace tvm { +namespace runtime { + +template +void tvm_cutlass_fp8_group_gemm(NDArray x, NDArray weight, NDArray indptr, NDArray workspace, + NDArray alpha, NDArray out) { + // Workspace is used for storing device-side group gemm arguments and cutlass internal workspace. + // Recommened size is 4MB. + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + CHECK_EQ(x->ndim, 2); + CHECK_EQ(weight->ndim, 3); + CHECK_EQ(indptr->ndim, 1); + CHECK_EQ(workspace->ndim, 1); + CHECK_EQ(out->ndim, 2); + CHECK_EQ(alpha->dtype.code, kDLFloat); + CHECK_EQ(alpha->dtype.bits, 32); + int num_groups = weight->shape[0]; + int n = weight->shape[1]; + int k = weight->shape[2]; + const float* beta = nullptr; + cudaStream_t stream = static_cast((*func)().operator void*()); + cutlass_group_gemm(static_cast(x->data), static_cast(weight->data), + static_cast(indptr->data), static_cast(workspace->data), + workspace->shape[0], n, k, num_groups, static_cast(alpha->data), beta, + static_cast(out->data), stream); +} + +TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e5m2_fp16") + .set_body_typed( + tvm_cutlass_fp8_group_gemm); + +TVM_REGISTER_GLOBAL("cutlass.group_gemm_e5m2_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_group_gemm); + +TVM_REGISTER_GLOBAL("cutlass.group_gemm_e4m3_e4m3_fp16") + .set_body_typed( + tvm_cutlass_fp8_group_gemm); + +} // namespace runtime +} // namespace tvm + +#endif // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED diff --git a/src/runtime/contrib/cutlass/group_gemm_runner.cuh b/src/runtime/contrib/cutlass/group_gemm_runner.cuh new file mode 100644 index 000000000000..50bdcf7becfa --- /dev/null +++ b/src/runtime/contrib/cutlass/group_gemm_runner.cuh @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../../cuda/cuda_common.h" + +// clang-format off +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +// clang-format on + +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group + +inline size_t aligned(size_t value, size_t alignment = 16) { + return (value + alignment - 1) / alignment * alignment; +} + +template +struct KernelTraits; + +template +struct CutlassGroupGemmRunner { + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements + // (up to 16 bytes) + + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements + // (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ScaleType = std::variant; + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = typename KernelTraits::TileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + using KernelSchedule = typename KernelTraits::KernelSchedule; // Kernel to launch + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, + EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Gemm::GemmKernel::UnderlyingStrideA; + using StrideB = typename Gemm::GemmKernel::UnderlyingStrideB; + using StrideC = typename Gemm::GemmKernel::UnderlyingStrideC; + using StrideD = typename Gemm::GemmKernel::UnderlyingStrideD; + + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, + ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, + StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, + uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha, + ScaleType beta, cudaStream_t stream) { + typename Gemm::EpilogueOutputOp::Params epilogue_params = [&]() { + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; + if (std::holds_alternative(alpha)) { + return typename Gemm::EpilogueOutputOp::Params{std::get(alpha), + std::get(beta)}; + } else if (std::holds_alternative(alpha)) { + return typename Gemm::EpilogueOutputOp::Params{std::get(alpha), + std::get(beta)}; + } else { + LOG(FATAL) << "Unsupported alpha and beta type"; + throw; + } + }(); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_groups, problem_sizes, problem_sizes_host}, + {ptr_A, stride_A, ptr_B, stride_B}, + {epilogue_params, ptr_C, stride_C, ptr_D, stride_D}, + hw_info}; + Gemm gemm_op; + CUTLASS_CHECK(gemm_op.can_implement(arguments)); + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); + CUTLASS_CHECK(gemm_op.run()); + } +}; + +template +__global__ void prepare_group_gemm_arguments( + const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, + StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, + int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) { + int group_id = threadIdx.x; + if (group_id >= num_groups) return; + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; + ptr_A[group_id] = x + prev_rows * k; + ptr_B[group_id] = weight + group_id * k * n; + ptr_D[group_id] = out + prev_rows * n; + problem_sizes[group_id] = {static_cast(indptr[group_id] - prev_rows), static_cast(n), + static_cast(k)}; + stride_A[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); + stride_B[group_id] = cute::make_stride(k, Int<1>{}, int64_t{0}); + stride_D[group_id] = cute::make_stride(n, Int<1>{}, int64_t{0}); +} + +template +void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, + std::variant alpha, + std::variant beta, ElementC* out, + cudaStream_t stream) { + using Runner = CutlassGroupGemmRunner; + using StrideA = typename Runner::StrideA; + using StrideB = typename Runner::StrideB; + using StrideC = typename Runner::StrideC; + + Runner runner; + std::ptrdiff_t offset = 0; + const ElementA** ptr_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementA*) * num_groups); + const ElementB** ptr_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementB*) * num_groups); + ElementC** ptr_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(ElementC*) * num_groups); + typename ProblemShape::UnderlyingProblemShape* problem_sizes = + reinterpret_cast(workspace + offset); + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); + StrideA* stride_A = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideA) * num_groups); + StrideB* stride_B = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideB) * num_groups); + StrideC* stride_D = reinterpret_cast(workspace + offset); + offset += aligned(sizeof(StrideC) * num_groups); + prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes, + stride_A, stride_B, stride_D, x, + weight, out, indptr, n, k, num_groups); + offset = aligned(offset, 256); + runner.run_group_gemm(ptr_A, ptr_B, const_cast(ptr_D), ptr_D, problem_sizes, + nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, + workspace_size - offset, num_groups, alpha, beta, stream); +} diff --git a/src/runtime/contrib/cutlass/weight_preprocess.cc b/src/runtime/contrib/cutlass/weight_preprocess.cc index 4b378fa4a739..5fded82762a3 100644 --- a/src/runtime/contrib/cutlass/weight_preprocess.cc +++ b/src/runtime/contrib/cutlass/weight_preprocess.cc @@ -21,7 +21,7 @@ #include #include -#include "../../../3rdparty/cutlass_fpA_intB_gemm/cutlass_kernels/cutlass_preprocessors.h" +#include "cutlass_kernels/cutlass_preprocessors.h" namespace tvm { namespace runtime { diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 6eaf10c2ab6a..154a68e1169c 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -17,6 +17,7 @@ import logging import tempfile import math +import ml_dtypes import tvm from tvm import relay from tvm.contrib.cudnn import conv_output_shape @@ -32,6 +33,7 @@ finalize_modules, finalize_modules_vm, ) +from tvm.contrib.pickle_memoize import memoize import tvm.testing logging.basicConfig(level=logging.INFO) @@ -1105,5 +1107,101 @@ def test_dense_transpose_dense(): verify_dense_transpose_dense(get_dense_transpose_dense(M, N, K), M, N, K) +def verify_group_gemm( + func_name, M, N, K, num_groups, x_dtype, weight_dtype, out_dtype, use_scale, rtol, atol +): + group_gemm_func = tvm.get_global_func(func_name, allow_missing=True) + if group_gemm_func is None: + print(f"Skipped as {func_name} is not available") + return + + @memoize("tvm.contrib.cutlass.test_group_gemm_sm90") + def get_ref_data(): + assert M % num_groups == 0 + M_per_group = M // num_groups + a_np = get_random_ndarray((M, K), "float16") + b_np = get_random_ndarray((num_groups, N, K), "float16") + indptr_np = np.arange(1, num_groups + 1).astype("int64") * M_per_group + c_np = np.concatenate( + [a_np[i * M_per_group : (i + 1) * M_per_group] @ b_np[i].T for i in range(num_groups)], + axis=0, + ) + return a_np, b_np, indptr_np, c_np + + def to_numpy_dtype(dtype): + mapping = {"e5m2_float8": ml_dtypes.float8_e5m2, "e4m3_float8": ml_dtypes.float8_e4m3fn} + return mapping.get(dtype, dtype) + + a_np, b_np, indptr_np, c_np = get_ref_data() + dev = tvm.cuda(0) + a_nd = tvm.nd.array(a_np.astype(to_numpy_dtype(x_dtype)), device=dev) + b_nd = tvm.nd.array(b_np.astype(to_numpy_dtype(weight_dtype)), device=dev) + c_nd = tvm.nd.empty(c_np.shape, dtype=out_dtype, device=dev) + indptr_nd = tvm.nd.array(indptr_np, device=dev) + workspace = tvm.nd.empty((4096 * 1024,), dtype="uint8", device=dev) + if use_scale: + scale = tvm.nd.array(np.array([1.0], dtype="float32"), device=dev) + group_gemm_func(a_nd, b_nd, indptr_nd, workspace, scale, c_nd) + else: + group_gemm_func(a_nd, b_nd, indptr_nd, workspace, c_nd) + tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=rtol, atol=atol) + + +@tvm.testing.requires_cutlass +def test_group_gemm_sm90(): + verify_group_gemm( + "cutlass.group_gemm_fp16_sm90", + 8, + 128, + 128, + 4, + "float16", + "float16", + "float16", + False, + rtol=1e-3, + atol=1e-3, + ) + verify_group_gemm( + "cutlass.group_gemm_e5m2_e5m2_fp16", + 8, + 16, + 16, + 4, + "e5m2_float8", + "e5m2_float8", + "float16", + True, + rtol=1e-1, + atol=1, + ) + verify_group_gemm( + "cutlass.group_gemm_e4m3_e4m3_fp16", + 8, + 16, + 16, + 4, + "e4m3_float8", + "e4m3_float8", + "float16", + True, + rtol=1e-1, + atol=1, + ) + verify_group_gemm( + "cutlass.group_gemm_e4m3_e5m2_fp16", + 8, + 16, + 16, + 4, + "e4m3_float8", + "e5m2_float8", + "float16", + True, + rtol=1e-1, + atol=1, + ) + + if __name__ == "__main__": tvm.testing.main()