Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 530 files
16 changes: 12 additions & 4 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,27 @@ if(USE_CUDA AND USE_CUTLASS)
set(TVM_CUTLASS_RUNTIME_SRCS "")

if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu)
endif()
if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a")
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu)
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu)
endif()
if(TVM_CUTLASS_RUNTIME_SRCS)
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-lineinfo --expt-relaxed-constexpr>)
target_include_directories(tvm_cutlass_objs PRIVATE
${CUTLASS_DIR}/include
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
)
target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header)
target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
# Note: enable this to get more detailed logs for cutlass kernels
# target_compile_definitions(tvm_cutlass_objs PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=2)
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
endif()

Expand Down
72 changes: 72 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <cuda_fp16.h>
#include <float.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/ndarray.h>

#include "cutlass/bfloat16.h"
#include "cutlass/half.h"

namespace tvm {
namespace runtime {

template <int Arch, typename ElementA, typename ElementB, typename ElementC>
struct CutlassGroupGemm;

template <int Arch>
void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
NDArray out) {
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
// Recommened size is 4MB.
static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
CHECK_EQ(x->ndim, 2);
CHECK_EQ(weight->ndim, 3);
CHECK_EQ(indptr->ndim, 1);
CHECK_EQ(workspace->ndim, 1);
CHECK_EQ(out->ndim, 2);
int num_groups = weight->shape[0];
int n = weight->shape[1];
int k = weight->shape[2];
float alpha = 1.0f;
float beta = 0.0f;
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());

if (DataType(x->dtype) == DataType::Float(16)) {
CHECK(DataType(weight->dtype) == DataType::Float(16));
CHECK(DataType(out->dtype) == DataType::Float(16));
using Dtype = cutlass::half_t;
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, alpha, beta, static_cast<Dtype*>(out->data), stream);
} else if (DataType(x->dtype) == DataType::BFloat(16)) {
CHECK(DataType(weight->dtype) == DataType::BFloat(16));
CHECK(DataType(out->dtype) == DataType::BFloat(16));
using Dtype = cutlass::bfloat16_t;
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
workspace->shape[0], n, k, num_groups, alpha, beta, static_cast<Dtype*>(out->data), stream);
}
}

} // namespace runtime
} // namespace tvm
221 changes: 221 additions & 0 deletions src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <fstream>
#include <iostream>
#include <sstream>
#include <variant>
#include <vector>

#include "../../cuda/cuda_common.h"

// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/group_array_problem_shape.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on

#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
CHECK(error == cutlass::Status::kSuccess) \
<< "Got cutlass error: " << cutlassGetStatusString(error); \
}

using namespace cute;
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group

inline size_t aligned(size_t value, size_t alignment = 16) {
return (value + alignment - 1) / alignment * alignment;
}

template <typename ElementA>
struct MMA1SMConfig {
using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>;
using ClusterShape = Shape<_2, _2, _1>;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
};

template <typename ElementA>
struct MMA2SMConfig {
using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>;
using ClusterShape = Shape<_2, _2, _1>;
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
};

template <typename ScheduleConfig, typename ElementA, typename ElementB, typename ElementC,
typename LayoutA = cutlass::layout::RowMajor,
typename LayoutB = cutlass::layout::ColumnMajor,
typename LayoutC = cutlass::layout::RowMajor>
struct CutlassGroupGemmRunner {
static constexpr int AlignmentA =
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentB =
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
// (up to 16 bytes)

static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
// (up to 16 bytes)

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>;
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using StageCountType =
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size

// Different configs for 1SM and 2SM MMA kernel
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm100, OperatorClass, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*,
AlignmentC, typename ScheduleConfig::EpilogueSchedule,
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*,
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
typename ScheduleConfig::ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
typename ScheduleConfig::KernelSchedule>::CollectiveOp;

using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::InternalStrideA;
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
using StrideD = typename Gemm::GemmKernel::InternalStrideD;

void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C,
ElementC** ptr_D,
typename ProblemShape::UnderlyingProblemShape* problem_sizes,
typename ProblemShape::UnderlyingProblemShape* problem_sizes_host,
StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D,
uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha,
ScaleType beta, cudaStream_t stream) {
typename Gemm::Arguments arguments;
decltype(arguments.epilogue.thread) fusion_args;
[&]() {
ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type";
if (std::holds_alternative<ElementAccumulator>(alpha)) {
fusion_args.alpha = std::get<ElementAccumulator>(alpha);
fusion_args.beta = std::get<ElementAccumulator>(beta);
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
fusion_args.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
fusion_args.beta_ptr = std::get<const ElementAccumulator*>(beta);
} else {
LOG(FATAL) << "Unsupported alpha and beta type";
throw;
}
}();

cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0;
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
arguments = typename Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
{num_groups, problem_sizes, problem_sizes_host},
{ptr_A, stride_A, ptr_B, stride_B},
{fusion_args, ptr_C, stride_C, ptr_D, stride_D},
hw_info};
Gemm gemm_op;
CUTLASS_CHECK(gemm_op.can_implement(arguments));
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
CUTLASS_CHECK(gemm_op.run(stream));
}
};

template <typename ElementA, typename ElementB, typename ElementC, typename StrideA,
typename StrideB, typename StrideC>
__global__ void prepare_group_gemm_arguments(
const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A,
StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out,
int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) {
int group_id = threadIdx.x;
if (group_id >= num_groups) return;
int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1];
ptr_A[group_id] = x + prev_rows * k;
ptr_B[group_id] = weight + group_id * k * n;
ptr_D[group_id] = out + prev_rows * n;
problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), static_cast<int>(n),
static_cast<int>(k)};
stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{});
}

template <typename ElementA, typename ElementB, typename ElementC>
void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
std::variant<float, const float*> alpha,
std::variant<float, const float*> beta, ElementC* out,
cudaStream_t stream) {
// Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if needed.
using Runner = CutlassGroupGemmRunner<MMA2SMConfig<ElementA>, ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
using StrideC = typename Runner::StrideC;

Runner runner;
std::ptrdiff_t offset = 0;
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + offset);
offset += aligned(sizeof(ElementA*) * num_groups);
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + offset);
offset += aligned(sizeof(ElementB*) * num_groups);
ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
offset += aligned(sizeof(ElementC*) * num_groups);
typename ProblemShape::UnderlyingProblemShape* problem_sizes =
reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(workspace + offset);
offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups);
StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
offset += aligned(sizeof(StrideA) * num_groups);
StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
offset += aligned(sizeof(StrideB) * num_groups);
StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
offset += aligned(sizeof(StrideC) * num_groups);
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes,
stride_A, stride_B, stride_D, x,
weight, out, indptr, n, k, num_groups);
offset = aligned(offset, 256);
runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes,
nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset,
workspace_size - offset, num_groups, alpha, beta, stream);
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ __global__ void prepare_group_gemm_arguments(
}

template <typename ElementA, typename ElementB, typename ElementC>
void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
std::variant<float, const float*> alpha,
std::variant<float, const float*> beta, ElementC* out,
cudaStream_t stream) {
void cutlass_group_gemm_sm90(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
std::variant<float, const float*> alpha,
std::variant<float, const float*> beta, ElementC* out,
cudaStream_t stream) {
using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
using StrideA = typename Runner::StrideA;
using StrideB = typename Runner::StrideB;
Expand Down
Loading
Loading