From febd76e4cad54c6f0cd6a958154da9f48287e4f8 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 20 Jul 2023 05:20:14 +0000 Subject: [PATCH 1/6] Temp save --- example/49_fpAintB_gemm/CMakeLists.txt | 5 + .../49_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 75 ++ example/49_fpAintB_gemm/run_gemm_example.inc | 167 +++ .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 555 +++++++++ .../gpu/block/blockwise_gemm_wmma.hpp | 4 +- .../gpu/device/device_gemm_dequantB.hpp | 46 + .../device/impl/device_fpAintB_gemm_wmma.hpp | 660 ++++++++++ .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 1104 +++++++++++++++++ .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 108 ++ 9 files changed, 2722 insertions(+), 2 deletions(-) create mode 100644 example/49_fpAintB_gemm/CMakeLists.txt create mode 100644 example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp create mode 100644 example/49_fpAintB_gemm/run_gemm_example.inc create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp diff --git a/example/49_fpAintB_gemm/CMakeLists.txt b/example/49_fpAintB_gemm/CMakeLists.txt new file mode 100644 index 00000000000..34059c7ff90 --- /dev/null +++ b/example/49_fpAintB_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") + add_custom_target(example_fpAintB_gemm_wmma) + add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) + add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) +endif() diff --git a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp new file mode 100644 index 00000000000..96e4f747816 --- /dev/null +++ b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" + +using ADataType = ck::half_t; +using BDataType = int8_t; +using ScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + ScaleDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/49_fpAintB_gemm/run_gemm_example.inc b/example/49_fpAintB_gemm/run_gemm_example.inc new file mode 100644 index 00000000000..7d06ec4cb01 --- /dev/null +++ b/example/49_fpAintB_gemm/run_gemm_example.inc @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp new file mode 100644 index 00000000000..283f5f87dae --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,555 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#define CK_MNK_LOOP + +namespace ck { + +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct Blockwise_fpAintB_GemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = AEnableLds ? 1 : 2; + static constexpr index_t B_KRow = BEnableLds ? 1 : 2; + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + const ScaleBlockBuffer& scale_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto scale_thread_buf = make_static_buffer( + scale_thread_desc_.GetElementSpaceSize()); + auto converted_b_thread_buf = b_thread_buf; + + static constexpr auto dequantizer = Dequantizer{}; + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / WmmaK, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read weight scale + scale_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_scale_thread_buf); + + // convert B from int8 to fp16 + converted_b_thread_buf = type_convert(b_thread_buf); + + // multiply scale + dequantize(converted_b_thread_buf, scale_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + converted_b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType()(Number<0>{}), + b_thread_vec.template AsType()(Number<0>{}), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + ADataType, + ADataType, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? false : true>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + BDataType, + BDataType, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? true : false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index 576a83f6b67..679da465dab 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -362,11 +362,11 @@ struct BlockwiseGemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, ... read B + // k=0,kpack*1, .. + // read B b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, make_tuple(Number{}, n0, I0, I0, I0, I0), diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp new file mode 100644 index 00000000000..acb18efabfe --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline +// As input tensor thread buffer declared inside blockwise-gemm pipeline. + +template +struct DeviceGemm_dequantB : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp new file mode 100644 index 00000000000..41ecbbb5321 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -0,0 +1,660 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) + +template +struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + // When K = 1, it might be scale tensor. + assert(K % K1 == 0 && K != 1 ); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using ScaleGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseFpAintBGemm_Wmma; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const ScaleDataType* p_scale, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_scale_grid_{p_scale}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_{}, + scale_grid_desc_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(1, N, 1); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const ScaleDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + ScaleGridDesc scale_grid_desc_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_fpAintB_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + ScaleDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceFpAintBGemm_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp new file mode 100644 index 00000000000..2ded2c0d1bb --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,1104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +template +struct GridwiseFpAintBGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1_dequant; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const ScaleGridDesc& scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto scale_grid_buf = make_dynamic_buffer( + p_scale_grid, scale_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + constexpr auto scale_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy, scale_blockwise_copy); + } + }; + + auto scale_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto scale_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::scale_block_space_offset, + SharedMemTrait::scale_block_space_size_aligned); + + auto scale_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ScaleDataType, + ScaleDataType, + decltype(scale_grid_desc), + decltype(scale_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + 1>( + scale_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + ck::tensor_operation::element_wise::PassThrough{}, + scale_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(scale_block_buf, scale_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto scale_block_buf = make_static_buffer( + scale_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto scale_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + scale_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(scale_block_buf, scale_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; + + auto scale_block_buf = scale_block_trait()[I0]; + auto scale_blockwise_copy = scale_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + Blockwise_fpAintB_GemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + /* + scale_blockwise_copy + */ + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + scale_grid_desc, + scale_block_desc, + scale_blockwise_copy, + scale_grid_buf, + scale_block_buf, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 3ce216e2454..dd4112939db 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -550,6 +550,114 @@ struct GridwiseGemmPipeline_v1<1, false, false> } }; +template +struct GridwiseGemmPipeline_v1_dequant; + +template <> +struct GridwiseGemmPipeline_v1_dequant<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleBlockDesc& scale_block_desc, + const ScaleGridBuffer& scale_grid_buf, + ScaleBlockBuffer& scale_block_buf, + ScaleBlockTransfer& scale_blockwise_copy, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + scale_blockwise_copy.RunRead(scale_grid_desc, scale_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + scale_blockwise_copy.RunWrite(scale_block_desc, scale_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + template struct GridwiseGemmPipelineInterwave_v1; From 0c51a35ea8a60adb8feb2fc7da876ea45c272730 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 25 Jul 2023 08:46:39 +0000 Subject: [PATCH 2/6] fpAintB kernel compile pass --- example/49_fpAintB_gemm/common.hpp | 89 +++++++++ .../49_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 12 +- example/49_fpAintB_gemm/run_gemm_example.inc | 14 +- .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 110 ++++++++--- .../device/impl/device_fpAintB_gemm_wmma.hpp | 155 +++++++-------- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 99 +++++++--- .../grid/gridwise_gemm_pipeline_selector.hpp | 5 + .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 4 +- include/ck/utility/data_type.hpp | 10 + .../cpu/reference_fpAintB_gemm.hpp | 177 ++++++++++++++++++ ...emm_wmma_f16_f16_f16_km_kn_mn_instance.cpp | 9 +- ...emm_wmma_f16_f16_f16_km_nk_mn_instance.cpp | 9 +- ...emm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp | 7 +- 13 files changed, 558 insertions(+), 142 deletions(-) create mode 100644 example/49_fpAintB_gemm/common.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp diff --git a/example/49_fpAintB_gemm/common.hpp b/example/49_fpAintB_gemm/common.hpp new file mode 100644 index 00000000000..1f67d53de2b --- /dev/null +++ b/example/49_fpAintB_gemm/common.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp" + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} diff --git a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp index 96e4f747816..618f1e90982 100644 --- a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp +++ b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -37,7 +37,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ BElementOp, CElementOp, GemmDefault, - 2, // Prefetch stage + 1, // Prefetch stage 128, // BlockSize 128, // MPerBlock 64, // NPerBlock @@ -67,8 +67,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ 8>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm; #include "run_gemm_example.inc" diff --git a/example/49_fpAintB_gemm/run_gemm_example.inc b/example/49_fpAintB_gemm/run_gemm_example.inc index 7d06ec4cb01..d50b592fec1 100644 --- a/example/49_fpAintB_gemm/run_gemm_example.inc +++ b/example/49_fpAintB_gemm/run_gemm_example.inc @@ -27,6 +27,8 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + // assume scale tensor is [1, n] + Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{})); switch(config.init_method) { @@ -34,26 +36,32 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(scale_k_n); break; case 2: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); break; case 3: ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(scale_k_n); break; case 4: ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(scale_k_n); break; case 5: ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(scale_k_n); break; default: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -61,6 +69,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; #ifdef BUILD_INT4_EXAMPLE @@ -77,10 +86,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) #else DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + scale_k_n_device_buf.ToDevice(scale_k_n.mData.data()); #endif auto a_element_op = AElementOp{}; @@ -98,6 +109,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) #else static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(scale_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), #endif M, @@ -136,7 +148,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + a_m_k, b_k_n, scale_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp index 283f5f87dae..472d6154a95 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -20,7 +20,7 @@ template {}; + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -178,9 +179,10 @@ struct Blockwise_fpAintB_GemmWMMA } using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); - __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), - Tuple6 b_origin = CalculateBThreadOriginDataIndex()) - : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + __host__ __device__ + Blockwise_fpAintB_GemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin), scale_thread_copy_(b_origin) { static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); @@ -290,8 +292,12 @@ struct Blockwise_fpAintB_GemmWMMA // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + static constexpr ScaleBlockDesc scale_block_desc_1_n0_n1_n2_1; - template + template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, const ScaleBlockBuffer& scale_block_buf, @@ -305,8 +311,6 @@ struct Blockwise_fpAintB_GemmWMMA scale_thread_desc_.GetElementSpaceSize()); auto converted_b_thread_buf = b_thread_buf; - static constexpr auto dequantizer = Dequantizer{}; - // basic intrinsic to determine loopover direction if constexpr(MRepeat < NRepeat) { @@ -333,21 +337,22 @@ struct Blockwise_fpAintB_GemmWMMA b_thread_buf); // read weight scale scale_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, + scale_block_desc_1_n0_n1_n2_1, make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_scale_thread_desc_, + scale_block_buf, + scale_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0), - b_scale_thread_buf); + scale_thread_buf); - // convert B from int8 to fp16 - converted_b_thread_buf = type_convert(b_thread_buf); - - // multiply scale - dequantize(converted_b_thread_buf, scale_thread_buf); + // convert B from int8 to fp16, multiply scale + static_for<0, b_thread_buf.size(), 1>{}([&](auto i) { + converted_b_thread_buf(i) = + scale_thread_buf[i / WmmaK] * + type_convert(b_thread_buf[i]); + }); vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { a_thread_vec.template AsType()(i) = @@ -358,7 +363,7 @@ struct Blockwise_fpAintB_GemmWMMA (i / A_K1) % A_KRow, 0, i % A_K1))>{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = converted_b_thread_buf[Number::type; - using wmma_input_type_b = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -396,6 +401,20 @@ struct Blockwise_fpAintB_GemmWMMA b_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0), b_thread_buf); + // read weight scale + scale_thread_copy_.Run( + scale_block_desc_1_n0_n1_n2_1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + scale_block_buf, + scale_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + scale_thread_buf); + + // convert B from int8 to fp16, multiply scale + static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { + converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * + type_convert(b_thread_buf[i]); + }); // read A a_thread_copy_.Run( a_block_desc_k0_m0_m1_m2_k1, @@ -406,11 +425,11 @@ struct Blockwise_fpAintB_GemmWMMA a_thread_buf); vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - b_thread_buf[Number()(i) = + converted_b_thread_buf[Number::type; - using wmma_input_type_b = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -472,6 +491,15 @@ struct Blockwise_fpAintB_GemmWMMA Number{}, Number<1>{})); + static constexpr auto scale_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(I0, I1, I0, I0, I0, I0)); + // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); @@ -548,8 +576,42 @@ struct Blockwise_fpAintB_GemmWMMA TransposeC ? true : false>; }; + template + struct ScaleThreadCopySelector; + + template <> + struct ScaleThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct ScaleThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic< + ScaleDataType, + ScaleDataType, + decltype(scale_block_desc_1_n0_n1_n2_1), + decltype(scale_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + 1>; + }; + typename AThreadCopySelector::type a_thread_copy_; typename BThreadCopySelector::type b_thread_copy_; + typename ScaleThreadCopySelector::type scale_thread_copy_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index 41ecbbb5321..0cff0aae769 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -22,7 +22,7 @@ namespace tensor_operation { namespace device { // 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) -// 2. C(M, N) = A(M, K) * DequantB(K, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) template + ck::PipelineVersion PipelineVer = ck::PipelineVersion::dequant_v1> struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -103,7 +103,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB{MPerBlock, NPerBlock, KPerBlock}; - + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; // Describe how data read from Global memory @@ -183,7 +183,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB; + using GridwiseGemm = GridwiseFpAintBGemm_Wmma< + BlockSize, + ADataType, + BDataType, + ScaleDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc, + BGridDesc, + ScaleGridDesc, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; // Argument struct Argument : public BaseArgument { Argument(const ADataType* p_a_grid, const BDataType* p_b_grid, - const ScaleDataType* p_scale, + const ScaleDataType* p_scale_grid, CDataType* p_c_grid, index_t M, index_t N, @@ -310,7 +311,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB, remove_reference_t, - remove_reference_t, + remove_reference_t, remove_reference_t< typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, AElementwiseOperation, @@ -422,9 +423,11 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB MakeArgumentPointer(const void* p_a, const void* p_b, + const void* p_scale, void* p_c, index_t M, index_t N, @@ -595,6 +599,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB(static_cast(p_a), static_cast(p_b), + static_cast(p_scale), static_cast(p_c), M, N, @@ -623,8 +628,10 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB LoopSchedToString{ {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; - std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, - {PipelineVersion::v2, "v2"}}; + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}, + {PipelineVersion::dequant_v1, "dequant_v1"}}; // clang-format off str << "DeviceFpAintBGemm_Wmma_CShuffle" diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 2ded2c0d1bb..3f5af4bf9d7 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -20,9 +20,11 @@ namespace ck { template (p_a_grid, p_b_grid, + p_scale_grid, p_c_grid, p_shared, a_grid_desc, b_grid_desc, + scale_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, b_element_op, @@ -63,9 +69,11 @@ __global__ void #else ignore = p_a_grid; ignore = p_b_grid; + ignore = p_scale_grid; ignore = p_c_grid; ignore = a_grid_desc; ignore = b_grid_desc; + ignore = scale_grid_desc; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; ignore = b_element_op; @@ -77,12 +85,14 @@ __global__ void template + PipelineVersion PipelineVer = PipelineVersion::dequant_v1> struct GridwiseFpAintBGemm_Wmma { static constexpr auto I0 = Number<0>{}; @@ -140,7 +150,12 @@ struct GridwiseFpAintBGemm_Wmma using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1_dequant; + using GridwiseGemmPipe = + remove_cvref_t())>; // Describe how data store to (LDS/VGPR) buffer from Global memory __host__ __device__ static constexpr auto MakeABlockDescriptor() @@ -237,6 +252,38 @@ struct GridwiseFpAintBGemm_Wmma return b_block_desc; } + __host__ __device__ static constexpr auto MakeScaleBlockDescriptor() + { + // Scale [1, N], all K related dimension reduce to 1 + constexpr auto scale_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(I0, I1, I0)); + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(I0, I1, I0, I0, I0, I0, I0)); + } + }(); + + return scale_block_desc; + } + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() { constexpr auto a_block_copy_step = [&]() { @@ -537,9 +584,15 @@ struct GridwiseFpAintBGemm_Wmma BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), max_lds_align) : 0; + static constexpr auto scale_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple( + MakeScaleBlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; static constexpr auto a_block_space_offset = 0; static constexpr auto b_block_space_offset = a_block_space_size_aligned; + static constexpr auto scale_block_space_offset = + b_block_space_offset + b_block_space_size_aligned; // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_space_size = @@ -551,7 +604,8 @@ struct GridwiseFpAintBGemm_Wmma static constexpr auto lds_size = math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)); + b_block_space_size_aligned * sizeof(BDataType) + + scale_block_space_size_aligned * sizeof(ScaleDataType)); }; template @@ -609,7 +663,7 @@ struct GridwiseFpAintBGemm_Wmma constexpr auto a_block_desc = MakeABlockDescriptor(); constexpr auto b_block_desc = MakeBBlockDescriptor(); - constexpr auto scale_block_desc = MakeBBlockDescriptor(); + constexpr auto scale_block_desc = MakeScaleBlockDescriptor(); auto a_block_trait = [&](){ // A matrix blockwise copy @@ -768,7 +822,7 @@ struct GridwiseFpAintBGemm_Wmma get_thread_local_1d_id() % 16, 0)); - return make_tuple(b_block_buf, b_blockwise_copy, scale_blockwise_copy); + return make_tuple(b_block_buf, b_blockwise_copy); } }; @@ -776,13 +830,14 @@ struct GridwiseFpAintBGemm_Wmma if constexpr(BEnableLds) { constexpr auto K0PerBlock = KPerBlock/ K1; + auto scale_block_buf = make_dynamic_buffer( static_cast(p_shared) + SharedMemTrait::scale_block_space_offset, SharedMemTrait::scale_block_space_size_aligned); auto scale_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1, @@ -802,10 +857,10 @@ struct GridwiseFpAintBGemm_Wmma 1, BThreadTransferSrcResetCoordinateAfterRun, true, - 1>( + NumGemmKPrefetchStage>( scale_grid_desc, make_multi_index(0, n_block_data_idx_on_grid, 0), - ck::tensor_operation::element_wise::PassThrough{}, + b_element_op, scale_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); @@ -815,13 +870,12 @@ struct GridwiseFpAintBGemm_Wmma else { // Thread-wise copy - // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; constexpr auto K0PerWmma = WmmaK/2/K1Value; + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 auto scale_block_buf = make_static_buffer( scale_block_desc.GetElementSpaceSize()); - // Limitation: NumDim of Src and Dst descriptor should be identical auto scale_blockwise_copy = ThreadwiseTensorSliceTransfer_v2(a_grid_desc, a_block_desc, a_blockwise_copy, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index 48bd22a764a..48295b638cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -12,6 +12,7 @@ enum struct PipelineVersion { v1, v2, + dequant_v1, }; template {}; + } else { std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index dd4112939db..cf5c9066b9a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -600,9 +600,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> const BBlockTransferStep& b_block_copy_step, const ScaleGridDesc& scale_grid_desc, const ScaleBlockDesc& scale_block_desc, + ScaleBlockTransfer& scale_blockwise_copy, const ScaleGridBuffer& scale_grid_buf, ScaleBlockBuffer& scale_block_buf, - ScaleBlockTransfer& scale_blockwise_copy, const BlockwiseGemm& blockwise_gemm, CThreadBuffer& c_thread_buf, index_t num_loop) @@ -653,7 +653,7 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> { block_sync_lds(); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf); } } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8d3f2dbd633..0e07c20ae55 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1090,6 +1090,16 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +// convert int8 to fp16 via fp32 +template <> +inline __host__ __device__ constexpr half_t type_convert(int8_t x) +{ + // TODO: replace it with fast_converter + float x_fp32 = static_cast(x); + + return type_convert(x_fp32); +} + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp new file mode 100644 index 00000000000..ac392f09069 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferencefpAintBGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + scale_k_n_{scale_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& scale_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferencefpAintBGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + ScaleDataType v_scale; + ADataType v_converted_b; + + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + // same for scale matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_scale, + arg.scale_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_scale, arg.scale_k_n_(k, n)); + } + + v_converted_b = type_convert(v_b) * v_scale; + v_acc += ck::type_convert(v_a) * + ck::type_convert(v_converted_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp index e757049b4e1..f3665eb8d8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp @@ -29,9 +29,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | @@ -62,8 +61,8 @@ using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = // 1 Wave DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> #endif // clang-format on - >; + >; void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( std::vector Date: Fri, 28 Jul 2023 07:29:32 +0000 Subject: [PATCH 3/6] Sanity pass. --- .../49_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 2 +- example/49_fpAintB_gemm/run_gemm_example.inc | 48 +++++++- .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 104 +++++++++++++--- .../device/impl/device_fpAintB_gemm_wmma.hpp | 53 ++++++++- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 75 ++++++++++-- .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 112 +++++++++++++++++- 6 files changed, 359 insertions(+), 35 deletions(-) diff --git a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp index 618f1e90982..8ff1077da4a 100644 --- a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp +++ b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -9,7 +9,7 @@ using ADataType = ck::half_t; using BDataType = int8_t; using ScaleDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = float; +using CShuffleDataType = ck::half_t; using CDataType = ck::half_t; using ALayout = Row; diff --git a/example/49_fpAintB_gemm/run_gemm_example.inc b/example/49_fpAintB_gemm/run_gemm_example.inc index d50b592fec1..913a18d7a4b 100644 --- a/example/49_fpAintB_gemm/run_gemm_example.inc +++ b/example/49_fpAintB_gemm/run_gemm_example.inc @@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] - Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{})); + Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); switch(config.init_method) { @@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) case 4: ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(scale_k_n); + ck::utils::FillUniformDistributionIntegerValue{2.f, 2.f}(scale_k_n); break; case 5: ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); @@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); } +#if 0 + printf("Matrix A:\n"); + for (int im = 0; im < M; im++) + { + for (int ik = 0; ik < K; ik++) + { + if(ik % 16 == 0){ + printf("|"); + } + + printf(" %04x", *(reinterpret_cast(&a_m_k(im,ik)))); + } + printf("\n"); + } + + printf("Matrix B:\n"); + for (int in = 0; in < N; in++) + { + for (int ik = 0; ik < K; ik++) + { + if(ik % 16 == 0){ + printf("|"); + } + + printf(" %02x", b_k_n(ik,in)); + } + printf("\n"); + } + + printf("Matrix Scale:\n"); + for (int in = 0; in < N; in++) + { + for (int ik = 0; ik < K; ik++) + { + if(ik % 16 == 0){ + printf("|"); + } + + printf(" %04x", *(reinterpret_cast(&scale_k_n(ik,in)))); + } + printf("\n"); + } + #endif + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp index 472d6154a95..434b69d5786 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA b_thread_desc_.GetElementSpaceSize()); auto scale_thread_buf = make_static_buffer( scale_thread_desc_.GetElementSpaceSize()); - auto converted_b_thread_buf = b_thread_buf; + auto converted_b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); // basic intrinsic to determine loopover direction if constexpr(MRepeat < NRepeat) @@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA scale_thread_buf); // convert B from int8 to fp16, multiply scale - static_for<0, b_thread_buf.size(), 1>{}([&](auto i) { + static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * type_convert(b_thread_buf[i]); @@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA else { static_for<0, NRepeat, 1>{}([&](auto n0) { + // read weight scale + scale_thread_copy_.Run( + scale_block_desc_1_n0_n1_n2_1, + make_tuple(I0, n0, I0, I0, I0, I0), + scale_block_buf, + scale_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + scale_thread_buf); +#if 0 + printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n", + get_thread_local_1d_id(), n0.value, + *(reinterpret_cast(&scale_thread_buf[n0])) + ); +#endif static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of // k=0,kpack*1, .. @@ -400,16 +415,7 @@ struct Blockwise_fpAintB_GemmWMMA b_block_buf, b_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // read weight scale - scale_thread_copy_.Run( - scale_block_desc_1_n0_n1_n2_1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - scale_block_buf, - scale_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - scale_thread_buf); - + b_thread_buf); // convert B from int8 to fp16, multiply scale static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * @@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA a_thread_desc_, make_tuple(I0, m0, I0, I0, I0, I0), a_thread_buf); - + if (true){ +#if 0 + printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n", + get_thread_local_1d_id(), m0.value, n0.value, k.value, + *(reinterpret_cast(&a_thread_buf[Number<0>{}])), + *(reinterpret_cast(&a_thread_buf[Number<1>{}])), + *(reinterpret_cast(&a_thread_buf[Number<2>{}])), + *(reinterpret_cast(&a_thread_buf[Number<3>{}])), + *(reinterpret_cast(&a_thread_buf[Number<4>{}])), + *(reinterpret_cast(&a_thread_buf[Number<5>{}])), + *(reinterpret_cast(&a_thread_buf[Number<6>{}])), + *(reinterpret_cast(&a_thread_buf[Number<7>{}])), + *(reinterpret_cast(&a_thread_buf[Number<8>{}])), + *(reinterpret_cast(&a_thread_buf[Number<9>{}])), + *(reinterpret_cast(&a_thread_buf[Number<10>{}])), + *(reinterpret_cast(&a_thread_buf[Number<11>{}])), + *(reinterpret_cast(&a_thread_buf[Number<12>{}])), + *(reinterpret_cast(&a_thread_buf[Number<13>{}])), + *(reinterpret_cast(&a_thread_buf[Number<14>{}])), + *(reinterpret_cast(&a_thread_buf[Number<15>{}])) + ); +#endif +#if 0 + printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n", + get_thread_local_1d_id(), m0.value, n0.value, k.value, + b_thread_buf[Number<0>{}], + b_thread_buf[Number<1>{}], + b_thread_buf[Number<2>{}], + b_thread_buf[Number<3>{}], + b_thread_buf[Number<4>{}], + b_thread_buf[Number<5>{}], + b_thread_buf[Number<6>{}], + b_thread_buf[Number<7>{}], + b_thread_buf[Number<8>{}], + b_thread_buf[Number<9>{}], + b_thread_buf[Number<10>{}], + b_thread_buf[Number<11>{}], + b_thread_buf[Number<12>{}], + b_thread_buf[Number<13>{}], + b_thread_buf[Number<14>{}], + b_thread_buf[Number<15>{}] + ); +#endif +#if 0 + printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n", + get_thread_local_1d_id(), m0.value, n0.value, k.value, + *(reinterpret_cast(&converted_b_thread_buf[Number<0>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<1>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<2>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<3>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<4>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<5>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<6>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<7>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<8>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<9>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<10>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<11>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<12>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<13>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<14>{}])), + *(reinterpret_cast(&converted_b_thread_buf[Number<15>{}])) + ); +#endif + } vector_type a_thread_vec; vector_type b_thread_vec; @@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA I1, Number{}, I1, - Number{}), + I1), make_tuple(I0, I1, I0, I0, I0, I0)); // C[M, N, NumRegWMMA] @@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA ScaleDataType, decltype(scale_block_desc_1_n0_n1_n2_1), decltype(scale_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, - B_K1, - B_K1>; + 1, + 1>; }; template <> diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index 0cff0aae769..cb6678e391b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) { const auto c_grid_desc_mraw_nraw = [&]() { @@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB(p_a_grid, p_b_grid, @@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma constexpr auto K0PerBlock = KPerBlock / K1; return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, I1), make_tuple(I0, I1, I0)); } else @@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma Number{}, I1, I1, - K1), + I1), make_tuple(I0, I1, I0, I0, I0, I0, I0)); } }(); @@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma return b_wave_desc; } + template + __host__ __device__ static constexpr auto MakeScaleWaveDescriptor(const ScaleBlockDesc_&) + { + constexpr auto scale_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = ScaleBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + ScaleBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ScaleBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ScaleBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = ScaleBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(I0, + I1, + I0, + I0, + I0, + I0)); + } + }(); + + return scale_wave_desc; + } + __host__ __device__ static constexpr auto // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() @@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma : 0; static constexpr auto a_block_space_offset = 0; - static constexpr auto b_block_space_offset = a_block_space_size_aligned; + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType)/sizeof(BDataType); static constexpr auto scale_block_space_offset = - b_block_space_offset + b_block_space_size_aligned; + (b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType)/sizeof(ScaleDataType); // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_space_size = @@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma auto b_block_buf = make_dynamic_buffer( static_cast(p_shared) + SharedMemTrait::b_block_space_offset, SharedMemTrait::b_block_space_size_aligned); - + // printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset); auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1( static_cast(p_shared) + SharedMemTrait::scale_block_space_offset, SharedMemTrait::scale_block_space_size_aligned); - + // printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset); + auto scale_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1, + // Reduce slice length K1 to 1 + Sequence, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, ScaleDataType, @@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma Sequence<0, 1, 2>, BBlockTransferSrcVectorDim, 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, 1, 1, + 1, // no effect + 1, // no effect BThreadTransferSrcResetCoordinateAfterRun, true, NumGemmKPrefetchStage>( @@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma AccDataType, decltype(MakeAWaveDescriptor(a_block_desc)), decltype(MakeBWaveDescriptor(b_block_desc)), - decltype(MakeBWaveDescriptor(scale_block_desc)), + decltype(MakeScaleWaveDescriptor(scale_block_desc)), MPerBlock, NPerBlock, KPerBlock, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index cf5c9066b9a..3a04213a9a2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> typename BBlockTransferStep, typename ScaleGridDesc, typename ScaleBlockDesc, + typename ScaleBlockTransfer, typename ScaleGridBuffer, typename ScaleBlockBuffer, - typename ScaleBlockTransfer, typename BlockwiseGemm, typename CThreadBuffer> __device__ static void Run(const AGridDesc& a_grid_desc, @@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> } }; +template <> +struct GridwiseGemmPipeline_v1_dequant<1, true, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleBlockDesc& scale_block_desc, + ScaleBlockTransfer& scale_blockwise_copy, + const ScaleGridBuffer& scale_grid_buf, + ScaleBlockBuffer& scale_block_buf, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + scale_blockwise_copy.Run(scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + template struct GridwiseGemmPipelineInterwave_v1; From 32bac6f3bc432beeda4a9033170b33bf06c7a493 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Tue, 1 Aug 2023 06:40:46 +0000 Subject: [PATCH 4/6] Temp save --- .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 2 +- .../gpu/thread/threadwise_tensor_slice_transfer.hpp | 9 ++++----- .../thread/threadwise_tensor_slice_transfer_v3r1.hpp | 10 ++++++++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp index 434b69d5786..cfd49668597 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -419,7 +419,7 @@ struct Blockwise_fpAintB_GemmWMMA // convert B from int8 to fp16, multiply scale static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * - type_convert(b_thread_buf[i]); + type_convert(b_thread_buf[i]); // call byte permute }); // read A a_thread_copy_.Run( diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 570d4e725bd..3832b522ef4 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1143,7 +1143,9 @@ struct ThreadwiseTensorSliceTransfer_v4 const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc, src_data_coord); - +#if 0 + printf("Tid: %03d, LDS read offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetOffset()); +#endif // copy data from src_buf into src_tmp_vector if constexpr(SrcBuffer::IsDynamicBuffer()) { @@ -1417,10 +1419,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow 1, 0); v_theother_row = type_convert_sp(temp); - // if (get_thread_local_1d_id() == 0){ - // printf("src_offset:%d, dst_offset for this row: %d, dst_offset - // for the other row: %d \n", - // src_offset, dst_offset, dst_offset+DstScalarPerVector);} + if(get_thread_local_1d_id() % 32 < 16) { // apply type convert diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 6665d765f81..78f25091eac 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -207,7 +207,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // copy data from src_buf into src_vector_container auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - + if (false){ + printf("Tid: %03d, a_grid_buf: %04x\n", + get_thread_local_1d_id(), + *(reinterpret_cast(&src_vector_container.template AsType()[Number<0>{}]))); + } // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( @@ -442,7 +446,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); - +#if 0 + printf("Tid: %03d, LDS write offset: %d\n", get_thread_local_1d_id(), dst_coord_.GetOffset()); +#endif using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; From 5cf73a5e3a42175aa1fa6d7e40113dd97a7b6f7e Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 3 Aug 2023 01:37:00 +0000 Subject: [PATCH 5/6] debug code enabled --- example/49_fpAintB_gemm/common.hpp | 34 +++++ .../49_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 18 ++- example/49_fpAintB_gemm/run_gemm_example.inc | 58 +++++--- .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 125 +++++++++--------- .../device/impl/device_fpAintB_gemm_wmma.hpp | 5 +- .../element/unary_element_wise_operation.hpp | 91 +++++++++++++ .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 31 +++-- .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 3 +- .../threadwise_tensor_slice_transfer.hpp | 4 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 8 +- include/ck/utility/amd_buffer_addressing.hpp | 111 +++++++++++++++- include/ck/utility/data_type.hpp | 16 +++ script/clang-format-overwrite.sh | 4 +- 13 files changed, 402 insertions(+), 106 deletions(-) diff --git a/example/49_fpAintB_gemm/common.hpp b/example/49_fpAintB_gemm/common.hpp index 1f67d53de2b..4fb4c41d056 100644 --- a/example/49_fpAintB_gemm/common.hpp +++ b/example/49_fpAintB_gemm/common.hpp @@ -48,6 +48,40 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +template +struct UnsignedWeightPreprocessor +{ +}; + +template <> +struct UnsignedWeightPreprocessor +{ + using UnsignedWeight = Tensor; + using SignedWeight = Tensor; + static UnsignedWeight convert(SignedWeight const& Input) + { + + UnsignedWeight Output = Input.template CopyAsType(); + + auto f_kn = [&](auto k, auto n) { + const uint8_t adder = 128; + int8_t v_signed_weight; + uint8_t v_unsigned_weight; + + ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n)); + v_unsigned_weight = ck::type_convert(v_signed_weight) + adder; + Output(k, n) = v_unsigned_weight; + }; + + make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return Output; + } + + UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); } +}; + inline bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) { diff --git a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp index 8ff1077da4a..e8776a94bcd 100644 --- a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp +++ b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -5,8 +5,18 @@ #include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" +// Implementation follows the paper: +// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run: +// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022. +// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to +// unsigned. + +// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType +// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType + using ADataType = ck::half_t; -using BDataType = int8_t; +using QuantDataType = int8_t; +using BDataType = uint8_t; using ScaleDataType = ck::half_t; using AccDataType = float; using CShuffleDataType = ck::half_t; @@ -40,13 +50,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ 1, // Prefetch stage 128, // BlockSize 128, // MPerBlock - 64, // NPerBlock + 128, // NPerBlock 64, // KPerBlock 8, // K1 16, // MPerWmma 16, // NPerWmma 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave - 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -68,7 +78,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // assume scale tensor is [1, n] Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); @@ -35,35 +35,38 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) case 0: break; case 1: ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(quant_b_k_n); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(scale_k_n); break; case 2: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); break; case 3: ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(quant_b_k_n); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(scale_k_n); break; case 4: ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(quant_b_k_n); ck::utils::FillUniformDistributionIntegerValue{2.f, 2.f}(scale_k_n); break; case 5: ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); - ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(quant_b_k_n); ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(scale_k_n); break; default: ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); } + UnsignedWeightPreprocessor preprocessor; + Tensor b_k_n = preprocessor(quant_b_k_n); + #if 0 printf("Matrix A:\n"); for (int im = 0; im < M; im++) @@ -78,8 +81,9 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) } printf("\n"); } - - printf("Matrix B:\n"); +#endif +#if 0 + printf("Matrix QuantB:\n"); for (int in = 0; in < N; in++) { for (int ik = 0; ik < K; ik++) @@ -88,12 +92,29 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) printf("|"); } - printf(" %02x", b_k_n(ik,in)); + printf(" %02x", *(reinterpret_cast(&quant_b_k_n(ik,in)))); } printf("\n"); } - +#endif +#if 0 printf("Matrix Scale:\n"); + for(int in = 0; in < N; in++) + { + for(int ik = 0; ik < 1; ik++) + { + if(ik % 16 == 0) + { + printf("|"); + } + + printf(" %04x", *(reinterpret_cast(&scale_k_n(ik, in)))); + } + printf("\n"); + } +#endif +#if 0 + printf("Matrix B:\n"); for (int in = 0; in < N; in++) { for (int ik = 0; ik < K; ik++) @@ -102,12 +123,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) printf("|"); } - printf(" %04x", *(reinterpret_cast(&scale_k_n(ik,in)))); + printf(" %02x", b_k_n(ik,in)); } printf("\n"); } - #endif - +#endif + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -191,8 +212,13 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); - auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, scale_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + quant_b_k_n, + scale_k_n, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); ref_invoker.Run(ref_argument); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp index cfd49668597..981fa70a69b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -309,8 +309,10 @@ struct Blockwise_fpAintB_GemmWMMA b_thread_desc_.GetElementSpaceSize()); auto scale_thread_buf = make_static_buffer( scale_thread_desc_.GetElementSpaceSize()); - auto converted_b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); + // auto converted_b_thread_buf = make_static_buffer( + // b_thread_desc_.GetElementSpaceSize()); + tensor_operation::element_wise::FastNumericArrayConverter + fast_numeric_converter; // basic intrinsic to determine loopover direction if constexpr(MRepeat < NRepeat) @@ -345,15 +347,29 @@ struct Blockwise_fpAintB_GemmWMMA make_tuple(I0, n0, I0, I0, I0, I0), scale_thread_buf); - // convert B from int8 to fp16, multiply scale - static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { - converted_b_thread_buf(i) = - scale_thread_buf[i / WmmaK] * - type_convert(b_thread_buf[i]); + vector_type b_int_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + b_int_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + // convert B from uint8 to fp16, multiply scale + b_thread_vec = fast_numeric_converter(b_int_vec); + static_for<0, WmmaK, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + scale_thread_buf[n0] * + b_thread_vec.template AsType()(i); }); vector_type a_thread_vec; - vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { a_thread_vec.template AsType()(i) = @@ -364,14 +380,6 @@ struct Blockwise_fpAintB_GemmWMMA (i / A_K1) % A_KRow, 0, i % A_K1))>{}]; - b_thread_vec.template AsType()(i) = - converted_b_thread_buf[Number{}]; }); using wmma_input_type_a = typename vector_type::type; @@ -390,37 +398,48 @@ struct Blockwise_fpAintB_GemmWMMA } else { - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read weight scale - scale_thread_copy_.Run( - scale_block_desc_1_n0_n1_n2_1, - make_tuple(I0, n0, I0, I0, I0, I0), - scale_block_buf, - scale_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - scale_thread_buf); + static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read weight scale + scale_thread_copy_.Run(scale_block_desc_1_n0_n1_n2_1, + make_tuple(I0, n0, I0, I0, I0, I0), + scale_block_buf, + scale_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + scale_thread_buf); #if 0 printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n", get_thread_local_1d_id(), n0.value, *(reinterpret_cast(&scale_thread_buf[n0])) ); #endif - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of - // k=0,kpack*1, .. - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0, I0), - b_thread_buf); - // convert B from int8 to fp16, multiply scale - static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { - converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * - type_convert(b_thread_buf[i]); // call byte permute - }); + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type b_int_vec; + vector_type b_thread_vec; + + static_for<0, WmmaK, 1>{}([&](auto i) { + b_int_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + // convert B from uint8 to fp16, multiply scale + b_thread_vec = fast_numeric_converter(b_int_vec); + static_for<0, WmmaK, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + scale_thread_buf[n0] * b_thread_vec.template AsType()(i); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { // read A a_thread_copy_.Run( a_block_desc_k0_m0_m1_m2_k1, @@ -429,7 +448,8 @@ struct Blockwise_fpAintB_GemmWMMA a_thread_desc_, make_tuple(I0, m0, I0, I0, I0, I0), a_thread_buf); - if (true){ + if(true) + { #if 0 printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n", get_thread_local_1d_id(), m0.value, n0.value, k.value, @@ -495,17 +515,8 @@ struct Blockwise_fpAintB_GemmWMMA #endif } vector_type a_thread_vec; - vector_type b_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { - b_thread_vec.template AsType()(i) = - converted_b_thread_buf[Number{}]; a_thread_vec.template AsType()(i) = a_thread_buf[Number{}, Number<1>{})); - static constexpr auto scale_thread_desc_ = - make_naive_tensor_descriptor(make_tuple(Number{}, - Number{}, - I1, - Number{}, - I1, - I1), - make_tuple(I0, I1, I0, I0, I0, I0)); + static constexpr auto scale_thread_desc_ = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, I1, Number{}, I1, I1), + make_tuple(I0, I1, I0, I0, I0, I0)); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp index cb6678e391b..64aaaf034c2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -95,8 +95,9 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB 1); static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c3e7706ef3f..57aa8638a31 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -6,6 +6,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/math.hpp" #include "ck/utility/math_v2.hpp" +#include "ck/utility/get_id.hpp" namespace ck { namespace tensor_operation { @@ -68,6 +69,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(uint8_t& y, const uint8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { @@ -371,6 +378,90 @@ struct Swish float beta_ = 1.0f; }; +// support fastconvert of int8 to fp16 + +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + // printf("Tid: %03d, uint8_4: %08x\n", + // get_thread_local_1d_id(), + // uint8_4); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + // printf("Tid: %03d, Part1 converted: %08x | %08x\n", + // get_thread_local_1d_id(), + // half_2[Number<0>{}], + // half_2[Number<1>{}]); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed + // integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + // printf("Tid: %03d, Part2 converted: %08x | %08x\n", + // get_thread_local_1d_id(), + // half_2[Number<0>{}], + // half_2[Number<1>{}]); + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index f205b3a18f2..8010550e040 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -52,11 +52,13 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx1102__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; - if (false && get_thread_local_1d_id()==0){ + if(false && get_thread_local_1d_id() == 0) + { printf("lds_size: %lu\n", GridwiseGemm::SharedMemTrait::lds_size); printf("lds_a_size: %d\n", GridwiseGemm::SharedMemTrait::a_block_space_size_aligned); printf("lds_b_size: %d\n", GridwiseGemm::SharedMemTrait::b_block_space_size_aligned); - printf("lds_scale_size: %d\n", GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned); + printf("lds_scale_size: %d\n", + GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned); } GridwiseGemm::template Run(p_a_grid, @@ -459,17 +461,12 @@ struct GridwiseFpAintBGemm_Wmma // Workaround, Freeze transform return make_naive_tensor_descriptor(make_tuple(Number{}, - Number{}, - I1, - Number{}, - I1, - Number{}), - make_tuple(I0, - I1, - I0, - I0, - I0, - I0)); + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(I0, I1, I0, I0, I0, I0)); } }(); @@ -642,10 +639,12 @@ struct GridwiseFpAintBGemm_Wmma : 0; static constexpr auto a_block_space_offset = 0; - static constexpr auto b_block_space_offset = - (a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType)/sizeof(BDataType); + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType) / + sizeof(BDataType); static constexpr auto scale_block_space_offset = - (b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType)/sizeof(ScaleDataType); + (b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType) / + sizeof(ScaleDataType); // LDS allocation for C shuffle in LDS static constexpr auto c_shuffle_block_space_size = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 3a04213a9a2..0ff11a531f8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -719,7 +719,8 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, false> a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); b_blockwise_copy.Run( b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); - scale_blockwise_copy.Run(scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf); + scale_blockwise_copy.Run( + scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 3832b522ef4..5f350c98564 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1145,7 +1145,7 @@ struct ThreadwiseTensorSliceTransfer_v4 src_desc, src_data_coord); #if 0 printf("Tid: %03d, LDS read offset: %d\n", get_thread_local_1d_id(), src_data_coord.GetOffset()); -#endif +#endif // copy data from src_buf into src_tmp_vector if constexpr(SrcBuffer::IsDynamicBuffer()) { @@ -1419,7 +1419,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow 1, 0); v_theother_row = type_convert_sp(temp); - + if(get_thread_local_1d_id() % 32 < 16) { // apply type convert diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 78f25091eac..096e93bf202 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -207,10 +207,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // copy data from src_buf into src_vector_container auto src_vector_container = src_vector_type{ src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; - if (false){ + if(false) + { printf("Tid: %03d, a_grid_buf: %04x\n", - get_thread_local_1d_id(), - *(reinterpret_cast(&src_vector_container.template AsType()[Number<0>{}]))); + get_thread_local_1d_id(), + *(reinterpret_cast( + &src_vector_container.template AsType()[Number<0>{}]))); } // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 38ee76d8836..f9bb7d0fa2c 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -312,7 +312,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); if constexpr(is_same::value) @@ -614,6 +615,114 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w static_cast(coherence)); return bit_cast(tmp); +#endif + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); +#else + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 4) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + return llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); +#else + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 8) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + static_cast(coherence)); + + return tmp.AsType()(Number<0>{}); +#else + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); +#endif + } + else if constexpr(N == 16) + { +#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE + vector_type tmp; + + tmp.AsType()(Number<0>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + tmp.AsType()(Number<1>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int8_t), + static_cast(coherence)); + + tmp.AsType()(Number<2>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int8_t), + static_cast(coherence)); + + tmp.AsType()(Number<3>{}) = + llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int8_t), + static_cast(coherence)); + + return tmp.AsType()(Number<0>{}); +#else + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); #endif } } diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 0e07c20ae55..0c09d74428e 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -133,6 +133,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = uint8_t; + static constexpr index_t vector_size = 1; +}; + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> struct scalar_type @@ -944,6 +951,15 @@ using int8x16_t = typename vector_type::type; using int8x32_t = typename vector_type::type; using int8x64_t = typename vector_type::type; +// u8 +// i8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + // Convert X to Y template __host__ __device__ constexpr Y type_convert(X x) diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 2ddbb6440d8..3a09d6038a4 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +# git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' From b5083bfef4a1f7600c8c30030e677ad79b07d2fb Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Thu, 3 Aug 2023 02:01:44 +0000 Subject: [PATCH 6/6] Fp16AInt8B_GEMM sanity --- .../49_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 8 +- example/49_fpAintB_gemm/run_gemm_example.inc | 62 ---------------- .../gpu/block/blockwise_fpAintB_gemm_wmma.hpp | 74 +------------------ .../element/unary_element_wise_operation.hpp | 16 +--- .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 11 +-- .../threadwise_tensor_slice_transfer.hpp | 4 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 12 +-- script/clang-format-overwrite.sh | 4 +- 8 files changed, 14 insertions(+), 177 deletions(-) diff --git a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp index e8776a94bcd..138c8f1f86a 100644 --- a/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp +++ b/example/49_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -14,6 +14,8 @@ // The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType // The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType +//TODO: Current implementation consume more VGPR than expected. + using ADataType = ck::half_t; using QuantDataType = int8_t; using BDataType = uint8_t; @@ -49,13 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_ GemmDefault, 1, // Prefetch stage 128, // BlockSize - 128, // MPerBlock - 128, // NPerBlock + 64, // MPerBlock + 128, // NPerBlock 64, // KPerBlock 8, // K1 16, // MPerWmma 16, // NPerWmma - 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave S<4, 32, 1>, S<1, 0, 2>, diff --git a/example/49_fpAintB_gemm/run_gemm_example.inc b/example/49_fpAintB_gemm/run_gemm_example.inc index 5aca18fd5cc..87c8d6a70a1 100644 --- a/example/49_fpAintB_gemm/run_gemm_example.inc +++ b/example/49_fpAintB_gemm/run_gemm_example.inc @@ -67,68 +67,6 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) UnsignedWeightPreprocessor preprocessor; Tensor b_k_n = preprocessor(quant_b_k_n); -#if 0 - printf("Matrix A:\n"); - for (int im = 0; im < M; im++) - { - for (int ik = 0; ik < K; ik++) - { - if(ik % 16 == 0){ - printf("|"); - } - - printf(" %04x", *(reinterpret_cast(&a_m_k(im,ik)))); - } - printf("\n"); - } -#endif -#if 0 - printf("Matrix QuantB:\n"); - for (int in = 0; in < N; in++) - { - for (int ik = 0; ik < K; ik++) - { - if(ik % 16 == 0){ - printf("|"); - } - - printf(" %02x", *(reinterpret_cast(&quant_b_k_n(ik,in)))); - } - printf("\n"); - } -#endif -#if 0 - printf("Matrix Scale:\n"); - for(int in = 0; in < N; in++) - { - for(int ik = 0; ik < 1; ik++) - { - if(ik % 16 == 0) - { - printf("|"); - } - - printf(" %04x", *(reinterpret_cast(&scale_k_n(ik, in)))); - } - printf("\n"); - } -#endif -#if 0 - printf("Matrix B:\n"); - for (int in = 0; in < N; in++) - { - for (int ik = 0; ik < K; ik++) - { - if(ik % 16 == 0){ - printf("|"); - } - - printf(" %02x", b_k_n(ik,in)); - } - printf("\n"); - } -#endif - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp index 981fa70a69b..7aab2c77c2b 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_fpAintB_gemm_wmma.hpp @@ -408,12 +408,7 @@ struct Blockwise_fpAintB_GemmWMMA scale_thread_desc_, make_tuple(I0, n0, I0, I0, I0, I0), scale_thread_buf); -#if 0 - printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n", - get_thread_local_1d_id(), n0.value, - *(reinterpret_cast(&scale_thread_buf[n0])) - ); -#endif + // read B b_thread_copy_.Run( b_block_desc_k0_n0_n1_n2_k1, @@ -448,72 +443,7 @@ struct Blockwise_fpAintB_GemmWMMA a_thread_desc_, make_tuple(I0, m0, I0, I0, I0, I0), a_thread_buf); - if(true) - { -#if 0 - printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n", - get_thread_local_1d_id(), m0.value, n0.value, k.value, - *(reinterpret_cast(&a_thread_buf[Number<0>{}])), - *(reinterpret_cast(&a_thread_buf[Number<1>{}])), - *(reinterpret_cast(&a_thread_buf[Number<2>{}])), - *(reinterpret_cast(&a_thread_buf[Number<3>{}])), - *(reinterpret_cast(&a_thread_buf[Number<4>{}])), - *(reinterpret_cast(&a_thread_buf[Number<5>{}])), - *(reinterpret_cast(&a_thread_buf[Number<6>{}])), - *(reinterpret_cast(&a_thread_buf[Number<7>{}])), - *(reinterpret_cast(&a_thread_buf[Number<8>{}])), - *(reinterpret_cast(&a_thread_buf[Number<9>{}])), - *(reinterpret_cast(&a_thread_buf[Number<10>{}])), - *(reinterpret_cast(&a_thread_buf[Number<11>{}])), - *(reinterpret_cast(&a_thread_buf[Number<12>{}])), - *(reinterpret_cast(&a_thread_buf[Number<13>{}])), - *(reinterpret_cast(&a_thread_buf[Number<14>{}])), - *(reinterpret_cast(&a_thread_buf[Number<15>{}])) - ); -#endif -#if 0 - printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n", - get_thread_local_1d_id(), m0.value, n0.value, k.value, - b_thread_buf[Number<0>{}], - b_thread_buf[Number<1>{}], - b_thread_buf[Number<2>{}], - b_thread_buf[Number<3>{}], - b_thread_buf[Number<4>{}], - b_thread_buf[Number<5>{}], - b_thread_buf[Number<6>{}], - b_thread_buf[Number<7>{}], - b_thread_buf[Number<8>{}], - b_thread_buf[Number<9>{}], - b_thread_buf[Number<10>{}], - b_thread_buf[Number<11>{}], - b_thread_buf[Number<12>{}], - b_thread_buf[Number<13>{}], - b_thread_buf[Number<14>{}], - b_thread_buf[Number<15>{}] - ); -#endif -#if 0 - printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n", - get_thread_local_1d_id(), m0.value, n0.value, k.value, - *(reinterpret_cast(&converted_b_thread_buf[Number<0>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<1>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<2>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<3>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<4>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<5>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<6>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<7>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<8>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<9>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<10>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<11>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<12>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<13>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<14>{}])), - *(reinterpret_cast(&converted_b_thread_buf[Number<15>{}])) - ); -#endif - } + vector_type a_thread_vec; static_for<0, WmmaK, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 57aa8638a31..28d60e3ca90 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -398,23 +398,12 @@ struct FastNumericArrayConverter uint32_t* half_2 = reinterpret_cast(&Output); uint32_t const uint8_4 = reinterpret_cast(Input); - // printf("Tid: %03d, uint8_4: %08x\n", - // get_thread_local_1d_id(), - // uint8_4); - static constexpr uint32_t byte_selector_01 = 0x05010500; static constexpr uint32_t byte_selector_23 = 0x05030502; static constexpr uint32_t fp16_adder = 0x64646464; half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); - // printf("Tid: %03d, Part1 converted: %08x | %08x\n", - // get_thread_local_1d_id(), - // half_2[Number<0>{}], - // half_2[Number<1>{}]); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed - // integer as fp16. static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" : "=v"(half_2[0]) @@ -422,10 +411,7 @@ struct FastNumericArrayConverter asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]\n" : "=v"(half_2[1]) : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); - // printf("Tid: %03d, Part2 converted: %08x | %08x\n", - // get_thread_local_1d_id(), - // half_2[Number<0>{}], - // half_2[Number<1>{}]); + return Output; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 8010550e040..b44f8d0e0eb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -52,14 +52,6 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ defined(__gfx1102__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; - if(false && get_thread_local_1d_id() == 0) - { - printf("lds_size: %lu\n", GridwiseGemm::SharedMemTrait::lds_size); - printf("lds_a_size: %d\n", GridwiseGemm::SharedMemTrait::a_block_space_size_aligned); - printf("lds_b_size: %d\n", GridwiseGemm::SharedMemTrait::b_block_space_size_aligned); - printf("lds_scale_size: %d\n", - GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned); - } GridwiseGemm::template Run(p_a_grid, p_b_grid, @@ -805,7 +797,7 @@ struct GridwiseFpAintBGemm_Wmma auto b_block_buf = make_dynamic_buffer( static_cast(p_shared) + SharedMemTrait::b_block_space_offset, SharedMemTrait::b_block_space_size_aligned); - // printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset); + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1( static_cast(p_shared) + SharedMemTrait::scale_block_space_offset, SharedMemTrait::scale_block_space_size_aligned); - // printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset); auto scale_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1(src_coord_.GetOffset(), is_src_valid)}; - if(false) - { - printf("Tid: %03d, a_grid_buf: %04x\n", - get_thread_local_1d_id(), - *(reinterpret_cast( - &src_vector_container.template AsType()[Number<0>{}]))); - } + // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( @@ -448,9 +442,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); -#if 0 - printf("Tid: %03d, LDS write offset: %d\n", get_thread_local_1d_id(), dst_coord_.GetOffset()); -#endif + using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 3a09d6038a4..2ddbb6440d8 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' -# git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'