Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f9fde06
add some instance to develop
shaojiewang Apr 14, 2022
09f365a
avoid bank conflicts for wrw for all instance
shaojiewang Apr 17, 2022
0a1e41a
add small K1 test
shaojiewang Apr 18, 2022
0fd4df3
delete some unused instance
shaojiewang Apr 19, 2022
230a41c
reset buffer load oob and ds memcpy to default option
shaojiewang Apr 26, 2022
81ffce2
remove useless instances
shaojiewang Apr 26, 2022
a6ebdb4
remove redandunt space
shaojiewang Apr 26, 2022
c58fc51
remove printf code
shaojiewang Apr 26, 2022
fc17eb4
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 26, 2022
1654fcc
clang-format-10 change
shaojiewang Apr 26, 2022
260dcdb
fix clang format for the other files
shaojiewang Apr 27, 2022
93871ca
add bank length computation
shaojiewang Apr 28, 2022
1e5c712
add template to distinguish the instance that need lds padding for wrw
shaojiewang Apr 28, 2022
eb09227
use rocm5.1 as docker
shaojiewang Apr 29, 2022
2e6eaf6
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 29, 2022
579e8e7
use integer value for GEMM test
Apr 30, 2022
507e149
Merge remote-tracking branch 'origin/develop' into wrw_conv_impr
Apr 30, 2022
60ee26a
Merge remote-tracking branch 'origin/fix_test' into wrw_conv_impr
Apr 30, 2022
5c8ad21
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 30, 2022
28bb628
1. move dedicated transform into gridwisegemm's head file. 2. make ld…
shaojiewang May 4, 2022
a7297c2
use a new gridwise gemm header for bwd-weight
shaojiewang May 13, 2022
72b5309
revert gridwise gemm v2r4r2
shaojiewang May 13, 2022
3243112
Merge branch 'develop' into wrw_conv_impr
shaojiewang May 13, 2022
10a0802
change foramt
shaojiewang May 13, 2022
9459a68
rename kernel invoker
shaojiewang May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r4r2.hpp"
#include "gridwise_gemm_xdlops_bwd_weight.hpp"

namespace ck {
namespace tensor_operation {
Expand Down Expand Up @@ -81,6 +81,20 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
static constexpr auto K1Number = Number<K1>{};
static constexpr auto GemmK1Number = K1Number;

// Bytes per 32 lds bank: 32 * 4 bytes
static constexpr auto BankLength = 128;
static constexpr auto ElePerBank = BankLength / sizeof(ADataType);

// M1 & M0
static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1;
static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock;
static constexpr auto ABlockLdsM1Padding = 4;

// N1 & N0
static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1;
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
static constexpr auto BBlockLdsN1Padding = 4;

static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
Expand Down Expand Up @@ -205,7 +219,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;

// GridwiseGemm
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
Expand Down Expand Up @@ -233,6 +247,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Expand All @@ -241,12 +258,17 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;

using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight<
BlockSize,
ADataType, // TODO: distinguish A/B datatype
AccDataType,
Expand Down Expand Up @@ -274,6 +296,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
ABlockTransferDstScalarPerVector_K1,
false, // AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM,
ABlockLdsM1PerBlock,
ABlockLdsM0PerBlock,
ABlockLdsM1Padding,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
Expand All @@ -282,10 +307,15 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
BBlockTransferDstScalarPerVector_K1,
false, // BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN,
BBlockLdsN1PerBlock,
BBlockLdsN0PerBlock,
BBlockLdsN1Padding,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
Expand Down Expand Up @@ -465,7 +495,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Expand All @@ -482,7 +512,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Expand All @@ -502,7 +532,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
{
if(kbatch == 1)
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Expand All @@ -519,7 +549,7 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
}
else
{
const auto kernel = kernel_gemm_xdlops_v2r4r2<
const auto kernel = kernel_gemm_xdlops_bwd_weight<
GridwiseGemmAtomicAdd,
ADataType, // TODO: distiguish A/B datatype
CDataType,
Expand Down
Loading