From 5b6700bd308b6d2e45962ae175a636ae4cf1777a Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 25 Sep 2025 09:40:35 +0800 Subject: [PATCH 01/23] tf32:bf16x3:use bf16x3 emulate tf32 gemm --- example/15_grouped_gemm/CMakeLists.txt | 10 +++ .../grouped_gemm_xdl_fp32_tf32.cpp | 67 +++++++++++++++++++ .../run_grouped_gemm_example.inc | 9 ++- .../gpu/device/device_grouped_gemm.hpp | 3 +- .../device/impl/device_grouped_gemm_xdl.hpp | 8 +-- 5 files changed, 91 insertions(+), 6 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 20cbc5fdca6..f8448aa2a55 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() + +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp) + add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32) + set(target 1) + endif() +endforeach() diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp new file mode 100644 index 00000000000..cd9a3435faa --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>; +// clang-format on + +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 4ef6074f4ac..1f3c9970814 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + struct ProblemSize final { std::vector Ms; @@ -231,7 +236,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co AccDataType, AElementOp, BElementOp, - CDEElementOp>; + CDEElementOp, + ComputeDataType, + ComputeDataType>; for(std::size_t i = 0; i < gemm_descs.size(); i++) { diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 52632785bd4..cf9992942e7 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -94,7 +94,8 @@ template + typename CElementwiseOperation, + typename ComputeDataType = ADataType> struct DeviceGroupedGemm : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 7a1944cc685..0ae1aa321a1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -134,7 +134,8 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename ComputeDataType = ADataType> struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm + CDEElementwiseOperation, + ComputeDataType> { using DeviceOp = DeviceGroupedGemm_Xdl; GET_NXDL_PER_WAVE_IMPL @@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); - using ComputeDataType = ADataType; - // GridwiseGemm template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< From 1549f41cca5e4aad578b93d011bea0a0d36024c1 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 30 Oct 2025 17:39:11 +0800 Subject: [PATCH 02/23] change blockwiseGemm to demo bf16x3 --- include/ck/library/utility/check_err.hpp | 4 + .../gpu/block/blockwise_gemm_xdlops.hpp | 525 +++++++++++++++++- .../cpu/reference_gemm.hpp | 5 +- 3 files changed, 519 insertions(+), 15 deletions(-) diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 3637053e14b..bf6851314d8 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -225,6 +225,10 @@ check_err(const Range& out, double rtol = 1e-5, double atol = 3e-6) { + // TODO: change according device + rtol = 5e-3; + atol = 5e-3; + // std::cout << "check_err: rtol = " << rtol << ", atol = " << atol << std::endl; if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 55015dd30f7..489d9671e54 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -307,6 +307,31 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("MPerBlock: %d, NPerBlock: %d, KPerBlock: %d, A_K0: %d, B_K0: %d, A_K1: %d, " + // "B_K1: %d, MWaves: %d, NWaves: %d, WaveSize: %d, KPerThread: %d\n", + // MPerBlock, + // NPerBlock, + // KPerBlock, + // A_K0, + // B_K0, + // A_K1, + // B_K1, + // MWaves, + // NWaves, + // WaveSize, + // KPerThread); + // printf("a thead size: %ld; b thead size: %ld\n", + // a_thread_desc_.GetElementSpaceSize().value, + // b_thread_desc_.GetElementSpaceSize().value); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("a_block_buf: %d , %d, %d, %d \n", a_block_buf[0], a_block_buf[1], + // a_block_buf[2], a_block_buf[3]); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("b_block_buf: %d , %d, %d, %d \n", b_block_buf[0], b_block_buf[1], + // b_block_buf[2], b_block_buf[3]); auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -330,6 +355,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(I0, I0, I0, I0), b_thread_buf); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("a thead: %d, %d, %d, %d; b thead: %d, %d, %d, %d\n", + // a_thread_buf[I0], + // a_thread_buf[I1], + // a_thread_buf[I2], + // a_thread_buf[I3], + // b_thread_buf[I0], + // b_thread_buf[I1], + // b_thread_buf[I2], + // b_thread_buf[I3]); + static_for<0, KPerThread, KPack>{}([&](auto k) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -603,6 +639,449 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING }; +// use bfx3 simulate tf32. +// - in/out/acc are all float; +// - one input is separated to 2 bf16 registers. -- TODO: layout should be changed. +// - 3 xdlops gemm outputs are same, as accumulation of 3 xdlops gemm results. + +// std::enable_if_t && is_same_v && +// is_same_v && is_same_v, +// bool> = true +template +struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // hard code to bf16. Both input reg and mfma type are bf16. + using DataTypeA = bhalf_t; + using DataTypeB = bhalf_t; + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + if constexpr(is_same_v || is_same_v) + { + static_assert(is_same_v, + "ComputeTypeA and ComputeTypeB must be same when one of them is tf32"); + } + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("MPerBlock: %d, NPerBlock: %d, KPerBlock: %d, A_K0: %d, B_K0: %d, A_K1: %d, " + // "B_K1: %d, MWaves: %d, NWaves: %d, WaveSize: %d, KPerThread: %d\n", + // MPerBlock, + // NPerBlock, + // KPerBlock, + // A_K0, + // B_K0, + // A_K1, + // B_K1, + // MWaves, + // NWaves, + // WaveSize, + // KPerThread); + // printf("a thead size: %ld; b thead size: %ld\n", + // a_thread_desc_.GetElementSpaceSize().value, + // b_thread_desc_.GetElementSpaceSize().value); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("a_block_buf: %d , %d, %d, %d \n", a_block_buf[0], a_block_buf[1], + // a_block_buf[2], a_block_buf[3]); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("b_block_buf: %d , %d, %d, %d \n", b_block_buf[0], b_block_buf[1], + // b_block_buf[2], b_block_buf[3]); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_thread_buf_big = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_big = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_thread_buf_small = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_small = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + static_for<0, a_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { + a_thread_buf_big(Number{}) = type_convert(a_thread_buf[i]); + a_thread_buf_small(Number{}) = type_convert( + a_thread_buf[i] - type_convert(a_thread_buf_big[i])); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + static_for<0, b_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { + b_thread_buf_big(Number{}) = + type_convert(b_thread_buf[i]); + b_thread_buf_small(Number{}) = type_convert( + b_thread_buf[i] - type_convert(b_thread_buf_big[i])); + }); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("a thead: %d, %d, %d, %d; b thead: %d, %d, %d, %d\n", + // a_thread_buf[I0], + // a_thread_buf[I1], + // a_thread_buf[I2], + // a_thread_buf[I3], + // b_thread_buf[I0], + // b_thread_buf[I1], + // b_thread_buf[I2], + // b_thread_buf[I3]); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + // why another register buffer? for index? + vector_type a_thread_vec_big; + vector_type b_thread_vec_big; + vector_type a_thread_vec_small; + vector_type b_thread_vec_small; + + static_for<0, KPack, 1>{}([&](auto i) { + auto a_idx = + Number{}; + auto b_idx = + Number{}; + a_thread_vec_big.template AsType()(i) = a_thread_buf_big[a_idx]; + b_thread_vec_big.template AsType()(i) = b_thread_buf_big[b_idx]; + a_thread_vec_small.template AsType()(i) = + a_thread_buf_small[a_idx]; + b_thread_vec_small.template AsType()(i) = + b_thread_buf_small[b_idx]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec_big.template AsType(), + b_thread_vec_small.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec_small.template AsType(), + b_thread_vec_big.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec_big.template AsType(), + b_thread_vec_big.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + template {}; + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 8b9b973b2d6..63ec05fa088 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -146,8 +146,9 @@ struct ReferenceGemm : public device::BaseOperator is_same_v) { // only for tf32 now v_acc += - ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); + // ck::type_convert(ck::type_convert(v_a)) * + // ck::type_convert(ck::type_convert(v_b)); + ck::type_convert(v_a) * ck::type_convert(v_b); } else { From a7d0da05275b3da33a5c4e15724a48df8b89a649 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Fri, 31 Oct 2025 17:47:47 +0800 Subject: [PATCH 03/23] temp push --- .../15_grouped_gemm/grouped_gemm_xdl_fp32.cpp | 2 +- .../grouped_gemm_xdl_fp32_tf32.cpp | 1 - .../run_grouped_gemm_example.inc | 60 +++++++++++- .../tensor_description/cluster_descriptor.hpp | 1 + .../gpu/block/blockwise_gemm_xdlops.hpp | 95 +++++++++++-------- ...hread_group_tensor_slice_transfer_v4r1.hpp | 8 ++ .../device/impl/device_grouped_gemm_xdl.hpp | 21 ++++ .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 3 + .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 9 ++ .../threadwise_tensor_slice_transfer.hpp | 2 + .../threadwise_tensor_slice_transfer_v3r1.hpp | 11 ++- include/ck/utility/type_convert.hpp | 2 + 12 files changed, 170 insertions(+), 45 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index fb047ae364b..571536344cb 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp index cd9a3435faa..70f758f904b 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -50,7 +50,6 @@ using BElementOp = PassThrough; using CDEElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; - using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl // clang-format off //######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 1f3c9970814..4add691f3da 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -157,6 +157,50 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); + + if(false) + { + // A is row major, B is column major, C is row major -- here are all row major? + ADataType* host_a_ptr = a_tensors[i].mData.data(); + std::cout << "a_tensor[" << i << "]; shape " << a_tensors[i].mDesc.GetLengths()[0] + << " x " << a_tensors[i].mDesc.GetLengths()[1] + << "; stride: " << a_tensors[i].mDesc.GetStrides()[0] << " x " + << a_tensors[i].mDesc.GetStrides()[1] << std::endl; + std::cout << "\t"; + for(std::size_t row_idx = 0; row_idx < a_tensors[i].mDesc.GetLengths()[0]; row_idx++) + { + std::cout << row_idx << ": "; + for(std::size_t col_idx = 0; col_idx < a_tensors[i].mDesc.GetLengths()[1]; + col_idx++) + { + std::cout << host_a_ptr[row_idx * a_tensors[i].mDesc.GetStrides()[0] + + col_idx * a_tensors[i].mDesc.GetStrides()[1]] + << " "; + } + std::cout << std::endl << "\t"; + } + std::cout << std::endl; + // col major + BDataType* host_b_ptr = b_tensors[i].mData.data(); + std::cout << "b_tensor[" << i << "]; shape " << b_tensors[i].mDesc.GetLengths()[0] + << " x " << b_tensors[i].mDesc.GetLengths()[1] + << "; stride: " << b_tensors[i].mDesc.GetStrides()[0] << " x " + << b_tensors[i].mDesc.GetStrides()[1] << std::endl; + std::cout << "\t"; + for(std::size_t row_idx = 0; row_idx < b_tensors[i].mDesc.GetLengths()[0]; row_idx++) + { + std::cout << row_idx << ": "; + for(std::size_t col_idx = 0; col_idx < b_tensors[i].mDesc.GetLengths()[1]; + col_idx++) + { + std::cout << host_b_ptr[row_idx * b_tensors[i].mDesc.GetStrides()[0] + + col_idx * b_tensors[i].mDesc.GetStrides()[1]] + << " "; + } + std::cout << std::endl << "\t"; + } + std::cout << std::endl; + } #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); @@ -211,7 +255,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(!config.async_hargs) { - invoker.Run(argument, StreamConfig{nullptr, false}); + invoker.Run(argument, StreamConfig{nullptr, false, 1, 0, 1}); + // invoker.Run(argument, StreamConfig{nullptr, false}); } else { @@ -221,7 +266,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co hipEvent_t event0 = nullptr; hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false, 1, 0, 1}, stream0, event0); + // invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); hip_check_error(hipEventSynchronize(event0)); hip_check_error(hipStreamSynchronize(stream0)); @@ -267,6 +313,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(config.time_kernel) { + std::cout << "run with time kernel" << std::endl; + // float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1, 0, + // 1}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -280,6 +329,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co bool run_grouped_gemm_example(int argc, char* argv[]) { + // std::srand(0); // 固定随机种子为0 ProblemSize problem_size; ExecutionConfig config; @@ -313,7 +363,11 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(128 + 64 * i); + problem_size.Ks.push_back(64 + 64 * i); + + // problem_size.Ms.push_back(128 + 128 * i); + // problem_size.Ns.push_back(128 + 128 * i); + // problem_size.Ks.push_back(32 + 32 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index 2dfcad8e042..cb15267dcef 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -14,6 +14,7 @@ __host__ __device__ constexpr auto make_cluster_descriptor( const Lengths& lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) { + // A: <4, 64, 1> <1, 0, 2> --> <64, 4, 1> constexpr index_t ndim_low = Lengths::Size(); const auto reordered_lengths = container_reorder_given_new2old(lengths, order); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 489d9671e54..95f9226445e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -993,15 +993,20 @@ struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf[i] - type_convert(b_thread_buf_big[i])); }); // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("a thead: %d, %d, %d, %d; b thead: %d, %d, %d, %d\n", + // printf("a thead: %.10f, %.10f, %.10f, %.10f\n a big: %.10f, %.10f, %.10f, " + // "%.10f\n a small: %.10f, %.10f, %.10f, %.10f\n", // a_thread_buf[I0], // a_thread_buf[I1], // a_thread_buf[I2], // a_thread_buf[I3], - // b_thread_buf[I0], - // b_thread_buf[I1], - // b_thread_buf[I2], - // b_thread_buf[I3]); + // type_convert(a_thread_buf_big[I0]), + // type_convert(a_thread_buf_big[I1]), + // type_convert(a_thread_buf_big[I2]), + // type_convert(a_thread_buf_big[I3]), + // type_convert(a_thread_buf_small[I0]), + // type_convert(a_thread_buf_small[I1]), + // type_convert(a_thread_buf_small[I2]), + // type_convert(a_thread_buf_small[I3])); static_for<0, KPerThread, KPack>{}([&](auto k) { // why another register buffer? for index? @@ -1100,39 +1105,53 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - if constexpr(is_same_v && is_same_v && - is_same_v && is_same_v) - { - return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - } - else - { - return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - } + // if constexpr(is_same_v && is_same_v && + // is_same_v && is_same_v) + // { + // return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + // } + // else + // { + // return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + // } + + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index bbbe012730f..a5f524248db 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -120,6 +120,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1 const SrcBuffer& src_buf, Number thread_scratch_id = Number{}) { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { @@ -132,6 +134,8 @@ struct ThreadGroupTensorSliceTransfer_v4r1 DstBuffer& dst_buf, Number thread_scratch_id = Number{}) { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { @@ -146,10 +150,13 @@ struct ThreadGroupTensorSliceTransfer_v4r1 DstBuffer& dst_buf, Number thread_scratch_id) { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); RunRead(src_desc, src_buf, thread_scratch_id); RunWrite(dst_desc, dst_buf, thread_scratch_id); } + // move to next slice(for pipeline read/write) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or @@ -171,6 +178,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 private: static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + // A: <4, 64, 1> <1, 0, 2> -- > <64, 4, 1> using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v3r1()) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -572,6 +588,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; + // std::cout << "grid_size_:" << arg.grid_size_ << ", BlockSize:" << BlockSize + // << std::endl; return launch_and_time_kernel( stream_config, kernel, @@ -824,6 +844,7 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg); if(p_arg_) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 1d9b7eb9781..ad06e1e0eb2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -532,6 +532,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t k_batch = 1, const index_t k_idx = 0) { + // use this + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 0cdb7ce2ca0..97344054859 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -57,6 +57,10 @@ struct GridwiseGemmPipeline_v1<1, true, true> CThreadBuffer& c_thread_buf, index_t num_loop) { + // grid buf is global, block buf is lds + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); + // preload data into LDS a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); @@ -83,6 +87,9 @@ struct GridwiseGemmPipeline_v1<1, true, true> b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%d: mainloop\n\n", __LINE__); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_sync_lds(); @@ -99,6 +106,8 @@ struct GridwiseGemmPipeline_v1<1, true, true> // tail { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%d: tail\n\n", __LINE__); block_sync_lds(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 5da9722a4b5..c6abc70d2a8 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1297,6 +1297,8 @@ struct ThreadwiseTensorSliceTransfer_v4 const DstOriginIdx&, DstBuffer& dst_buf) const { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 4a6ed62c0e2..e41ab315a7c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -119,6 +119,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const SrcBuffer& src_buf, Number thread_scratch_id = Number{}) { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); @@ -129,18 +131,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access + // <1, 1, 4> constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); + // <1, 2, 1> constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; // <1, 0, 2> constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + // <2,1,1> <4,1,1> // make forward steps const auto src_forward_steps = generate_tuple( @@ -279,7 +284,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Sequence, Sequence, Sequence>; - + // real load static_for<0, tuple_element_t::Size(), 1>{}( [&](auto v_idx) { constexpr auto VectorLoadSize = @@ -522,6 +527,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstBuffer& dst_buf, Number thread_scratch_id = Number{}) { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // printf("%s\n\n", __PRETTY_FUNCTION__); // if there is transpose, it's done here // if there is oob check, it's done here // TODO move this elsewhere diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 701b2686c74..4feba62eb8f 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -79,6 +79,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); + // return uint16_t((u.int32) >> 16); #endif } @@ -135,6 +136,7 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else + // error: static_cast drops too much precision return uint16_t(static_cast(x) >> 16); #endif } From 951cb30b57bdd4f23bf1f8b05252b33eb3ad8222 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Mon, 3 Nov 2025 16:59:33 +0800 Subject: [PATCH 04/23] self review --- example/15_grouped_gemm/CMakeLists.txt | 2 +- .../15_grouped_gemm/grouped_gemm_xdl_fp32.cpp | 2 +- .../grouped_gemm_xdl_fp32_tf32.cpp | 3 +- .../run_grouped_gemm_example.inc | 54 +------- include/ck/host_utility/device_prop.hpp | 6 +- .../tensor_description/cluster_descriptor.hpp | 1 - .../gpu/block/blockwise_gemm_xdlops.hpp | 122 +++++------------- ...hread_group_tensor_slice_transfer_v4r1.hpp | 7 - .../device/impl/device_grouped_gemm_xdl.hpp | 21 --- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 3 - .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 9 -- .../threadwise_tensor_slice_transfer.hpp | 2 - .../threadwise_tensor_slice_transfer_v3r1.hpp | 7 +- include/ck/utility/type_convert.hpp | 2 - 14 files changed, 47 insertions(+), 194 deletions(-) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index f8448aa2a55..aff46b34ac9 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -34,7 +34,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index 571536344cb..a8a64693ce4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 16, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp index 70f758f904b..885f4b40d71 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -50,13 +50,14 @@ using BElementOp = PassThrough; using CDEElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl // clang-format off //######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 4add691f3da..062917b9782 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -157,50 +157,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); - - if(false) - { - // A is row major, B is column major, C is row major -- here are all row major? - ADataType* host_a_ptr = a_tensors[i].mData.data(); - std::cout << "a_tensor[" << i << "]; shape " << a_tensors[i].mDesc.GetLengths()[0] - << " x " << a_tensors[i].mDesc.GetLengths()[1] - << "; stride: " << a_tensors[i].mDesc.GetStrides()[0] << " x " - << a_tensors[i].mDesc.GetStrides()[1] << std::endl; - std::cout << "\t"; - for(std::size_t row_idx = 0; row_idx < a_tensors[i].mDesc.GetLengths()[0]; row_idx++) - { - std::cout << row_idx << ": "; - for(std::size_t col_idx = 0; col_idx < a_tensors[i].mDesc.GetLengths()[1]; - col_idx++) - { - std::cout << host_a_ptr[row_idx * a_tensors[i].mDesc.GetStrides()[0] + - col_idx * a_tensors[i].mDesc.GetStrides()[1]] - << " "; - } - std::cout << std::endl << "\t"; - } - std::cout << std::endl; - // col major - BDataType* host_b_ptr = b_tensors[i].mData.data(); - std::cout << "b_tensor[" << i << "]; shape " << b_tensors[i].mDesc.GetLengths()[0] - << " x " << b_tensors[i].mDesc.GetLengths()[1] - << "; stride: " << b_tensors[i].mDesc.GetStrides()[0] << " x " - << b_tensors[i].mDesc.GetStrides()[1] << std::endl; - std::cout << "\t"; - for(std::size_t row_idx = 0; row_idx < b_tensors[i].mDesc.GetLengths()[0]; row_idx++) - { - std::cout << row_idx << ": "; - for(std::size_t col_idx = 0; col_idx < b_tensors[i].mDesc.GetLengths()[1]; - col_idx++) - { - std::cout << host_b_ptr[row_idx * b_tensors[i].mDesc.GetStrides()[0] + - col_idx * b_tensors[i].mDesc.GetStrides()[1]] - << " "; - } - std::cout << std::endl << "\t"; - } - std::cout << std::endl; - } #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); @@ -255,8 +211,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(!config.async_hargs) { - invoker.Run(argument, StreamConfig{nullptr, false, 1, 0, 1}); - // invoker.Run(argument, StreamConfig{nullptr, false}); + invoker.Run(argument, StreamConfig{nullptr, false}); } else { @@ -266,8 +221,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co hipEvent_t event0 = nullptr; hip_check_error(hipEventCreate(&event0)); - invoker.Run(argument, StreamConfig{nullptr, false, 1, 0, 1}, stream0, event0); - // invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); + invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); hip_check_error(hipEventSynchronize(event0)); hip_check_error(hipStreamSynchronize(stream0)); @@ -313,9 +267,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co if(config.time_kernel) { - std::cout << "run with time kernel" << std::endl; - // float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 1, 0, - // 1}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -329,7 +280,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co bool run_grouped_gemm_example(int argc, char* argv[]) { - // std::srand(0); // 固定随机种子为0 ProblemSize problem_size; ExecutionConfig config; diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 0c4f056a465..385cf47ed95 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,7 +129,11 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } -inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } +inline bool is_tf32_supported() +{ + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || + is_gfx12_supported() || is_gfx11_supported(); +} } // namespace ck #endif diff --git a/include/ck/tensor_description/cluster_descriptor.hpp b/include/ck/tensor_description/cluster_descriptor.hpp index cb15267dcef..2dfcad8e042 100644 --- a/include/ck/tensor_description/cluster_descriptor.hpp +++ b/include/ck/tensor_description/cluster_descriptor.hpp @@ -14,7 +14,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor( const Lengths& lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) { - // A: <4, 64, 1> <1, 0, 2> --> <64, 4, 1> constexpr index_t ndim_low = Lengths::Size(); const auto reordered_lengths = container_reorder_given_new2old(lengths, order); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 95f9226445e..54454638287 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -926,31 +926,6 @@ struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("MPerBlock: %d, NPerBlock: %d, KPerBlock: %d, A_K0: %d, B_K0: %d, A_K1: %d, " - // "B_K1: %d, MWaves: %d, NWaves: %d, WaveSize: %d, KPerThread: %d\n", - // MPerBlock, - // NPerBlock, - // KPerBlock, - // A_K0, - // B_K0, - // A_K1, - // B_K1, - // MWaves, - // NWaves, - // WaveSize, - // KPerThread); - // printf("a thead size: %ld; b thead size: %ld\n", - // a_thread_desc_.GetElementSpaceSize().value, - // b_thread_desc_.GetElementSpaceSize().value); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("a_block_buf: %d , %d, %d, %d \n", a_block_buf[0], a_block_buf[1], - // a_block_buf[2], a_block_buf[3]); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("b_block_buf: %d , %d, %d, %d \n", b_block_buf[0], b_block_buf[1], - // b_block_buf[2], b_block_buf[3]); auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -992,21 +967,6 @@ struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf_small(Number{}) = type_convert( b_thread_buf[i] - type_convert(b_thread_buf_big[i])); }); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("a thead: %.10f, %.10f, %.10f, %.10f\n a big: %.10f, %.10f, %.10f, " - // "%.10f\n a small: %.10f, %.10f, %.10f, %.10f\n", - // a_thread_buf[I0], - // a_thread_buf[I1], - // a_thread_buf[I2], - // a_thread_buf[I3], - // type_convert(a_thread_buf_big[I0]), - // type_convert(a_thread_buf_big[I1]), - // type_convert(a_thread_buf_big[I2]), - // type_convert(a_thread_buf_big[I3]), - // type_convert(a_thread_buf_small[I0]), - // type_convert(a_thread_buf_small[I1]), - // type_convert(a_thread_buf_small[I2]), - // type_convert(a_thread_buf_small[I3])); static_for<0, KPerThread, KPack>{}([&](auto k) { // why another register buffer? for index? @@ -1105,53 +1065,41 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - // if constexpr(is_same_v && is_same_v && - // is_same_v && is_same_v) - // { - // return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - // } - // else - // { - // return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - // } - - return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; +#if((defined(__gfx12__) || defined(__gfx11__) || defined(__gfx950__)) && !defined(__gfx942__)) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else +#endif + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp index a5f524248db..ae38bd558f8 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -120,8 +120,6 @@ struct ThreadGroupTensorSliceTransfer_v4r1 const SrcBuffer& src_buf, Number thread_scratch_id = Number{}) { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { @@ -134,8 +132,6 @@ struct ThreadGroupTensorSliceTransfer_v4r1 DstBuffer& dst_buf, Number thread_scratch_id = Number{}) { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { @@ -150,13 +146,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1 DstBuffer& dst_buf, Number thread_scratch_id) { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); RunRead(src_desc, src_buf, thread_scratch_id); RunWrite(dst_desc, dst_buf, thread_scratch_id); } - // move to next slice(for pipeline read/write) __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index 34c4f47d1d5..0ae1aa321a1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -40,22 +40,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation c_element_op) { #if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) - // if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) - // printf("GridDim: (%u,%u,%u), BlockIdx: (%u,%u,%u), BlockDim: (%u,%u,%u), ThreadIdx: - // " - // "(%u,%u,%u)\n", - // gridDim.x, - // gridDim.y, - // gridDim.z, - // blockIdx.x, - // blockIdx.y, - // blockIdx.z, - // blockDim.x, - // blockDim.y, - // blockDim.z, - // threadIdx.x, - // threadIdx.y, - // threadIdx.z); if constexpr(GridwiseGemm::template IsValidCompilationParameter<>()) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -588,8 +572,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm; - // std::cout << "grid_size_:" << arg.grid_size_ << ", BlockSize:" << BlockSize - // << std::endl; return launch_and_time_kernel( stream_config, kernel, @@ -844,7 +824,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm(p_arg); if(p_arg_) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index ad06e1e0eb2..1d9b7eb9781 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -532,9 +532,6 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t k_batch = 1, const index_t k_idx = 0) { - // use this - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 97344054859..0cdb7ce2ca0 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -57,10 +57,6 @@ struct GridwiseGemmPipeline_v1<1, true, true> CThreadBuffer& c_thread_buf, index_t num_loop) { - // grid buf is global, block buf is lds - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); - // preload data into LDS a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); @@ -87,9 +83,6 @@ struct GridwiseGemmPipeline_v1<1, true, true> b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%d: mainloop\n\n", __LINE__); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); block_sync_lds(); @@ -106,8 +99,6 @@ struct GridwiseGemmPipeline_v1<1, true, true> // tail { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%d: tail\n\n", __LINE__); block_sync_lds(); blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index c6abc70d2a8..5da9722a4b5 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1297,8 +1297,6 @@ struct ThreadwiseTensorSliceTransfer_v4 const DstOriginIdx&, DstBuffer& dst_buf) const { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index e41ab315a7c..05f32624b13 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -119,8 +119,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 const SrcBuffer& src_buf, Number thread_scratch_id = Number{}) { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, "wrong!"); @@ -131,21 +129,18 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // scalar per access on each dim // TODO: don't use lambda_scalar_per_access - // <1, 1, 4> constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - // <1, 2, 1> constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; // <1, 0, 2> + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // <2,1,1> <4,1,1> // make forward steps const auto src_forward_steps = generate_tuple( diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 4feba62eb8f..701b2686c74 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -79,7 +79,6 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); - // return uint16_t((u.int32) >> 16); #endif } @@ -136,7 +135,6 @@ inline __host__ __device__ constexpr bhalf_t type_convert(float #if CK_USE_RNE_BF16_CONVERSION return bf16_convert_rtn(x); #else - // error: static_cast drops too much precision return uint16_t(static_cast(x) >> 16); #endif } From 007ebbe7f01fa218d0685d168f60880382dac0f8 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Mon, 3 Nov 2025 17:33:09 +0800 Subject: [PATCH 05/23] self review --- example/15_grouped_gemm/CMakeLists.txt | 2 +- .../15_grouped_gemm/grouped_gemm_xdl_fp32.cpp | 2 +- .../run_grouped_gemm_example.inc | 6 +-- include/ck/library/utility/check_err.hpp | 9 ++-- .../gpu/block/blockwise_gemm_xdlops.hpp | 50 +++---------------- ...hread_group_tensor_slice_transfer_v4r1.hpp | 1 - .../threadwise_tensor_slice_transfer_v3r1.hpp | 3 -- .../cpu/reference_gemm.hpp | 15 +----- .../gpu/reference_gemm.hpp | 11 +--- 9 files changed, 17 insertions(+), 82 deletions(-) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index aff46b34ac9..38ac18b32f4 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -34,7 +34,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() -list(APPEND gpu_list_tf32 gfx942 950) +list(APPEND gpu_list_tf32 gfx942 950 1100 1101 1102 1200 1201 12-generic) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp index a8a64693ce4..fb047ae364b 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32.cpp @@ -54,7 +54,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 16, 4, 4, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 2>; // clang-format on #include "run_grouped_gemm_example.inc" diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 062917b9782..1f3c9970814 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -313,11 +313,7 @@ bool run_grouped_gemm_example(int argc, char* argv[]) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); - problem_size.Ks.push_back(64 + 64 * i); - - // problem_size.Ms.push_back(128 + 128 * i); - // problem_size.Ns.push_back(128 + 128 * i); - // problem_size.Ks.push_back(32 + 32 * i); + problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index bf6851314d8..83f5e9b30d4 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -225,10 +225,11 @@ check_err(const Range& out, double rtol = 1e-5, double atol = 3e-6) { - // TODO: change according device - rtol = 5e-3; - atol = 5e-3; - // std::cout << "check_err: rtol = " << rtol << ", atol = " << atol << std::endl; + if(ck::get_device_name() == "gfx942") + { + rtol = 5e-3; + atol = 5e-3; + } if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 54454638287..cfa413684eb 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -307,31 +307,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("MPerBlock: %d, NPerBlock: %d, KPerBlock: %d, A_K0: %d, B_K0: %d, A_K1: %d, " - // "B_K1: %d, MWaves: %d, NWaves: %d, WaveSize: %d, KPerThread: %d\n", - // MPerBlock, - // NPerBlock, - // KPerBlock, - // A_K0, - // B_K0, - // A_K1, - // B_K1, - // MWaves, - // NWaves, - // WaveSize, - // KPerThread); - // printf("a thead size: %ld; b thead size: %ld\n", - // a_thread_desc_.GetElementSpaceSize().value, - // b_thread_desc_.GetElementSpaceSize().value); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("a_block_buf: %d , %d, %d, %d \n", a_block_buf[0], a_block_buf[1], - // a_block_buf[2], a_block_buf[3]); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("b_block_buf: %d , %d, %d, %d \n", b_block_buf[0], b_block_buf[1], - // b_block_buf[2], b_block_buf[3]); auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -355,17 +330,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(I0, I0, I0, I0), b_thread_buf); - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("a thead: %d, %d, %d, %d; b thead: %d, %d, %d, %d\n", - // a_thread_buf[I0], - // a_thread_buf[I1], - // a_thread_buf[I2], - // a_thread_buf[I3], - // b_thread_buf[I0], - // b_thread_buf[I1], - // b_thread_buf[I2], - // b_thread_buf[I3]); - static_for<0, KPerThread, KPack>{}([&](auto k) { vector_type a_thread_vec; vector_type b_thread_vec; @@ -639,14 +603,12 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING }; -// use bfx3 simulate tf32. -// - in/out/acc are all float; -// - one input is separated to 2 bf16 registers. -- TODO: layout should be changed. -// - 3 xdlops gemm outputs are same, as accumulation of 3 xdlops gemm results. - -// std::enable_if_t && is_same_v && -// is_same_v && is_same_v, -// bool> = true +/* + * @brief blockwise gemm xdlops with bf16x3 simulate tf32 + * in/out/acc are all float; + * one input is separated to 2 bf16 registers. + * 3 xdlops gemm output regs are same, as accumulation of 3 xdlops gemm results. + */ template <1, 0, 2> -- > <64, 4, 1> using ThreadwiseTransfer = ThreadwiseTensorSliceTransfer_v3r1, Sequence, Sequence>; - // real load static_for<0, tuple_element_t::Size(), 1>{}( [&](auto v_idx) { constexpr auto VectorLoadSize = @@ -522,8 +521,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstBuffer& dst_buf, Number thread_scratch_id = Number{}) { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // printf("%s\n\n", __PRETTY_FUNCTION__); // if there is transpose, it's done here // if there is oob check, it's done here // TODO move this elsewhere diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 63ec05fa088..660ec64f973 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -142,19 +142,8 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - if constexpr(is_same_v && - is_same_v) - { // only for tf32 now - v_acc += - // ck::type_convert(ck::type_convert(v_a)) * - // ck::type_convert(ck::type_convert(v_b)); - ck::type_convert(v_a) * ck::type_convert(v_b); - } - else - { - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); - } + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index cf30bc7ddad..bcc8b955005 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -80,16 +80,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - if constexpr(is_same_v && - is_same_v) - { // only for tf32 now - v_acc += ck::type_convert(ck::type_convert(v_a)) * - ck::type_convert(ck::type_convert(v_b)); - } - else - { - v_acc += type_convert(v_a) * type_convert(v_b); - } + v_acc += type_convert(v_a) * type_convert(v_b); } // apply c_element_op c_element_op(v_c, v_acc); From 54b1642344419f4e9fa86bcddc4e248cda3d2b6d Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 4 Nov 2025 11:41:30 +0800 Subject: [PATCH 06/23] fix multi-device compile error --- example/15_grouped_gemm/CMakeLists.txt | 2 +- .../grouped_gemm_xdl_fp32_tf32.cpp | 2 +- .../run_grouped_gemm_example.inc | 3 ++- include/ck/library/utility/check_err.hpp | 21 ++++++++++++++----- .../gpu/block/blockwise_gemm_xdlops.hpp | 10 +++++++-- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 38ac18b32f4..7119e568028 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -34,7 +34,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() -list(APPEND gpu_list_tf32 gfx942 950 1100 1101 1102 1200 1201 12-generic) +list(APPEND gpu_list_tf32 gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 gfx12-generic) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp index 885f4b40d71..78eb90e3114 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 1f3c9970814..7099dc82991 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -260,7 +260,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else - pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + pass &= ck::utils::check_err, Tensor, ComputeDataType>( + c_device_tensors[i], c_host_tensors[i]); #endif } } diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 83f5e9b30d4..9723457c1e9 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -19,6 +19,7 @@ #include "ck/host_utility/io.hpp" #include "ck/library/utility/ranges.hpp" +#include "ck/host_utility/device_prop.hpp" namespace ck { namespace utils { @@ -171,6 +172,21 @@ check_err(const Range& out, double rtol = 1e-5, double atol = 3e-5) { +#ifndef __HIPCC_RTC__ + if(ck::get_device_name() == "gfx942") + { + rtol = 5e-3; + atol = 5e-3; + } +#else +// In RTC mode, use preprocessor macros to check device architecture +#if defined(__gfx942__) + { + rtol = 5e-3; + atol = 5e-3; + } +#endif +#endif if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() @@ -225,11 +241,6 @@ check_err(const Range& out, double rtol = 1e-5, double atol = 3e-6) { - if(ck::get_device_name() == "gfx942") - { - rtol = 5e-3; - atol = 5e-3; - } if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index cfa413684eb..c0579da9abe 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1025,10 +1025,16 @@ template constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { + +#if defined(__gfx12__) || defined(__gfx11__) || defined(__gfx950__) + constexpr bool is_supported_arch = true; +#else + constexpr bool is_supported_arch = false; +#endif + if constexpr(LoopSched == LoopScheduler::Default) { -#if((defined(__gfx12__) || defined(__gfx11__) || defined(__gfx950__)) && !defined(__gfx942__)) - if constexpr(is_same_v && is_same_v && + if constexpr(is_supported_arch && is_same_v && is_same_v && is_same_v && is_same_v) { return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 Date: Tue, 4 Nov 2025 13:42:22 +0800 Subject: [PATCH 07/23] bug fix --- include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index c0579da9abe..62e61f2c6da 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -1052,7 +1052,6 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() ComputeTypeB>{}; } else -#endif { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 Date: Wed, 5 Nov 2025 00:31:04 +0800 Subject: [PATCH 08/23] code refactor --- .../gpu/block/blockwise_gemm_xdlops.hpp | 298 +++--------------- 1 file changed, 38 insertions(+), 260 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 62e61f2c6da..b283267f0bd 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -623,266 +623,48 @@ template struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - using ThisThreadBlock = ThisThreadBlock; + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + using Base::c_thread_desc_; + using Base::I0; + using Base::I1; + using Base::KPerThread; // hard code to bf16. Both input reg and mfma type are bf16. using DataTypeA = bhalf_t; using DataTypeB = bhalf_t; - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; - static constexpr auto xdlops_gemm = XdlopsGemm{}; - static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; - - StaticBufferTupleOfVector - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = ThisThreadBlock::GetThreadId(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); - - return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); - - return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); - } - - template - __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - template - __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); - - return make_tuple(Number{}, - Number{}, - waveId_m, - waveId_n, - blk_idx[I0], - blk_idx[I1], - blk_idx[I2], - blk_idx[I3]); - } - - __host__ __device__ BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, - "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, - "wrong!"); - if constexpr(is_same_v || is_same_v) - { - static_assert(is_same_v, - "ComputeTypeA and ComputeTypeB must be same when one of them is tf32"); - } - } - - __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); - - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - - return make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); - } - - __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); - - constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; - constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; - constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; - constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; - - return make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); - } - - __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); - } - - __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() - { - constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = - make_naive_tensor_descriptor_packed(make_tuple(I1, - Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( - c_block_desc_g_m0_n0_m1_n1_m2_n2); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); - - const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); - - return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); - } - - template - __host__ __device__ static constexpr auto - MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) - { - const auto G = c_grid_desc_g_m_n.GetLength(I0); - const auto M = c_grid_desc_g_m_n.GetLength(I1); - const auto N = c_grid_desc_g_m_n.GetLength(I2); - - const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( - c_grid_desc_g_m_n, - make_tuple(make_pass_through_transform(G), - make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), - make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); - - return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( - c_grid_desc_g_m0_n0_m1_n1_m2_n2); - } - - __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple( - make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); - } - - static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); - static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); - template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, @@ -981,10 +763,6 @@ struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - // C[M, N, NumRegXdlops] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - A_K1, - A_K1>; + Base::A_K1, + Base::A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - B_K1, - B_K1>; + Base::B_K1, + Base::B_K1>; - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; }; template Date: Wed, 5 Nov 2025 20:00:54 +0800 Subject: [PATCH 09/23] limit to gfx950 --- example/15_grouped_gemm/CMakeLists.txt | 2 +- include/ck/host_utility/device_prop.hpp | 3 +-- include/ck/library/utility/check_err.hpp | 8 ++++---- .../gpu/block/blockwise_gemm_xdlops.hpp | 14 +++++++++---- profiler/src/profile_grouped_conv_fwd.cpp | 20 +++++++++---------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index 7119e568028..20d9bab7e1c 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -34,7 +34,7 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() -list(APPEND gpu_list_tf32 gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1200 gfx1201 gfx12-generic) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 385cf47ed95..53f4c273994 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -131,8 +131,7 @@ inline bool is_wmma_supported() inline bool is_tf32_supported() { - return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950" || - is_gfx12_supported() || is_gfx11_supported(); + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; } } // namespace ck diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 9723457c1e9..ecf8ecd977b 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -175,15 +175,15 @@ check_err(const Range& out, #ifndef __HIPCC_RTC__ if(ck::get_device_name() == "gfx942") { - rtol = 5e-3; - atol = 5e-3; + rtol = 1e-2; + atol = 1e-2; } #else // In RTC mode, use preprocessor macros to check device architecture #if defined(__gfx942__) { - rtol = 5e-3; - atol = 5e-3; + rtol = 1e-2; + atol = 1e-2; } #endif #endif diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index b283267f0bd..6e795f9b5ff 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -605,9 +605,15 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 /* * @brief blockwise gemm xdlops with bf16x3 simulate tf32 - * in/out/acc are all float; - * one input is separated to 2 bf16 registers. - * 3 xdlops gemm output regs are same, as accumulation of 3 xdlops gemm results. + * in/out/acc are all float; + * step: + * separate one input to 2 bf16 registers: + * in_bf16_big = f32_to_bf16(in_f32) + * in_bf16_small = in_f32 - in_bf16_big + * run 3 xdlops gemm: all the accumulator registers of gemm are same. + * out_f32 = A_bf16_big * B_bf16_big + * out_f32 += A_bf16_small * B_bf16_big + * out_f32 += A_bf16_big * B_bf16_small */ template Date: Fri, 7 Nov 2025 12:11:40 +0800 Subject: [PATCH 10/23] enhance gemm gfx942 threshold --- example/01_gemm/common.hpp | 11 +++++++++-- example/01_gemm/run_gemm_example.inc | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index e482953e464..32110759f45 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -356,11 +356,18 @@ inline __host__ __device__ constexpr double get_rtol() } template -inline __host__ __device__ constexpr double get_atol() +inline __host__ __device__ constexpr double get_atol(size_t K = 0) { if constexpr(std::is_same_v && std::is_same_v) { - return 1e-3; + if(K == 0) + { + throw std::runtime_error("K is 0"); + } + // tf32 has 10 mantissa bits, so epsilon = 2^(-10) = 1/1024 + constexpr double epsilon_tf32 = 1.0 / 1024.0; // 2^(-10) + constexpr double epsilon_fp32 = std::numeric_limits::epsilon(); + return (epsilon_tf32 - epsilon_fp32) * K; } else if constexpr(std::is_same_v) { diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7fb0c1e812e..cdabcc9fa82 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -212,7 +212,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) c_m_n_host_result, "Error: Incorrect results!", get_rtol(), - get_atol()); + get_atol(K)); #endif } From 81d248dd0ac5e2444e47b568d6ca1fc9bcb64f08 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Mon, 10 Nov 2025 11:51:22 +0800 Subject: [PATCH 11/23] lower change from blockwise to warpwise --- example/01_gemm/CMakeLists.txt | 2 +- example/09_convnd_fwd/CMakeLists.txt | 2 +- .../gpu/block/blockwise_gemm_xdlops.hpp | 438 +++++++++--------- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 52 +++ include/ck/utility/amd_xdlops.hpp | 139 +++++- 5 files changed, 415 insertions(+), 218 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 03bde864214..a9ae0b2a6aa 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 4f174bfcbb2..791d81e2645 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() -list(APPEND gpu_list_tf32 gfx942) +list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 6e795f9b5ff..1a53f374843 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -307,6 +307,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { + // if(threadIdx.x == 0 && blockIdx.x == 0) + // { + // printf("BlockwiseGemmXdlops: Run\n"); + // } auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( @@ -615,183 +619,187 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 * out_f32 += A_bf16_small * B_bf16_big * out_f32 += A_bf16_big * B_bf16_small */ -template -struct BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 - : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 -{ - using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - - using Base::a_block_desc_m0_m1_m2_k; - using Base::b_block_desc_n0_n1_n2_k; - using Base::c_thread_desc_; - using Base::I0; - using Base::I1; - using Base::KPerThread; - - // hard code to bf16. Both input reg and mfma type are bf16. - using DataTypeA = bhalf_t; - using DataTypeB = bhalf_t; - - static constexpr auto xdlops_gemm = - XdlopsGemm{}; - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - auto a_thread_buf_big = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf_big = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - auto a_thread_buf_small = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf_small = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - static_for<0, a_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { - a_thread_buf_big(Number{}) = type_convert(a_thread_buf[i]); - a_thread_buf_small(Number{}) = type_convert( - a_thread_buf[i] - type_convert(a_thread_buf_big[i])); - }); - - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - static_for<0, b_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { - b_thread_buf_big(Number{}) = - type_convert(b_thread_buf[i]); - b_thread_buf_small(Number{}) = type_convert( - b_thread_buf[i] - type_convert(b_thread_buf_big[i])); - }); - - static_for<0, KPerThread, KPack>{}([&](auto k) { - // why another register buffer? for index? - vector_type a_thread_vec_big; - vector_type b_thread_vec_big; - vector_type a_thread_vec_small; - vector_type b_thread_vec_small; - - static_for<0, KPack, 1>{}([&](auto i) { - auto a_idx = - Number{}; - auto b_idx = - Number{}; - a_thread_vec_big.template AsType()(i) = a_thread_buf_big[a_idx]; - b_thread_vec_big.template AsType()(i) = b_thread_buf_big[b_idx]; - a_thread_vec_small.template AsType()(i) = - a_thread_buf_small[a_idx]; - b_thread_vec_small.template AsType()(i) = - b_thread_buf_small[b_idx]; - }); - - using mfma_input_type_a = - typename vector_type::type; - using mfma_input_type_b = - typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - xdlops_gemm.Run(a_thread_vec_big.template AsType(), - b_thread_vec_small.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec_small.template AsType(), - b_thread_vec_big.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - xdlops_gemm.Run(a_thread_vec_big.template AsType(), - b_thread_vec_big.template AsType(), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); - } - - protected: - // A[M0, M1, M2, KPerThread] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - - // B[N0, N1, N2, KPerThread] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - Base::A_K1, - Base::A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - Base::B_K1, - Base::B_K1>; - - AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; -}; +// template +// struct BlockwiseGemmXdlops_BF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +// : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +// { +// using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +// using Base::a_block_desc_m0_m1_m2_k; +// using Base::b_block_desc_n0_n1_n2_k; +// using Base::c_thread_desc_; +// using Base::I0; +// using Base::I1; +// using Base::KPerThread; + +// // hard code to bf16. Both input reg and mfma type are bf16. +// using DataTypeA = bhalf_t; +// using DataTypeB = bhalf_t; + +// static constexpr auto xdlops_gemm = +// XdlopsGemm{}; + +// template +// __device__ void Run(const ABlockBuffer& a_block_buf, +// const BBlockBuffer& b_block_buf, +// CThreadBuffer& c_thread_buf) const +// { +// // if(threadIdx.x == 0 && blockIdx.x == 0) +// // { +// // printf("BlockwiseGemmXdlops_bf16x3: Run\n"); +// // } +// auto a_thread_buf = make_static_buffer( +// a_thread_desc_.GetElementSpaceSize()); +// auto b_thread_buf = make_static_buffer( +// b_thread_desc_.GetElementSpaceSize()); +// auto a_thread_buf_big = make_static_buffer( +// a_thread_desc_.GetElementSpaceSize()); +// auto b_thread_buf_big = make_static_buffer( +// b_thread_desc_.GetElementSpaceSize()); +// auto a_thread_buf_small = make_static_buffer( +// a_thread_desc_.GetElementSpaceSize()); +// auto b_thread_buf_small = make_static_buffer( +// b_thread_desc_.GetElementSpaceSize()); + +// static_for<0, MRepeat, 1>{}([&](auto m0) { +// // read A +// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, +// make_tuple(m0, I0, I0, I0), +// a_block_buf, +// a_thread_desc_, +// make_tuple(I0, I0, I0, I0), +// a_thread_buf); +// static_for<0, a_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { +// a_thread_buf_big(Number{}) = type_convert(a_thread_buf[i]); +// a_thread_buf_small(Number{}) = type_convert( +// a_thread_buf[i] - type_convert(a_thread_buf_big[i])); +// }); + +// static_for<0, NRepeat, 1>{}([&](auto n0) { +// // read B +// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, +// make_tuple(n0, I0, I0, I0), +// b_block_buf, +// b_thread_desc_, +// make_tuple(I0, I0, I0, I0), +// b_thread_buf); +// static_for<0, b_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { +// b_thread_buf_big(Number{}) = +// type_convert(b_thread_buf[i]); +// b_thread_buf_small(Number{}) = type_convert( +// b_thread_buf[i] - type_convert(b_thread_buf_big[i])); +// }); + +// static_for<0, KPerThread, KPack>{}([&](auto k) { +// // why another register buffer? for index? +// vector_type a_thread_vec_big; +// vector_type b_thread_vec_big; +// vector_type a_thread_vec_small; +// vector_type b_thread_vec_small; + +// static_for<0, KPack, 1>{}([&](auto i) { +// auto a_idx = +// Number{}; +// auto b_idx = +// Number{}; +// a_thread_vec_big.template AsType()(i) = a_thread_buf_big[a_idx]; +// b_thread_vec_big.template AsType()(i) = b_thread_buf_big[b_idx]; +// a_thread_vec_small.template AsType()(i) = +// a_thread_buf_small[a_idx]; +// b_thread_vec_small.template AsType()(i) = +// b_thread_buf_small[b_idx]; +// }); + +// using mfma_input_type_a = +// typename vector_type::type; +// using mfma_input_type_b = +// typename vector_type::type; + +// constexpr index_t c_offset = +// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + +// xdlops_gemm.Run(a_thread_vec_big.template AsType(), +// b_thread_vec_small.template AsType(), +// c_thread_buf.GetVectorTypeReference(Number{})); +// xdlops_gemm.Run(a_thread_vec_small.template AsType(), +// b_thread_vec_big.template AsType(), +// c_thread_buf.GetVectorTypeReference(Number{})); +// xdlops_gemm.Run(a_thread_vec_big.template AsType(), +// b_thread_vec_big.template AsType(), +// c_thread_buf.GetVectorTypeReference(Number{})); +// }); +// }); +// }); +// } + +// protected: +// // A[M0, M1, M2, KPerThread] +// static constexpr auto a_thread_desc_ = +// make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + +// // B[N0, N1, N2, KPerThread] +// static constexpr auto b_thread_desc_ = +// make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + +// using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, +// Sequence<0, 1, 2, 3>, +// 3, +// Base::A_K1, +// Base::A_K1>; + +// using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, +// Sequence<0, 1, 2, 3>, +// 3, +// Base::B_K1, +// Base::B_K1>; + +// AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; +// BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; +// }; template && is_same_v && - is_same_v && is_same_v) - { - return BlockwiseGemmXdlopsBF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - } - else - { - return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - } + // if constexpr(is_supported_arch && is_same_v && is_same_v && + // is_same_v && is_same_v) + // { + // return BlockwiseGemmXdlops_BF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + // } + // else + // { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + // } } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index ce2d9299f90..cb97d89dbee 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -82,6 +82,8 @@ enum struct MfmaInstr mfma_scale_f32_16x16x128f8f6f4, mfma_f32_16x16x8xf32, // tf32 mfma_f32_32x32x4xf32, + mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 + mfma_f32_32x32x16xf32, // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -996,6 +998,7 @@ struct mfma_type template <> struct mfma_type { + // gfx942 specific configuration static constexpr index_t wave_size = 64; // fixed static constexpr index_t m_per_blk = 32; // from the instruction static constexpr index_t n_per_blk = 32; // from the instruction @@ -1015,6 +1018,51 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16xf32::Run(a, b, reg_c); + } +}; +template <> +struct mfma_type +{ + // gfx950 specific: use bf16x3 simulate tf32 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32xf32::Run(a, b, reg_c); + } +}; + // gfx11 struct mfma_type_gfx11_base { @@ -1281,6 +1329,8 @@ struct MfmaSelector return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_32x32x4xf32; #else @@ -1295,6 +1345,8 @@ struct MfmaSelector return MfmaInstr::wmma_unsupport_16x16_gfx12; #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32xf32; #elif defined(__gfx942__) return MfmaInstr::mfma_f32_16x16x8xf32; #else diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 7ff8e6b057a..4b940d59da8 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1636,7 +1636,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; -/******************* tf32 *************************************/ +/******************* tf32 on gfx942 *************************************/ template struct intrin_mfma_f32_16x16x8xf32; @@ -1676,5 +1676,142 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> #endif } }; +/******************* tf32 on gfx942 end *********************************/ + +/******************* tf32 on gfx950 *************************************/ +/* bf16x3 simulate tf32: input/output/accumulator are all float; */ +/* step: */ +/* 1. separate one input to 2 bf16 registers: */ +/* in_bf16_big = f32_to_bf16(in_f32) */ +/* in_bf16_small = in_f32 - in_bf16_big */ +/* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */ +/* out_f32 = A_bf16_big * B_bf16_big */ +/* out_f32 += A_bf16_small * B_bf16_big */ +/* out_f32 += A_bf16_big * B_bf16_small */ +/************************************************************************/ +template +struct intrin_mfma_f32_16x16x32xf32; + +template <> +struct intrin_mfma_f32_16x16x32xf32<16, 16> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + static_for<0, 8, 1>{}([&](auto k) { + using IK = Number; + v_reg_a_bf16_big.template AsType()(k) = + type_convert(reg_a_v.template AsType()[IK{}]); + v_reg_a_bf16_small.template AsType()(k) = type_convert( + reg_a_v.template AsType()[IK{}] - + type_convert(v_reg_a_bf16_big.template AsType()[IK{}])); + v_reg_b_bf16_big.template AsType()(k) = + type_convert(reg_b_v.template AsType()[IK{}]); + v_reg_b_bf16_small.template AsType()(k) = type_convert( + reg_b_v.template AsType()[IK{}] - + type_convert(v_reg_b_bf16_big.template AsType()[IK{}])); + }); + + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_32x32x16xf32; + +template <> +struct intrin_mfma_f32_32x32x16xf32<32, 32> +{ + template + __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + static_for<0, 8, 1>{}([&](auto k) { + using IK = Number; + v_reg_a_bf16_big.template AsType()(k) = + type_convert(reg_a_v.template AsType()[IK{}]); + v_reg_a_bf16_small.template AsType()(k) = type_convert( + reg_a_v.template AsType()[IK{}] - + type_convert(v_reg_a_bf16_big.template AsType()[IK{}])); + v_reg_b_bf16_big.template AsType()(k) = + type_convert(reg_b_v.template AsType()[IK{}]); + v_reg_b_bf16_small.template AsType()(k) = type_convert( + reg_b_v.template AsType()[IK{}] - + type_convert(v_reg_b_bf16_big.template AsType()[IK{}])); + }); + + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); + reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c.template AsType()[I0{}], + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; +/******************* tf32 on gfx950 end ************************************/ } // namespace ck From 51688520da7edf702264f0f7af9c11136cf002a5 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 11 Nov 2025 10:54:51 +0800 Subject: [PATCH 12/23] refact codes --- .../gpu/block/blockwise_gemm_xdlops.hpp | 220 ------------------ .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 31 ++- include/ck/utility/amd_xdlops.hpp | 136 +++++------ 3 files changed, 96 insertions(+), 291 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 1a53f374843..217c4e70753 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -607,200 +607,6 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING }; -/* - * @brief blockwise gemm xdlops with bf16x3 simulate tf32 - * in/out/acc are all float; - * step: - * separate one input to 2 bf16 registers: - * in_bf16_big = f32_to_bf16(in_f32) - * in_bf16_small = in_f32 - in_bf16_big - * run 3 xdlops gemm: all the accumulator registers of gemm are same. - * out_f32 = A_bf16_big * B_bf16_big - * out_f32 += A_bf16_small * B_bf16_big - * out_f32 += A_bf16_big * B_bf16_small - */ -// template -// struct BlockwiseGemmXdlops_BF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 -// : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 -// { -// using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - -// using Base::a_block_desc_m0_m1_m2_k; -// using Base::b_block_desc_n0_n1_n2_k; -// using Base::c_thread_desc_; -// using Base::I0; -// using Base::I1; -// using Base::KPerThread; - -// // hard code to bf16. Both input reg and mfma type are bf16. -// using DataTypeA = bhalf_t; -// using DataTypeB = bhalf_t; - -// static constexpr auto xdlops_gemm = -// XdlopsGemm{}; - -// template -// __device__ void Run(const ABlockBuffer& a_block_buf, -// const BBlockBuffer& b_block_buf, -// CThreadBuffer& c_thread_buf) const -// { -// // if(threadIdx.x == 0 && blockIdx.x == 0) -// // { -// // printf("BlockwiseGemmXdlops_bf16x3: Run\n"); -// // } -// auto a_thread_buf = make_static_buffer( -// a_thread_desc_.GetElementSpaceSize()); -// auto b_thread_buf = make_static_buffer( -// b_thread_desc_.GetElementSpaceSize()); -// auto a_thread_buf_big = make_static_buffer( -// a_thread_desc_.GetElementSpaceSize()); -// auto b_thread_buf_big = make_static_buffer( -// b_thread_desc_.GetElementSpaceSize()); -// auto a_thread_buf_small = make_static_buffer( -// a_thread_desc_.GetElementSpaceSize()); -// auto b_thread_buf_small = make_static_buffer( -// b_thread_desc_.GetElementSpaceSize()); - -// static_for<0, MRepeat, 1>{}([&](auto m0) { -// // read A -// a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, -// make_tuple(m0, I0, I0, I0), -// a_block_buf, -// a_thread_desc_, -// make_tuple(I0, I0, I0, I0), -// a_thread_buf); -// static_for<0, a_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { -// a_thread_buf_big(Number{}) = type_convert(a_thread_buf[i]); -// a_thread_buf_small(Number{}) = type_convert( -// a_thread_buf[i] - type_convert(a_thread_buf_big[i])); -// }); - -// static_for<0, NRepeat, 1>{}([&](auto n0) { -// // read B -// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, -// make_tuple(n0, I0, I0, I0), -// b_block_buf, -// b_thread_desc_, -// make_tuple(I0, I0, I0, I0), -// b_thread_buf); -// static_for<0, b_thread_desc_.GetElementSpaceSize().value, 1>{}([&](auto i) { -// b_thread_buf_big(Number{}) = -// type_convert(b_thread_buf[i]); -// b_thread_buf_small(Number{}) = type_convert( -// b_thread_buf[i] - type_convert(b_thread_buf_big[i])); -// }); - -// static_for<0, KPerThread, KPack>{}([&](auto k) { -// // why another register buffer? for index? -// vector_type a_thread_vec_big; -// vector_type b_thread_vec_big; -// vector_type a_thread_vec_small; -// vector_type b_thread_vec_small; - -// static_for<0, KPack, 1>{}([&](auto i) { -// auto a_idx = -// Number{}; -// auto b_idx = -// Number{}; -// a_thread_vec_big.template AsType()(i) = a_thread_buf_big[a_idx]; -// b_thread_vec_big.template AsType()(i) = b_thread_buf_big[b_idx]; -// a_thread_vec_small.template AsType()(i) = -// a_thread_buf_small[a_idx]; -// b_thread_vec_small.template AsType()(i) = -// b_thread_buf_small[b_idx]; -// }); - -// using mfma_input_type_a = -// typename vector_type::type; -// using mfma_input_type_b = -// typename vector_type::type; - -// constexpr index_t c_offset = -// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - -// xdlops_gemm.Run(a_thread_vec_big.template AsType(), -// b_thread_vec_small.template AsType(), -// c_thread_buf.GetVectorTypeReference(Number{})); -// xdlops_gemm.Run(a_thread_vec_small.template AsType(), -// b_thread_vec_big.template AsType(), -// c_thread_buf.GetVectorTypeReference(Number{})); -// xdlops_gemm.Run(a_thread_vec_big.template AsType(), -// b_thread_vec_big.template AsType(), -// c_thread_buf.GetVectorTypeReference(Number{})); -// }); -// }); -// }); -// } - -// protected: -// // A[M0, M1, M2, KPerThread] -// static constexpr auto a_thread_desc_ = -// make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - -// // B[N0, N1, N2, KPerThread] -// static constexpr auto b_thread_desc_ = -// make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); - -// using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, -// Sequence<0, 1, 2, 3>, -// 3, -// Base::A_K1, -// Base::A_K1>; - -// using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, -// Sequence<0, 1, 2, 3>, -// 3, -// Base::B_K1, -// Base::B_K1>; - -// AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; -// BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; -// }; - template && is_same_v && - // is_same_v && is_same_v) - // { - // return BlockwiseGemmXdlops_BF16X3_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - // } - // else - // { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - // } } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index cb97d89dbee..5c8e06bb569 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -80,10 +80,10 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, - mfma_f32_16x16x8xf32, // tf32 - mfma_f32_32x32x4xf32, - mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 - mfma_f32_32x32x16xf32, + mfma_f32_16x16x8xf32, // tf32 + mfma_f32_32x32x4xf32, // tf32 on gfx942; bf16x3 simulate tf32 on gfx950 + mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950 + mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950 // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -1323,7 +1323,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; @@ -1339,7 +1339,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; @@ -1354,6 +1354,21 @@ struct MfmaSelector #endif } + template <> + constexpr auto GetMfma() + { +#if defined(__gfx12__) + return MfmaInstr::wmma_unsupport_16x16_gfx12; +#elif defined(__gfx11__) + return MfmaInstr::wmma_unsupport_16x16_gfx11; +#elif defined(__gfx942__) || defined(__gfx950__) + // bf16x3 simulate tf32 on gfx950. real tf32 on gfx942. + return MfmaInstr::mfma_f32_32x32x4xf32; +#else + return MfmaInstr::mfma_f32_32x32x2f32; +#endif + } + template <> constexpr auto GetMfma() { @@ -2237,6 +2252,10 @@ struct XdlopsGemm (is_same::value && KPack <= 8) || ((is_same::value || is_same::value) && KPack < 32) || is_same::value) +#if defined(__gfx950__) + // tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16. + || (is_same::value && KPack <= 4) +#endif ? true : false; static constexpr auto mfma = MfmaSelector +__device__ __forceinline__ void +convert_float_to_bf16_pairs(const vector_type& reg_f32, + vector_type& reg_bf16_big, + vector_type& reg_bf16_small) +{ + static_for<0, VecSize, 1>{}([&](auto k) { + using IK = Number; + reg_bf16_big.template AsType()(k) = + type_convert(reg_f32.template AsType()[IK{}]); + reg_bf16_small.template AsType()(k) = type_convert( + reg_f32.template AsType()[IK{}] - + type_convert(reg_bf16_big.template AsType()[IK{}])); + }); +} +/* */ + // fp32 template struct intrin_mfma_f32_32x32x1f32; @@ -1666,7 +1685,33 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx950__) + // simulation: details is same as tf32 on gfx950 + using I0 = Number<0>; + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + vector_type v_reg_a_bf16_big; + vector_type v_reg_a_bf16_small; + vector_type v_reg_b_bf16_big; + vector_type v_reg_b_bf16_small; + + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); + + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_32x32x4bf16<32, 32>::Run( + v_reg_a_bf16_small.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x4bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x4bf16<32, 32>::Run(v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_big.template AsType()[I0{}], + reg_c); +#elif defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1676,10 +1721,9 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> #endif } }; -/******************* tf32 on gfx942 end *********************************/ -/******************* tf32 on gfx950 *************************************/ -/* bf16x3 simulate tf32: input/output/accumulator are all float; */ +/******************* tf32/xf32 on gfx950 ********************************/ +/* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */ /* step: */ /* 1. separate one input to 2 bf16 registers: */ /* in_bf16_big = f32_to_bf16(in_f32) */ @@ -1708,41 +1752,22 @@ struct intrin_mfma_f32_16x16x32xf32<16, 16> vector_type v_reg_b_bf16_big; vector_type v_reg_b_bf16_small; - static_for<0, 8, 1>{}([&](auto k) { - using IK = Number; - v_reg_a_bf16_big.template AsType()(k) = - type_convert(reg_a_v.template AsType()[IK{}]); - v_reg_a_bf16_small.template AsType()(k) = type_convert( - reg_a_v.template AsType()[IK{}] - - type_convert(v_reg_a_bf16_big.template AsType()[IK{}])); - v_reg_b_bf16_big.template AsType()(k) = - type_convert(reg_b_v.template AsType()[IK{}]); - v_reg_b_bf16_small.template AsType()(k) = type_convert( - reg_b_v.template AsType()[IK{}] - - type_convert(v_reg_b_bf16_big.template AsType()[IK{}])); - }); + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( - v_reg_a_bf16_big.template AsType()[I0{}], - v_reg_b_bf16_small.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( v_reg_a_bf16_small.template AsType()[I0{}], v_reg_b_bf16_big.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_16x16x32bf16<16, 16>::Run( v_reg_a_bf16_big.template AsType()[I0{}], v_reg_b_bf16_big.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); + reg_c); #else ignore = reg_a; ignore = reg_b; @@ -1770,41 +1795,22 @@ struct intrin_mfma_f32_32x32x16xf32<32, 32> vector_type v_reg_b_bf16_big; vector_type v_reg_b_bf16_small; - static_for<0, 8, 1>{}([&](auto k) { - using IK = Number; - v_reg_a_bf16_big.template AsType()(k) = - type_convert(reg_a_v.template AsType()[IK{}]); - v_reg_a_bf16_small.template AsType()(k) = type_convert( - reg_a_v.template AsType()[IK{}] - - type_convert(v_reg_a_bf16_big.template AsType()[IK{}])); - v_reg_b_bf16_big.template AsType()(k) = - type_convert(reg_b_v.template AsType()[IK{}]); - v_reg_b_bf16_small.template AsType()(k) = type_convert( - reg_b_v.template AsType()[IK{}] - - type_convert(v_reg_b_bf16_big.template AsType()[IK{}])); - }); + convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); + convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( - v_reg_a_bf16_big.template AsType()[I0{}], - v_reg_b_bf16_small.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + // Run 3 times: big*big, small*big, big*small + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( v_reg_a_bf16_small.template AsType()[I0{}], v_reg_b_bf16_big.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); - reg_c.template AsType()(I0{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( + v_reg_a_bf16_big.template AsType()[I0{}], + v_reg_b_bf16_small.template AsType()[I0{}], + reg_c); + intrin_mfma_f32_32x32x16bf16<32, 32>::Run( v_reg_a_bf16_big.template AsType()[I0{}], v_reg_b_bf16_big.template AsType()[I0{}], - reg_c.template AsType()[I0{}], - 0, - 0, - 0); + reg_c); #else ignore = reg_a; ignore = reg_b; @@ -1813,5 +1819,5 @@ struct intrin_mfma_f32_32x32x16xf32<32, 32> } }; -/******************* tf32 on gfx950 end ************************************/ +/******************* tf32/xf32 on gfx950 end ************************************/ } // namespace ck From 1d013fa75bf33c984ce9a935f44c17b9a8823824 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 11 Nov 2025 14:14:42 +0800 Subject: [PATCH 13/23] refact codes --- include/ck/library/utility/check_err.hpp | 6 +-- .../gpu/block/blockwise_gemm_xdlops.hpp | 4 -- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 46 ++----------------- include/ck/utility/amd_xdlops.hpp | 32 ++----------- 4 files changed, 10 insertions(+), 78 deletions(-) diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index ecf8ecd977b..f34f91acfc2 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -173,20 +173,20 @@ check_err(const Range& out, double atol = 3e-5) { #ifndef __HIPCC_RTC__ - if(ck::get_device_name() == "gfx942") + if(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") { rtol = 1e-2; atol = 1e-2; } #else // In RTC mode, use preprocessor macros to check device architecture -#if defined(__gfx942__) +#if defined(__gfx942__) || defined(__gfx950__) { rtol = 1e-2; atol = 1e-2; } #endif -#endif +#endif // __HIPCC_RTC__ if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 217c4e70753..7648e9a92d0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -307,10 +307,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - // if(threadIdx.x == 0 && blockIdx.x == 0) - // { - // printf("BlockwiseGemmXdlops: Run\n"); - // } auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 5c8e06bb569..0e1a5aa9ac7 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -80,8 +80,8 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, - mfma_f32_16x16x8xf32, // tf32 - mfma_f32_32x32x4xf32, // tf32 on gfx942; bf16x3 simulate tf32 on gfx950 + mfma_f32_16x16x8xf32, // tf32 on gfx942 + mfma_f32_32x32x4xf32, // tf32 on gfx942 mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950 mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950 // gfx11 @@ -995,29 +995,6 @@ struct mfma_type } }; -template <> -struct mfma_type -{ - // gfx942 specific configuration - static constexpr index_t wave_size = 64; // fixed - static constexpr index_t m_per_blk = 32; // from the instruction - static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t num_threads_per_blk = n_per_blk; // 32 - static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 - static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 - static constexpr index_t group_size = 4; // corresponding to CD rows mapping - static constexpr index_t num_groups_per_blk = 4; - static constexpr index_t num_output_blks = 1; - static constexpr index_t k_per_blk = 2; - static constexpr bool is_k_reduction = true; - // AB register size: 2, CD register size: 16 - template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const - { - intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); - } -}; - template <> struct mfma_type { @@ -1323,7 +1300,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; @@ -1339,7 +1316,7 @@ struct MfmaSelector } template <> - constexpr auto GetMfma() + constexpr auto GetMfma() { #if defined(__gfx12__) return MfmaInstr::wmma_unsupport_16x16_gfx12; @@ -1354,21 +1331,6 @@ struct MfmaSelector #endif } - template <> - constexpr auto GetMfma() - { -#if defined(__gfx12__) - return MfmaInstr::wmma_unsupport_16x16_gfx12; -#elif defined(__gfx11__) - return MfmaInstr::wmma_unsupport_16x16_gfx11; -#elif defined(__gfx942__) || defined(__gfx950__) - // bf16x3 simulate tf32 on gfx950. real tf32 on gfx942. - return MfmaInstr::mfma_f32_32x32x4xf32; -#else - return MfmaInstr::mfma_f32_32x32x2f32; -#endif - } - template <> constexpr auto GetMfma() { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0ad6a37d940..b7c9c1702a3 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1633,7 +1633,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> template __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( bit_cast(reg_a), bit_cast(reg_b), @@ -1665,7 +1665,7 @@ struct intrin_mfma_f32_16x16x8xf32<16, 16> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else @@ -1685,33 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx950__) - // simulation: details is same as tf32 on gfx950 - using I0 = Number<0>; - vector_type reg_a_v(reg_a); - vector_type reg_b_v(reg_b); - - vector_type v_reg_a_bf16_big; - vector_type v_reg_a_bf16_small; - vector_type v_reg_b_bf16_big; - vector_type v_reg_b_bf16_small; - - convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small); - convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small); - - // Run 3 times: big*big, small*big, big*small - intrin_mfma_f32_32x32x4bf16<32, 32>::Run( - v_reg_a_bf16_small.template AsType()[I0{}], - v_reg_b_bf16_big.template AsType()[I0{}], - reg_c); - intrin_mfma_f32_32x32x4bf16<32, 32>::Run( - v_reg_a_bf16_big.template AsType()[I0{}], - v_reg_b_bf16_small.template AsType()[I0{}], - reg_c); - intrin_mfma_f32_32x32x4bf16<32, 32>::Run(v_reg_a_bf16_big.template AsType()[I0{}], - v_reg_b_bf16_big.template AsType()[I0{}], - reg_c); -#elif defined(__gfx94__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else From 697f057b095d83d96e44afbf6ab360775d8c8d6c Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 11 Nov 2025 14:42:18 +0800 Subject: [PATCH 14/23] error fix --- include/ck/utility/amd_xdlops.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b7c9c1702a3..b9d171dbea9 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1633,7 +1633,7 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> template __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) { -#if defined(__gfx942__) +#if defined(__gfx94__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( bit_cast(reg_a), bit_cast(reg_b), @@ -1685,7 +1685,7 @@ struct intrin_mfma_f32_32x32x4xf32<32, 32> template __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) { -#if defined(__gfx94__) +#if defined(__gfx942__) reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); #else From c3aa0a9d70daa0f6e574f56fe24c8079842817d7 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 11 Nov 2025 16:56:15 +0800 Subject: [PATCH 15/23] change threshold --- example/15_grouped_gemm/run_grouped_gemm_example.inc | 10 ++++++++-- include/ck/library/utility/check_err.hpp | 8 ++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 13698f3394c..f557b15e54a 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -97,6 +97,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::size_t flop = 0, num_btype = 0; + double max_acc_value = 1.0; for(std::size_t i = 0; i < gemm_descs.size(); i++) { a_tensors.push_back(Tensor(f_host_tensor_descriptor( @@ -127,6 +128,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co case 1: a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + max_acc_value = 10.0; break; case 2: a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -260,8 +262,12 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else - pass &= ck::utils::check_err, Tensor, ComputeDataType>( - c_device_tensors[i], c_host_tensors[i]); + auto atol = + ck::utils::get_absolute_threshold(max_acc_value); + auto rtol = + ck::utils::get_relative_threshold(max_acc_value); + pass &= ck::utils::check_err( + c_device_tensors[i], c_host_tensors[i], "Error: Incorrect results!", rtol, atol); #endif } } diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index f34f91acfc2..0273c195953 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -169,18 +169,18 @@ typename std::enable_if< check_err(const Range& out, const RefRange& ref, const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-5, - double atol = 3e-5) + double rtol = 5e-4, + double atol = 5e-4) { #ifndef __HIPCC_RTC__ - if(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950") + if(ck::get_device_name() == "gfx942") { rtol = 1e-2; atol = 1e-2; } #else // In RTC mode, use preprocessor macros to check device architecture -#if defined(__gfx942__) || defined(__gfx950__) +#if defined(__gfx942__) { rtol = 1e-2; atol = 1e-2; From 7431aeea6ce56e419a976cf88e7daa23501689a6 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Tue, 11 Nov 2025 17:34:06 +0800 Subject: [PATCH 16/23] bug fix --- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 0e1a5aa9ac7..0817cf98563 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -995,6 +995,28 @@ struct mfma_type } }; +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 32 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 + static constexpr index_t group_size = 4; // corresponding to CD rows mapping + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + // AB register size: 2, CD register size: 16 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); + } +}; + template <> struct mfma_type { From 812a860288de9db66ba81d161179498ebd3750be Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Wed, 12 Nov 2025 16:40:47 +0800 Subject: [PATCH 17/23] fix threshold error --- example/01_gemm/common.hpp | 5 +---- example/09_convnd_fwd/convnd_fwd_common.hpp | 13 ++++++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 32110759f45..8a13f39b1b4 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -364,10 +364,7 @@ inline __host__ __device__ constexpr double get_atol(size_t K = 0) { throw std::runtime_error("K is 0"); } - // tf32 has 10 mantissa bits, so epsilon = 2^(-10) = 1/1024 - constexpr double epsilon_tf32 = 1.0 / 1024.0; // 2^(-10) - constexpr double epsilon_fp32 = std::numeric_limits::epsilon(); - return (epsilon_tf32 - epsilon_fp32) * K; + return 1e-3 * std::log2(K); } else if constexpr(std::is_same_v) { diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index d82b56ec00b..4a02f3ea221 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -73,11 +73,15 @@ inline __host__ __device__ constexpr double get_rtol() } template -inline __host__ __device__ constexpr double get_atol() +inline __host__ __device__ constexpr double get_atol(std::size_t K_reduce) { if constexpr(std::is_same_v && std::is_same_v) { - return 1e-2; + if(K_reduce == 0) + { + throw std::runtime_error("K_reduce is 0"); + } + return 1e-3 * std::log2(K_reduce); } else if constexpr(std::is_same_v) { @@ -145,6 +149,9 @@ bool run_grouped_conv_fwd(bool do_verification, std::cout << "in: " << in.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "out: " << out_host.mDesc << std::endl; + const auto& wei_lengths = wei.mDesc.GetLengths(); + auto K_reduce = + wei_lengths[1] * wei_lengths[2] * wei_lengths[3] * wei_lengths[4] * wei_lengths[5]; switch(init_method) { @@ -263,7 +270,7 @@ bool run_grouped_conv_fwd(bool do_verification, out_host, "Error: incorrect results!", get_rtol(), - get_atol()); + get_atol(K_reduce)); } return true; From 2c96968a4e74f8619b22caabcf3d5ec6de623141 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 11:19:25 +0800 Subject: [PATCH 18/23] change host reference implement to same as device --- example/01_gemm/common.hpp | 8 +- example/01_gemm/run_gemm_example.inc | 6 +- example/09_convnd_fwd/convnd_fwd_common.hpp | 19 ++-- .../run_grouped_gemm_example.inc | 17 ++-- include/ck/library/utility/check_err.hpp | 15 ---- .../threadwise_tensor_slice_transfer_v3r1.hpp | 1 + .../cpu/reference_conv_fwd.hpp | 89 ++++++++++++++++--- .../cpu/reference_gemm.hpp | 48 ++++++++-- 8 files changed, 141 insertions(+), 62 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 8a13f39b1b4..e482953e464 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -356,15 +356,11 @@ inline __host__ __device__ constexpr double get_rtol() } template -inline __host__ __device__ constexpr double get_atol(size_t K = 0) +inline __host__ __device__ constexpr double get_atol() { if constexpr(std::is_same_v && std::is_same_v) { - if(K == 0) - { - throw std::runtime_error("K is 0"); - } - return 1e-3 * std::log2(K); + return 1e-3; } else if constexpr(std::is_same_v) { diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index cdabcc9fa82..4fc1884d00d 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -24,6 +24,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto StrideB = problem_size.StrideB; auto StrideC = problem_size.StrideC; + auto device_name = ck::get_device_name(); + auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) @@ -192,7 +194,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op, device_name); std::cout << "Running verification on CPU." << std::endl; ref_invoker.Run(ref_argument); @@ -212,7 +214,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) c_m_n_host_result, "Error: Incorrect results!", get_rtol(), - get_atol(K)); + get_atol()); #endif } diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 4a02f3ea221..80220f6398e 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -73,15 +73,11 @@ inline __host__ __device__ constexpr double get_rtol() } template -inline __host__ __device__ constexpr double get_atol(std::size_t K_reduce) +inline __host__ __device__ constexpr double get_atol() { if constexpr(std::is_same_v && std::is_same_v) { - if(K_reduce == 0) - { - throw std::runtime_error("K_reduce is 0"); - } - return 1e-3 * std::log2(K_reduce); + return 1e-3; } else if constexpr(std::is_same_v) { @@ -149,9 +145,6 @@ bool run_grouped_conv_fwd(bool do_verification, std::cout << "in: " << in.mDesc << std::endl; std::cout << "wei: " << wei.mDesc << std::endl; std::cout << "out: " << out_host.mDesc << std::endl; - const auto& wei_lengths = wei.mDesc.GetLengths(); - auto K_reduce = - wei_lengths[1] * wei_lengths[2] * wei_lengths[3] * wei_lengths[4] * wei_lengths[5]; switch(init_method) { @@ -260,7 +253,11 @@ bool run_grouped_conv_fwd(bool do_verification, conv_param.input_right_pads_, in_element_op, wei_element_op, - out_element_op); + out_element_op, + {}, + {}, + {}, + ck::get_device_name()); ref_invoker.Run(ref_argument); @@ -270,7 +267,7 @@ bool run_grouped_conv_fwd(bool do_verification, out_host, "Error: incorrect results!", get_rtol(), - get_atol(K_reduce)); + get_atol()); } return true; diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index f557b15e54a..11522791851 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -39,6 +39,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co #endif int group_count = problem_size.group_count; + auto device_name = ck::get_device_name(); + // GEMM shape std::vector gemm_descs; std::vector p_a, p_b; @@ -97,7 +99,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co std::size_t flop = 0, num_btype = 0; - double max_acc_value = 1.0; for(std::size_t i = 0; i < gemm_descs.size(); i++) { a_tensors.push_back(Tensor(f_host_tensor_descriptor( @@ -128,7 +129,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co case 1: a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); - max_acc_value = 10.0; break; case 2: a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); @@ -239,7 +239,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co AElementOp, BElementOp, CDEElementOp, - ComputeDataType, ComputeDataType>; for(std::size_t i = 0; i < gemm_descs.size(); i++) @@ -253,7 +252,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_host_tensors[i], a_element_op, b_element_op, - c_element_op); + c_element_op, + device_name); ref_invoker.Run(ref_argument); @@ -262,12 +262,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else - auto atol = - ck::utils::get_absolute_threshold(max_acc_value); - auto rtol = - ck::utils::get_relative_threshold(max_acc_value); - pass &= ck::utils::check_err( - c_device_tensors[i], c_host_tensors[i], "Error: Incorrect results!", rtol, atol); + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); #endif } } diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 0273c195953..9106c09c6c4 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -172,21 +172,6 @@ check_err(const Range& out, double rtol = 5e-4, double atol = 5e-4) { -#ifndef __HIPCC_RTC__ - if(ck::get_device_name() == "gfx942") - { - rtol = 1e-2; - atol = 1e-2; - } -#else -// In RTC mode, use preprocessor macros to check device architecture -#if defined(__gfx942__) - { - rtol = 1e-2; - atol = 1e-2; - } -#endif -#endif // __HIPCC_RTC__ if(out.size() != ref.size()) { std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index c8643a4087f..4a6ed62c0e2 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -279,6 +279,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 Sequence, Sequence, Sequence>; + static_for<0, tuple_element_t::Size(), 1>{}( [&](auto v_idx) { constexpr auto VectorLoadSize = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 573571bc07d..7ef9525dc0b 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -79,7 +79,8 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op, const std::array, NumAElementwiseTensor>& elementwise_a_tensors, const std::array, NumBElementwiseTensor>& elementwise_b_tensors, - const std::array, NumDElementwiseTensor>& elementwise_d_tensors) + const std::array, NumDElementwiseTensor>& elementwise_d_tensors, + const ::std::string& device_name = "unknown") : input_{input}, weight_{weight}, output_{output}, @@ -92,7 +93,8 @@ struct ReferenceConvFwd : public device::BaseOperator in_right_pads_{input_right_pads}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, - out_element_op_{out_element_op} + out_element_op_{out_element_op}, + device_name_{device_name} { } @@ -112,6 +114,7 @@ struct ReferenceConvFwd : public device::BaseOperator InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; + const ::std::string& device_name_; // the device which this conv is compared with }; struct Invoker : public device::BaseInvoker @@ -251,10 +254,39 @@ struct ReferenceConvFwd : public device::BaseOperator x); if constexpr(is_same_v) { - v_acc += ck::type_convert( - ck::type_convert(v_in)) * - ck::type_convert( - ck::type_convert(v_wei)); + if(arg.device_name_ == "gfx942") + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_in_bf16_big = + ck::type_convert(v_in); + ck::bhalf_t v_in_bf16_small = + ck::type_convert( + v_in - type_convert(v_in_bf16_big)); + ck::bhalf_t v_wei_bf16_big = + ck::type_convert(v_wei); + ck::bhalf_t v_wei_bf16_small = + ck::type_convert( + v_wei - type_convert(v_wei_bf16_big)); + + v_acc += ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_small) + + ck::type_convert(v_in_bf16_small) * + ck::type_convert(v_wei_bf16_big) + + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_big); + } + else + { + throw std::runtime_error( + "Unsupported device: " + arg.device_name_ + + " for tf32 computation"); + } } else { @@ -350,10 +382,41 @@ struct ReferenceConvFwd : public device::BaseOperator x); if constexpr(is_same_v) { - v_acc += ck::type_convert( - ck::type_convert(v_in)) * - ck::type_convert( - ck::type_convert(v_wei)); + if(arg.device_name_ == "gfx942") + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_in_bf16_big = + ck::type_convert(v_in); + ck::bhalf_t v_in_bf16_small = + ck::type_convert( + v_in - type_convert(v_in_bf16_big)); + ck::bhalf_t v_wei_bf16_big = + ck::type_convert(v_wei); + ck::bhalf_t v_wei_bf16_small = + ck::type_convert( + v_wei - + type_convert(v_wei_bf16_big)); + + v_acc += + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_small) + + ck::type_convert(v_in_bf16_small) * + ck::type_convert(v_wei_bf16_big) + + ck::type_convert(v_in_bf16_big) * + ck::type_convert(v_wei_bf16_big); + } + else + { + throw std::runtime_error( + "Unsupported device: " + arg.device_name_ + + " for tf32 computation"); + } } else { @@ -463,7 +526,8 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op, const std::array, NumAElementwiseTensor>& elementwise_a_tensors = {}, const std::array, NumBElementwiseTensor>& elementwise_b_tensors = {}, - const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}) + const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}, + const ::std::string& device_name = "unknown") { return Argument{input, weight, @@ -477,7 +541,8 @@ struct ReferenceConvFwd : public device::BaseOperator out_element_op, elementwise_a_tensors, elementwise_b_tensors, - elementwise_d_tensors}; + elementwise_d_tensors, + device_name}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 660ec64f973..173bad2e874 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -39,13 +39,15 @@ struct ReferenceGemm : public device::BaseOperator Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const ::std::string& device_name = "unknown") : a_m_k_{a_m_k}, b_k_n_{b_k_n}, c_m_n_{c_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, - c_element_op_{c_element_op} + c_element_op_{c_element_op}, + device_name_{device_name} { } @@ -56,6 +58,7 @@ struct ReferenceGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; + const ::std::string& device_name_; // the device which this gemm is compared with }; // Invoker @@ -142,8 +145,40 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); + if(is_same_v && is_same_v) + { + if(arg.device_name_ == "gfx942") + { + v_acc += + ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else if(arg.device_name_ == "gfx950") + { + ck::bhalf_t v_a_bf16_big = ck::type_convert(v_a); + ck::bhalf_t v_a_bf16_small = ck::type_convert( + v_a - type_convert(v_a_bf16_big)); + ck::bhalf_t v_b_bf16_big = ck::type_convert(v_b); + ck::bhalf_t v_b_bf16_small = ck::type_convert( + v_b - type_convert(v_b_bf16_big)); + + v_acc += ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_small) + + ck::type_convert(v_a_bf16_small) * + ck::type_convert(v_b_bf16_big) + + ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_big); + } + else + { + throw std::runtime_error("Unsupported device: " + arg.device_name_); + } + } + else + { + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } } CDataType v_c{0}; @@ -180,9 +215,10 @@ struct ReferenceGemm : public device::BaseOperator Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const ::std::string& device_name = "unknown") { - return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op, device_name}; } static auto MakeInvoker() { return Invoker{}; } From e8a679042008e6548ae51f10feb58edb6c14ccb3 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 12:41:07 +0800 Subject: [PATCH 19/23] bug fix --- example/01_gemm/gemm_xdl_fp8.cpp | 6 +++++- example/01_gemm/gemm_xdl_fp8_bf8.cpp | 4 ++++ example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp | 8 ++++---- example/01_gemm/run_gemm_example.inc | 10 +++++----- example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp | 1 - .../reference_tensor_operation/cpu/reference_gemm.hpp | 5 ++++- 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index af9b7978f59..a61d9b738d5 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -5,6 +5,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using ADataType = ck::f8_t; using BDataType = ck::f8_t; using CDataType = ck::f8_t; @@ -32,7 +34,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; - // this instance has been tested working on gfx950 + // this instance has been tested working on gfx950 // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on @@ -64,3 +66,5 @@ int main(int argc, char* argv[]) return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index ce2a466a621..f3c02a3a4a2 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -5,6 +5,8 @@ #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using ADataType = ck::f8_t; using BDataType = ck::bf8_t; using CDataType = ck::half_t; @@ -66,3 +68,5 @@ int main(int argc, char* argv[]) return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp index 9b92fad779b..481e08137a2 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -21,7 +21,7 @@ using BDataType = F32; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F32; -using ComputeDataType = ck::tf32_t; +using ComputeTypeA = ck::tf32_t; using ALayout = Row; using BLayout = Col; @@ -45,7 +45,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, - 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeTypeA>; // clang-format on #else // clang-format off @@ -64,8 +64,8 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + ComputeTypeA, + ComputeTypeA>; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm @@ -213,8 +213,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -244,8 +244,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp index 78eb90e3114..c9a3ede1513 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp @@ -24,7 +24,6 @@ template using S = ck::Sequence; -using F16 = ck::half_t; using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 173bad2e874..875d13abbbd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -145,7 +145,10 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - if(is_same_v && is_same_v) + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v && + is_same_v) { if(arg.device_name_ == "gfx942") { From 673a46fcf2297a0f5d352b6e868dbd433953260b Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 12:46:53 +0800 Subject: [PATCH 20/23] bug fix --- example/01_gemm/gemm_xdl_fp8.cpp | 6 +----- example/01_gemm/gemm_xdl_fp8_bf8.cpp | 4 ---- example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp | 8 ++++---- example/01_gemm/run_gemm_example.inc | 10 +++++----- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index a61d9b738d5..af9b7978f59 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -5,8 +5,6 @@ #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" -#define EXAMPLE_WITH_COMPUTE_DATATYPE - using ADataType = ck::f8_t; using BDataType = ck::f8_t; using CDataType = ck::f8_t; @@ -34,7 +32,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; - // this instance has been tested working on gfx950 + // this instance has been tested working on gfx950 // < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on @@ -66,5 +64,3 @@ int main(int argc, char* argv[]) return !run_gemm_example(argc, argv); } - -#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index f3c02a3a4a2..ce2a466a621 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -5,8 +5,6 @@ #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" -#define EXAMPLE_WITH_COMPUTE_DATATYPE - using ADataType = ck::f8_t; using BDataType = ck::bf8_t; using CDataType = ck::half_t; @@ -68,5 +66,3 @@ int main(int argc, char* argv[]) return !run_gemm_example(argc, argv); } - -#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp index 481e08137a2..9b92fad779b 100644 --- a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -21,7 +21,7 @@ using BDataType = F32; using AccDataType = F32; using CShuffleDataType = F32; using CDataType = F32; -using ComputeTypeA = ck::tf32_t; +using ComputeDataType = ck::tf32_t; using ALayout = Row; using BLayout = Col; @@ -45,7 +45,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, - 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeTypeA>; + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; // clang-format on #else // clang-format off @@ -64,8 +64,8 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + ComputeDataType, + ComputeDataType>; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm @@ -213,8 +213,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -244,8 +244,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; From 60b519c767dec50710faa59ff08733d64b8cf548 Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 14:22:39 +0800 Subject: [PATCH 21/23] code refact --- example/09_convnd_fwd/convnd_fwd_common.hpp | 6 +++- include/ck/library/utility/check_err.hpp | 1 - .../gpu/block/blockwise_gemm_xdlops.hpp | 1 - .../gpu/reference_gemm.hpp | 29 ++++++++++++++++++- 4 files changed, 33 insertions(+), 4 deletions(-) diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 80220f6398e..0f565d53573 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -10,6 +10,9 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" + + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -243,6 +246,7 @@ bool run_grouped_conv_fwd(bool do_verification, 0, ComputeDataType>(); + auto device_name = ck::get_device_name(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, wei, @@ -257,7 +261,7 @@ bool run_grouped_conv_fwd(bool do_verification, {}, {}, {}, - ck::get_device_name()); + device_name); ref_invoker.Run(ref_argument); diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index 9106c09c6c4..fccd5c8e75f 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -19,7 +19,6 @@ #include "ck/host_utility/io.hpp" #include "ck/library/utility/ranges.hpp" -#include "ck/host_utility/device_prop.hpp" namespace ck { namespace utils { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 7648e9a92d0..55015dd30f7 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -619,7 +619,6 @@ template constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { - if constexpr(LoopSched == LoopScheduler::Default) { return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(v_a) * type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { +#if defined(__gfx942__) + v_acc += ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); +#elif defined(__gfx950__) + ck::bhalf_t v_a_bf16_big = ck::type_convert(v_a); + ck::bhalf_t v_a_bf16_small = + ck::type_convert(v_a - type_convert(v_a_bf16_big)); + ck::bhalf_t v_b_bf16_big = ck::type_convert(v_b); + ck::bhalf_t v_b_bf16_small = + ck::type_convert(v_b - type_convert(v_b_bf16_big)); + + v_acc += ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_small) + + ck::type_convert(v_a_bf16_small) * + ck::type_convert(v_b_bf16_big) + + ck::type_convert(v_a_bf16_big) * + ck::type_convert(v_b_bf16_big); +#else + v_acc += type_convert(v_a) * type_convert(v_b); +#endif + } + else + { + v_acc += type_convert(v_a) * type_convert(v_b); + } } // apply c_element_op c_element_op(v_c, v_acc); From d25721ede50757c3d614cb2ed45f7482107feefa Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 14:28:42 +0800 Subject: [PATCH 22/23] fix clang-format fail --- example/09_convnd_fwd/convnd_fwd_common.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 0f565d53573..4b293decb35 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -12,7 +12,6 @@ #include "ck/host_utility/device_prop.hpp" - #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" From 3ffebf4aa0d84f03f13bb5a869d2d5548effae0b Mon Sep 17 00:00:00 2001 From: yingmaolu Date: Thu, 13 Nov 2025 15:21:07 +0800 Subject: [PATCH 23/23] code refine --- example/01_gemm/run_gemm_example.inc | 4 +--- example/09_convnd_fwd/convnd_fwd_common.hpp | 9 +-------- .../15_grouped_gemm/run_grouped_gemm_example.inc | 5 +---- .../cpu/reference_conv_fwd.hpp | 15 +++++++-------- .../cpu/reference_gemm.hpp | 13 ++++++------- 5 files changed, 16 insertions(+), 30 deletions(-) diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 4fc1884d00d..7fb0c1e812e 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -24,8 +24,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto StrideB = problem_size.StrideB; auto StrideC = problem_size.StrideC; - auto device_name = ck::get_device_name(); - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) @@ -194,7 +192,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op, device_name); + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); std::cout << "Running verification on CPU." << std::endl; ref_invoker.Run(ref_argument); diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 4b293decb35..2a972c13eb5 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -10,8 +10,6 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/host_utility/device_prop.hpp" - #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" @@ -245,7 +243,6 @@ bool run_grouped_conv_fwd(bool do_verification, 0, ComputeDataType>(); - auto device_name = ck::get_device_name(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, wei, @@ -256,11 +253,7 @@ bool run_grouped_conv_fwd(bool do_verification, conv_param.input_right_pads_, in_element_op, wei_element_op, - out_element_op, - {}, - {}, - {}, - device_name); + out_element_op); ref_invoker.Run(ref_argument); diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 11522791851..62f0f3673d4 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -39,8 +39,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co #endif int group_count = problem_size.group_count; - auto device_name = ck::get_device_name(); - // GEMM shape std::vector gemm_descs; std::vector p_a, p_b; @@ -252,8 +250,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_host_tensors[i], a_element_op, b_element_op, - c_element_op, - device_name); + c_element_op); ref_invoker.Run(ref_argument); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 7ef9525dc0b..f47ce05cacd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -14,6 +14,8 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/host_utility/device_prop.hpp" + #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/fill.hpp" @@ -79,8 +81,7 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op, const std::array, NumAElementwiseTensor>& elementwise_a_tensors, const std::array, NumBElementwiseTensor>& elementwise_b_tensors, - const std::array, NumDElementwiseTensor>& elementwise_d_tensors, - const ::std::string& device_name = "unknown") + const std::array, NumDElementwiseTensor>& elementwise_d_tensors) : input_{input}, weight_{weight}, output_{output}, @@ -94,7 +95,7 @@ struct ReferenceConvFwd : public device::BaseOperator in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op}, - device_name_{device_name} + device_name_{ck::get_device_name()} { } @@ -114,7 +115,7 @@ struct ReferenceConvFwd : public device::BaseOperator InElementwiseOperation in_element_op_; WeiElementwiseOperation wei_element_op_; OutElementwiseOperation out_element_op_; - const ::std::string& device_name_; // the device which this conv is compared with + ::std::string device_name_; // the device which this conv is compared with }; struct Invoker : public device::BaseInvoker @@ -526,8 +527,7 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op, const std::array, NumAElementwiseTensor>& elementwise_a_tensors = {}, const std::array, NumBElementwiseTensor>& elementwise_b_tensors = {}, - const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}, - const ::std::string& device_name = "unknown") + const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}) { return Argument{input, weight, @@ -541,8 +541,7 @@ struct ReferenceConvFwd : public device::BaseOperator out_element_op, elementwise_a_tensors, elementwise_b_tensors, - elementwise_d_tensors, - device_name}; + elementwise_d_tensors}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 875d13abbbd..c5afebf75d7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -6,6 +6,7 @@ #include #include +#include "ck/host_utility/device_prop.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/library/utility/host_tensor.hpp" @@ -39,15 +40,14 @@ struct ReferenceGemm : public device::BaseOperator Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - const ::std::string& device_name = "unknown") + CElementwiseOperation c_element_op) : a_m_k_{a_m_k}, b_k_n_{b_k_n}, c_m_n_{c_m_n}, a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op}, - device_name_{device_name} + device_name_{ck::get_device_name()} { } @@ -58,7 +58,7 @@ struct ReferenceGemm : public device::BaseOperator AElementwiseOperation a_element_op_; BElementwiseOperation b_element_op_; CElementwiseOperation c_element_op_; - const ::std::string& device_name_; // the device which this gemm is compared with + ::std::string device_name_; // the device which this gemm is compared with }; // Invoker @@ -218,10 +218,9 @@ struct ReferenceGemm : public device::BaseOperator Tensor& c_m_n, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op, - const ::std::string& device_name = "unknown") + CElementwiseOperation c_element_op) { - return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op, device_name}; + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; } static auto MakeInvoker() { return Invoker{}; }