|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one |
| 3 | + * or more contributor license agreements. See the NOTICE file |
| 4 | + * distributed with this work for additional information |
| 5 | + * regarding copyright ownership. The ASF licenses this file |
| 6 | + * to you under the Apache License, Version 2.0 (the |
| 7 | + * "License"); you may not use this file except in compliance |
| 8 | + * with the License. You may obtain a copy of the License at |
| 9 | + * |
| 10 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | + * |
| 12 | + * Unless required by applicable law or agreed to in writing, |
| 13 | + * software distributed under the License is distributed on an |
| 14 | + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | + * KIND, either express or implied. See the License for the |
| 16 | + * specific language governing permissions and limitations |
| 17 | + * under the License. |
| 18 | + */ |
| 19 | + |
| 20 | +#include <fstream> |
| 21 | +#include <iostream> |
| 22 | +#include <sstream> |
| 23 | +#include <variant> |
| 24 | +#include <vector> |
| 25 | + |
| 26 | +#include "../../cuda/cuda_common.h" |
| 27 | + |
| 28 | +// clang-format off |
| 29 | +#include "cutlass/cutlass.h" |
| 30 | + |
| 31 | +#include "cute/tensor.hpp" |
| 32 | +#include "cutlass/tensor_ref.h" |
| 33 | +#include "cutlass/epilogue/collective/default_epilogue.hpp" |
| 34 | +#include "cutlass/epilogue/thread/linear_combination.h" |
| 35 | +#include "cutlass/gemm/dispatch_policy.hpp" |
| 36 | +#include "cutlass/gemm/group_array_problem_shape.hpp" |
| 37 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 38 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 39 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 40 | +#include "cutlass/gemm/kernel/gemm_universal.hpp" |
| 41 | +// clang-format on |
| 42 | + |
| 43 | +#define CUTLASS_CHECK(status) \ |
| 44 | + { \ |
| 45 | + cutlass::Status error = status; \ |
| 46 | + CHECK(error == cutlass::Status::kSuccess) \ |
| 47 | + << "Got cutlass error: " << cutlassGetStatusString(error); \ |
| 48 | + } |
| 49 | + |
| 50 | +using namespace cute; |
| 51 | +using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>; // <M,N,K> per group |
| 52 | + |
| 53 | +inline size_t aligned(size_t value, size_t alignment = 16) { |
| 54 | + return (value + alignment - 1) / alignment * alignment; |
| 55 | +} |
| 56 | + |
| 57 | +template <typename ElementA> |
| 58 | +struct MMA1SMConfig { |
| 59 | + using MmaTileShape = Shape<_128, _256, Int<128 / sizeof(ElementA)>>; |
| 60 | + using ClusterShape = Shape<_2, _2, _1>; |
| 61 | + using KernelSchedule = |
| 62 | + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100; // Kernel to launch |
| 63 | + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch |
| 64 | +}; |
| 65 | + |
| 66 | +template <typename ElementA> |
| 67 | +struct MMA2SMConfig { |
| 68 | + using MmaTileShape = Shape<_256, _256, Int<128 / sizeof(ElementA)>>; |
| 69 | + using ClusterShape = Shape<_2, _2, _1>; |
| 70 | + using KernelSchedule = |
| 71 | + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100; // Kernel to launch |
| 72 | + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; // Epilogue to launch |
| 73 | +}; |
| 74 | + |
| 75 | +template <typename ScheduleConfig, typename ElementA, typename ElementB, typename ElementC, |
| 76 | + typename LayoutA = cutlass::layout::RowMajor, |
| 77 | + typename LayoutB = cutlass::layout::ColumnMajor, |
| 78 | + typename LayoutC = cutlass::layout::RowMajor> |
| 79 | +struct CutlassGroupGemmRunner { |
| 80 | + static constexpr int AlignmentA = |
| 81 | + 128 / cutlass::sizeof_bits<ElementA>::value; // Alignment of A matrix in units of elements |
| 82 | + // (up to 16 bytes) |
| 83 | + |
| 84 | + static constexpr int AlignmentB = |
| 85 | + 128 / cutlass::sizeof_bits<ElementB>::value; // Alignment of B matrix in units of elements |
| 86 | + // (up to 16 bytes) |
| 87 | + |
| 88 | + static constexpr int AlignmentC = |
| 89 | + 128 / cutlass::sizeof_bits<ElementC>::value; // Alignment of C matrix in units of elements |
| 90 | + // (up to 16 bytes) |
| 91 | + |
| 92 | + // Core kernel configurations |
| 93 | + using ElementAccumulator = float; // Element type for internal accumulation |
| 94 | + using ScaleType = std::variant<ElementAccumulator, const ElementAccumulator*>; |
| 95 | + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag |
| 96 | + using StageCountType = |
| 97 | + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size |
| 98 | + |
| 99 | + // Different configs for 1SM and 2SM MMA kernel |
| 100 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 101 | + cutlass::arch::Sm100, OperatorClass, typename ScheduleConfig::MmaTileShape, |
| 102 | + typename ScheduleConfig::ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, |
| 103 | + ElementAccumulator, ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, |
| 104 | + AlignmentC, typename ScheduleConfig::EpilogueSchedule, |
| 105 | + cutlass::epilogue::fusion::LinearCombination<ElementC, ElementAccumulator>>::CollectiveOp; |
| 106 | + |
| 107 | + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< |
| 108 | + cutlass::arch::Sm100, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, |
| 109 | + AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape, |
| 110 | + typename ScheduleConfig::ClusterShape, |
| 111 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>( |
| 112 | + sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 113 | + typename ScheduleConfig::KernelSchedule>::CollectiveOp; |
| 114 | + |
| 115 | + using GemmKernel = |
| 116 | + cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop, CollectiveEpilogue>; |
| 117 | + |
| 118 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 119 | + |
| 120 | + using StrideA = typename Gemm::GemmKernel::InternalStrideA; |
| 121 | + using StrideB = typename Gemm::GemmKernel::InternalStrideB; |
| 122 | + using StrideC = typename Gemm::GemmKernel::InternalStrideC; |
| 123 | + using StrideD = typename Gemm::GemmKernel::InternalStrideD; |
| 124 | + |
| 125 | + void run_group_gemm(const ElementA** ptr_A, const ElementB** ptr_B, const ElementC** ptr_C, |
| 126 | + ElementC** ptr_D, |
| 127 | + typename ProblemShape::UnderlyingProblemShape* problem_sizes, |
| 128 | + typename ProblemShape::UnderlyingProblemShape* problem_sizes_host, |
| 129 | + StrideA* stride_A, StrideB* stride_B, StrideC* stride_C, StrideD* stride_D, |
| 130 | + uint8_t* workspace, int64_t workspace_size, int num_groups, ScaleType alpha, |
| 131 | + ScaleType beta, cudaStream_t stream) { |
| 132 | + typename Gemm::Arguments arguments; |
| 133 | + decltype(arguments.epilogue.thread) fusion_args; |
| 134 | + [&]() { |
| 135 | + ICHECK(alpha.index() == beta.index()) << "alpha and beta must have the same type"; |
| 136 | + if (std::holds_alternative<ElementAccumulator>(alpha)) { |
| 137 | + fusion_args.alpha = std::get<ElementAccumulator>(alpha); |
| 138 | + fusion_args.beta = std::get<ElementAccumulator>(beta); |
| 139 | + } else if (std::holds_alternative<const ElementAccumulator*>(alpha)) { |
| 140 | + fusion_args.alpha_ptr = std::get<const ElementAccumulator*>(alpha); |
| 141 | + fusion_args.beta_ptr = std::get<const ElementAccumulator*>(beta); |
| 142 | + } else { |
| 143 | + LOG(FATAL) << "Unsupported alpha and beta type"; |
| 144 | + throw; |
| 145 | + } |
| 146 | + }(); |
| 147 | + |
| 148 | + cutlass::KernelHardwareInfo hw_info; |
| 149 | + hw_info.device_id = 0; |
| 150 | + hw_info.sm_count = |
| 151 | + cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); |
| 152 | + arguments = typename Gemm::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, |
| 153 | + {num_groups, problem_sizes, problem_sizes_host}, |
| 154 | + {ptr_A, stride_A, ptr_B, stride_B}, |
| 155 | + {fusion_args, ptr_C, stride_C, ptr_D, stride_D}, |
| 156 | + hw_info}; |
| 157 | + Gemm gemm_op; |
| 158 | + CUTLASS_CHECK(gemm_op.can_implement(arguments)); |
| 159 | + CHECK_GE(workspace_size, gemm_op.get_workspace_size(arguments)); |
| 160 | + CUTLASS_CHECK(gemm_op.initialize(arguments, workspace, stream)); |
| 161 | + CUTLASS_CHECK(gemm_op.run(stream)); |
| 162 | + } |
| 163 | +}; |
| 164 | + |
| 165 | +template <typename ElementA, typename ElementB, typename ElementC, typename StrideA, |
| 166 | + typename StrideB, typename StrideC> |
| 167 | +__global__ void prepare_group_gemm_arguments( |
| 168 | + const ElementA** ptr_A, const ElementB** ptr_B, ElementC** ptr_D, |
| 169 | + typename ProblemShape::UnderlyingProblemShape* problem_sizes, StrideA* stride_A, |
| 170 | + StrideB* stride_B, StrideC* stride_D, const ElementA* x, const ElementB* weight, ElementC* out, |
| 171 | + int64_t* indptr, int64_t n, int64_t k, int64_t num_groups) { |
| 172 | + int group_id = threadIdx.x; |
| 173 | + if (group_id >= num_groups) return; |
| 174 | + int prev_rows = group_id == 0 ? 0 : indptr[group_id - 1]; |
| 175 | + ptr_A[group_id] = x + prev_rows * k; |
| 176 | + ptr_B[group_id] = weight + group_id * k * n; |
| 177 | + ptr_D[group_id] = out + prev_rows * n; |
| 178 | + problem_sizes[group_id] = {static_cast<int>(indptr[group_id] - prev_rows), static_cast<int>(n), |
| 179 | + static_cast<int>(k)}; |
| 180 | + stride_A[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); |
| 181 | + stride_B[group_id] = cute::make_stride(k, Int<1>{}, Int<0>{}); |
| 182 | + stride_D[group_id] = cute::make_stride(n, Int<1>{}, Int<0>{}); |
| 183 | +} |
| 184 | + |
| 185 | +template <typename ElementA, typename ElementB, typename ElementC> |
| 186 | +void cutlass_group_gemm_sm100(ElementA* x, ElementB* weight, int64_t* indptr, uint8_t* workspace, |
| 187 | + int64_t workspace_size, int64_t n, int64_t k, int64_t num_groups, |
| 188 | + std::variant<float, const float*> alpha, |
| 189 | + std::variant<float, const float*> beta, ElementC* out, |
| 190 | + cudaStream_t stream) { |
| 191 | + // Note: We use MMA2SMConfig for now. It can be changed to MMA1SMConfig if needed. |
| 192 | + using Runner = CutlassGroupGemmRunner<MMA2SMConfig<ElementA>, ElementA, ElementB, ElementC>; |
| 193 | + using StrideA = typename Runner::StrideA; |
| 194 | + using StrideB = typename Runner::StrideB; |
| 195 | + using StrideC = typename Runner::StrideC; |
| 196 | + |
| 197 | + Runner runner; |
| 198 | + std::ptrdiff_t offset = 0; |
| 199 | + const ElementA** ptr_A = reinterpret_cast<const ElementA**>(workspace + offset); |
| 200 | + offset += aligned(sizeof(ElementA*) * num_groups); |
| 201 | + const ElementB** ptr_B = reinterpret_cast<const ElementB**>(workspace + offset); |
| 202 | + offset += aligned(sizeof(ElementB*) * num_groups); |
| 203 | + ElementC** ptr_D = reinterpret_cast<ElementC**>(workspace + offset); |
| 204 | + offset += aligned(sizeof(ElementC*) * num_groups); |
| 205 | + typename ProblemShape::UnderlyingProblemShape* problem_sizes = |
| 206 | + reinterpret_cast<typename ProblemShape::UnderlyingProblemShape*>(workspace + offset); |
| 207 | + offset += aligned(sizeof(typename ProblemShape::UnderlyingProblemShape) * num_groups); |
| 208 | + StrideA* stride_A = reinterpret_cast<StrideA*>(workspace + offset); |
| 209 | + offset += aligned(sizeof(StrideA) * num_groups); |
| 210 | + StrideB* stride_B = reinterpret_cast<StrideB*>(workspace + offset); |
| 211 | + offset += aligned(sizeof(StrideB) * num_groups); |
| 212 | + StrideC* stride_D = reinterpret_cast<StrideC*>(workspace + offset); |
| 213 | + offset += aligned(sizeof(StrideC) * num_groups); |
| 214 | + prepare_group_gemm_arguments<<<1, num_groups, 0, stream>>>(ptr_A, ptr_B, ptr_D, problem_sizes, |
| 215 | + stride_A, stride_B, stride_D, x, |
| 216 | + weight, out, indptr, n, k, num_groups); |
| 217 | + offset = aligned(offset, 256); |
| 218 | + runner.run_group_gemm(ptr_A, ptr_B, const_cast<const ElementC**>(ptr_D), ptr_D, problem_sizes, |
| 219 | + nullptr, stride_A, stride_B, stride_D, stride_D, workspace + offset, |
| 220 | + workspace_size - offset, num_groups, alpha, beta, stream); |
| 221 | +} |
0 commit comments