Skip to content

Commit af75beb

Browse files
committed
[CUTLASS] Add GeMM kernels for Blackwell GPUs
This PR introduces CUTLASS gemm kernels, groupwise-scaled gemm kernels and group gemm kernels for Blackwell GPUs. Files are reorganized a bit so that the exposed global functions are now architecture agnostic. Prior to this PR, our global function names for CUTLASS kernels usually end with `"_sm90"`, which brings extra complexity when the frontend compiler decides to dispatch kernels when there are multiple supported architectures, such as Hopper and Blackwell. Therefore, this PR renames those global function so that the function names are arch agnostic. During the build time, only the kernels that the specific architecture supports will be built.
1 parent bb14b27 commit af75beb

18 files changed

+1220
-255
lines changed

3rdparty/cutlass

Submodule cutlass updated 530 files

cmake/modules/contrib/CUTLASS.cmake

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,27 @@ if(USE_CUDA AND USE_CUTLASS)
5858
set(TVM_CUTLASS_RUNTIME_SRCS "")
5959

6060
if (CMAKE_CUDA_ARCHITECTURES MATCHES "90a")
61-
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm.cu)
62-
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm.cu)
61+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm90.cu)
62+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_group_gemm_sm90.cu)
6363
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_gemm.cu)
64-
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_blockwise_scaled_gemm.cu)
64+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm90.cu)
65+
endif()
66+
if (CMAKE_CUDA_ARCHITECTURES MATCHES "100a")
67+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp16_group_gemm_sm100.cu)
68+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_gemm_sm100.cu)
69+
list(APPEND TVM_CUTLASS_RUNTIME_SRCS src/runtime/contrib/cutlass/fp8_groupwise_scaled_group_gemm_sm100.cu)
6570
endif()
6671
if(TVM_CUTLASS_RUNTIME_SRCS)
6772
add_library(tvm_cutlass_objs OBJECT ${TVM_CUTLASS_RUNTIME_SRCS})
68-
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>)
73+
target_compile_options(tvm_cutlass_objs PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-lineinfo --expt-relaxed-constexpr>)
6974
target_include_directories(tvm_cutlass_objs PRIVATE
7075
${CUTLASS_DIR}/include
7176
${PROJECT_SOURCE_DIR}/3rdparty/cutlass_fpA_intB_gemm/cutlass_extensions/include
7277
)
78+
target_link_libraries(tvm_cutlass_objs PRIVATE tvm_ffi_header)
7379
target_compile_definitions(tvm_cutlass_objs PRIVATE DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
80+
# Note: enable this to get more detailed logs for cutlass kernels
81+
# target_compile_definitions(tvm_cutlass_objs PRIVATE CUTLASS_DEBUG_TRACE_LEVEL=2)
7482
list(APPEND CUTLASS_RUNTIME_OBJS "$<${CUTLASS_GEN_COND}:$<TARGET_OBJECTS:tvm_cutlass_objs>>")
7583
endif()
7684

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <cuda_fp16.h>
21+
#include <float.h>
22+
#include <tvm/ffi/function.h>
23+
#include <tvm/runtime/ndarray.h>
24+
25+
#include "cutlass/bfloat16.h"
26+
#include "cutlass/half.h"
27+
28+
namespace tvm {
29+
namespace runtime {
30+
31+
template <int Arch, typename ElementA, typename ElementB, typename ElementC>
32+
struct CutlassGroupGemm;
33+
34+
template <int Arch>
35+
void tvm_cutlass_group_gemm_impl(NDArray x, NDArray weight, NDArray indptr, NDArray workspace,
36+
NDArray out) {
37+
// Workspace is used for storing device-side group gemm arguments and cutlass internal workspace.
38+
// Recommened size is 4MB.
39+
static auto func = tvm::ffi::Function::GetGlobalRequired("runtime.get_cuda_stream");
40+
CHECK_EQ(x->ndim, 2);
41+
CHECK_EQ(weight->ndim, 3);
42+
CHECK_EQ(indptr->ndim, 1);
43+
CHECK_EQ(workspace->ndim, 1);
44+
CHECK_EQ(out->ndim, 2);
45+
int num_groups = weight->shape[0];
46+
int n = weight->shape[1];
47+
int k = weight->shape[2];
48+
float alpha = 1.0f;
49+
float beta = 0.0f;
50+
cudaStream_t stream = static_cast<cudaStream_t>(func().cast<void*>());
51+
52+
if (DataType(x->dtype) == DataType::Float(16)) {
53+
CHECK(DataType(weight->dtype) == DataType::Float(16));
54+
CHECK(DataType(out->dtype) == DataType::Float(16));
55+
using Dtype = cutlass::half_t;
56+
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
57+
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
58+
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
59+
workspace->shape[0], n, k, num_groups, alpha, beta, static_cast<Dtype*>(out->data), stream);
60+
} else if (DataType(x->dtype) == DataType::BFloat(16)) {
61+
CHECK(DataType(weight->dtype) == DataType::BFloat(16));
62+
CHECK(DataType(out->dtype) == DataType::BFloat(16));
63+
using Dtype = cutlass::bfloat16_t;
64+
CutlassGroupGemm<Arch, Dtype, Dtype, Dtype>::run(
65+
static_cast<Dtype*>(x->data), static_cast<Dtype*>(weight->data),
66+
static_cast<int64_t*>(indptr->data), static_cast<uint8_t*>(workspace->data),
67+
workspace->shape[0], n, k, num_groups, alpha, beta, static_cast<Dtype*>(out->data), stream);
68+
}
69+
}
70+
71+
} // namespace runtime
72+
} // namespace tvm
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#include <fstream>
21+
#include <iostream>
22+
#include <sstream>
23+
#include <variant>
24+
#include <vector>
25+
26+
#include "../../cuda/cuda_common.h"
27+
28+
// clang-format off
29+
#include "cutlass/cutlass.h"
30+
31+
#include "cute/tensor.hpp"
32+
#include "cutlass/tensor_ref.h"
33+
#include "cutlass/epilogue/collective/default_epilogue.hpp"
34+
#include "cutlass/epilogue/thread/linear_combination.h"
35+
#include "cutlass/gemm/dispatch_policy.hpp"
36+
#include "cutlass/gemm/group_array_problem_shape.hpp"
37+
#include "cutlass/gemm/collective/collective_builder.hpp"
38+
#include "cutlass/epilogue/collective/collective_builder.hpp"
39+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
40+
#include "cutlass/gemm/kernel/gemm_universal.hpp"
41+
// clang-format on
42+
43+
#define CUTLASS_CHECK(status) \
44+
{ \
45+
cutlass::Status error = status; \
46+
CHECK(error == cutlass::Status::kSuccess) \
47+
<< "Got cutlass error: " << cutlassGetStatusString(error); \
48+
}
49+
50+
using namespace cute;
51+
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group
52+
53+
inline size_t aligned(size_t value, size_t alignment = 16) {
54+
return (value + alignment - 1) / alignment * alignment;
55+
}
56+
57+
template <typename ElementA>
58+
struct MMA1SMConfig {
59+
using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>;
60+
using ClusterShape = Shape<_2, _2, _1>;
61+
using KernelSchedule =
62+
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch
63+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch
64+
};
65+
66+
template <typename ElementA>
67+
struct MMA2SMConfig {
68+
using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>;
69+
using ClusterShape = Shape<_2, _2, _1>;
70+
using KernelSchedule =
71+
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch
72+
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch
73+
};
74+
75+
template <typename ScheduleConfig, typename ElementA, typename ElementB, typename ElementC,
76+
typename LayoutA = cutlass::layout::RowMajor,
77+
typename LayoutB = cutlass::layout::ColumnMajor,
78+
typename LayoutC = cutlass::layout::RowMajor>
79+
struct CutlassGroupGemmRunner {
80+
static constexpr int AlignmentA =
81+
128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements
82+
// (up to 16 bytes)
83+
84+
static constexpr int AlignmentB =
85+
128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements
86+
// (up to 16 bytes)
87+
88+
static constexpr int AlignmentC =
89+
128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements
90+
// (up to 16 bytes)
91+
92+
// Core kernel configurations
93+
using ElementAccumulator = float; // Element type for internal accumulation
94+
using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>;
95+
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
96+
using StageCountType =
97+
cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
98+
99+
// Different configs for 1SM and 2SM MMA kernel
100+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
101+
cutlass::arch::Sm100, OperatorClass, typename ScheduleConfig::MmaTileShape,
102+
typename ScheduleConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
103+
ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*,
104+
AlignmentC, typename ScheduleConfig::EpilogueSchedule,
105+
cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>::CollectiveOp;
106+
107+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
108+
cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*,
109+
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
110+
typename ScheduleConfig::ClusterShape,
111+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
112+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
113+
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
114+
115+
using GemmKernel =
116+
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>;
117+
118+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
119+
120+
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
121+
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
122+
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
123+
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
124+
125+
void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C,
126+
ElementC** ptr_D,
127+
typename ProblemShape::UnderlyingProblemShape* problem_sizes,
128+
typename ProblemShape::UnderlyingProblemShape* problem_sizes_host,
129+
StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D,
130+
uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha,
131+
ScaleType beta, cudaStream_t stream) {
132+
typename Gemm::Arguments arguments;
133+
decltype(arguments.epilogue.thread) fusion_args;
134+
[&]() {
135+
ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type";
136+
if (std::holds_alternative<ElementAccumulator>(alpha)) {
137+
fusion_args.alpha = std::get<ElementAccumulator>(alpha);
138+
fusion_args.beta = std::get<ElementAccumulator>(beta);
139+
} else if (std::holds_alternative<const ElementAccumulator*>(alpha)) {
140+
fusion_args.alpha_ptr = std::get<const ElementAccumulator*>(alpha);
141+
fusion_args.beta_ptr = std::get<const ElementAccumulator*>(beta);
142+
} else {
143+
LOG(FATAL) << "Unsupported alpha and beta type";
144+
throw;
145+
}
146+
}();
147+
148+
cutlass::KernelHardwareInfo hw_info;
149+
hw_info.device_id = 0;
150+
hw_info.sm_count =
151+
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
152+
arguments = typename Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped,
153+
{num_groups, problem_sizes, problem_sizes_host},
154+
{ptr_A, stride_A, ptr_B, stride_B},
155+
{fusion_args, ptr_C, stride_C, ptr_D, stride_D},
156+
hw_info};
157+
Gemm gemm_op;
158+
CUTLASS_CHECK(gemm_op.can_implement(arguments));
159+
CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments));
160+
CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream));
161+
CUTLASS_CHECK(gemm_op.run(stream));
162+
}
163+
};
164+
165+
template <typename ElementA, typename ElementB, typename ElementC, typename StrideA,
166+
typename StrideB, typename StrideC>
167+
__global__ void prepare_group_gemm_arguments(
168+
const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D,
169+
typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A,
170+
StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out,
171+
int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) {
172+
int group_id = threadIdx.x;
173+
if (group_id >= num_groups) return;
174+
int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1];
175+
ptr_A[group_id] = x + prev_rows * k;
176+
ptr_B[group_id] = weight + group_id * k * n;
177+
ptr_D[group_id] = out + prev_rows * n;
178+
problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), static_cast<int>(n),
179+
static_cast<int>(k)};
180+
stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
181+
stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{});
182+
stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{});
183+
}
184+
185+
template <typename ElementA, typename ElementB, typename ElementC>
186+
void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
187+
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
188+
std::variant<float, const float*> alpha,
189+
std::variant<float, const float*> beta, ElementC* out,
190+
cudaStream_t stream) {
191+
// Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if needed.
192+
using Runner = CutlassGroupGemmRunner<MMA2SMConfig<ElementA>, ElementA, ElementB, ElementC>;
193+
using StrideA = typename Runner::StrideA;
194+
using StrideB = typename Runner::StrideB;
195+
using StrideC = typename Runner::StrideC;
196+
197+
Runner runner;
198+
std::ptrdiff_t offset = 0;
199+
const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + offset);
200+
offset += aligned(sizeof(ElementA*) * num_groups);
201+
const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + offset);
202+
offset += aligned(sizeof(ElementB*) * num_groups);
203+
ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset);
204+
offset += aligned(sizeof(ElementC*) * num_groups);
205+
typename ProblemShape::UnderlyingProblemShape* problem_sizes =
206+
reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(workspace + offset);
207+
offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups);
208+
StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset);
209+
offset += aligned(sizeof(StrideA) * num_groups);
210+
StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset);
211+
offset += aligned(sizeof(StrideB) * num_groups);
212+
StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset);
213+
offset += aligned(sizeof(StrideC) * num_groups);
214+
prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes,
215+
stride_A, stride_B, stride_D, x,
216+
weight, out, indptr, n, k, num_groups);
217+
offset = aligned(offset, 256);
218+
runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes,
219+
nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset,
220+
workspace_size - offset, num_groups, alpha, beta, stream);
221+
}

src/runtime/contrib/cutlass/group_gemm_runner.cuh renamed to src/runtime/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,11 @@ __global__ void prepare_group_gemm_arguments(
169169
}
170170

171171
template <typename ElementA, typename ElementB, typename ElementC>
172-
void cutlass_group_gemm(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
173-
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
174-
std::variant<float, const float*> alpha,
175-
std::variant<float, const float*> beta, ElementC* out,
176-
cudaStream_t stream) {
172+
void cutlass_group_gemm_sm90(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace,
173+
int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups,
174+
std::variant<float, const float*> alpha,
175+
std::variant<float, const float*> beta, ElementC* out,
176+
cudaStream_t stream) {
177177
using Runner = CutlassGroupGemmRunner<ElementA, ElementB, ElementC>;
178178
using StrideA = typename Runner::StrideA;
179179
using StrideB = typename Runner::StrideB;

0 commit comments

Comments
 (0)