Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f9fde06
add some instance to develop
shaojiewang Apr 14, 2022
09f365a
avoid bank conflicts for wrw for all instance
shaojiewang Apr 17, 2022
0a1e41a
add small K1 test
shaojiewang Apr 18, 2022
0fd4df3
delete some unused instance
shaojiewang Apr 19, 2022
230a41c
reset buffer load oob and ds memcpy to default option
shaojiewang Apr 26, 2022
81ffce2
remove useless instances
shaojiewang Apr 26, 2022
a6ebdb4
remove redandunt space
shaojiewang Apr 26, 2022
c58fc51
remove printf code
shaojiewang Apr 26, 2022
fc17eb4
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 26, 2022
1654fcc
clang-format-10 change
shaojiewang Apr 26, 2022
260dcdb
fix clang format for the other files
shaojiewang Apr 27, 2022
93871ca
add bank length computation
shaojiewang Apr 28, 2022
1e5c712
add template to distinguish the instance that need lds padding for wrw
shaojiewang Apr 28, 2022
eb09227
use rocm5.1 as docker
shaojiewang Apr 29, 2022
2e6eaf6
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 29, 2022
579e8e7
use integer value for GEMM test
Apr 30, 2022
507e149
Merge remote-tracking branch 'origin/develop' into wrw_conv_impr
Apr 30, 2022
60ee26a
Merge remote-tracking branch 'origin/fix_test' into wrw_conv_impr
Apr 30, 2022
5c8ad21
Merge branch 'develop' into wrw_conv_impr
shaojiewang Apr 30, 2022
28bb628
1. move dedicated transform into gridwisegemm's head file. 2. make ld…
shaojiewang May 4, 2022
a7297c2
use a new gridwise gemm header for bwd-weight
shaojiewang May 13, 2022
72b5309
revert gridwise gemm v2r4r2
shaojiewang May 13, 2022
3243112
Merge branch 'develop' into wrw_conv_impr
shaojiewang May 13, 2022
10a0802
change foramt
shaojiewang May 13, 2022
9459a68
rename kernel invoker
shaojiewang May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions include/ck/tensor_description/merge_transform_for_wrw.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#pragma once

#include "common_header.hpp"
#include "multi_index_transform.hpp"

namespace ck {

// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct Merge_v3_division_mod_for_wrw
Comment thread
asroy marked this conversation as resolved.
Outdated
{
static constexpr index_t NDimLow = LowLengths::Size();

using LowerIndex = MultiIndex<NDimLow>;
using UpperIndex = MultiIndex<1>;

using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{}));

using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{})));

LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;

__host__ __device__ constexpr Merge_v3_division_mod_for_wrw() = default;

__host__ __device__ constexpr Merge_v3_division_mod_for_wrw(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))}
{
static_assert(LowerIndex::Size() == NDimLow, "wrong!");
}

__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; }

__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; }

__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }

template <typename LowIdx, typename UpIdx>
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");

index_t tmp = idx_up[Number<0>{}];

// division and mod
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp %= this->low_lengths_scan_[i];
});

idx_low(Number<NDimLow - 1>{}) = tmp;
}

template <typename LowIdxDiff,
typename UpIdxDiff,
typename LowIdx,
typename UpIdx,
index_t Hack>
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_up_diff,
LowIdx& idx_low,
const UpIdx& idx_up_new,
Number<Hack>) const
{
static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 &&
LowIdx::Size() == NDimLow && UpIdx::Size() == 1,
"wrong! inconsistent # of dimension");

constexpr auto I0 = Number<0>{};
constexpr auto INm1 = Number<NDimLow - 1>{};

index_t tmp = idx_up_new[I0];

// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
Comment thread
asroy marked this conversation as resolved.
Outdated
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// //printf("%d, %d, %d\n",
// // __LINE__,
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()),
// idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//}

// static_for<0, NDimLow - 1, 1>{}([&](auto i) {
// const index_t tmp2 = idx_low[i];
// idx_low(i) = tmp / this->low_lengths_scan_[i];
// idx_diff_low(i) = idx_low[i] - tmp2;
// tmp %= this->low_lengths_scan_[i];
//});

// const index_t tmp2 = idx_low[INm1];
// idx_low(INm1) = tmp;
// idx_diff_low(INm1) = idx_low[INm1] - tmp2;

idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_up_diff[I0];

// if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){
// //printf("%d, %d, %d\n", __LINE__, tmp, tmp2);
// printf("%d, %d, %d\n",
// __LINE__,
// static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())),
// static_cast<index_t>(this->low_lengths_scan_.At(Number<1>())));
// printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()),
// idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>()));
//}
}

__host__ __device__ static constexpr bool IsLinearTransform() { return false; }

__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
{
return true;
}

__host__ __device__ static constexpr bool IsKnownAtCompileTime()
{
return is_known_at_compile_time<LowLengths>::value &&
is_known_at_compile_time<LowLengthsScan>::value &&
is_known_at_compile_time<UpLengths>::value;
}

template <typename UpIdx>
__host__ __device__ static constexpr bool
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */)
{
return true;
}

__host__ __device__ void Print() const
{
printf("{");
printf("Merge_v3_direct_division_mod_wrw, ");
printf("low_lengths_ ");
print_multi_index(low_lengths_);
printf("low_lengths_scan_ ");
print_multi_index(low_lengths_scan_);
printf("up_lengths_ ");
print_multi_index(up_lengths_);
printf("}");
}
};

template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v3_division_mod_for_wrw(const LowLengths& low_lengths)
Comment thread
asroy marked this conversation as resolved.
Outdated
{
return Merge_v3_division_mod_for_wrw<LowLengths>{low_lengths};
}

} // namespace ck
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;

using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2<
BlockSize,
Expand Down Expand Up @@ -285,7 +287,9 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferScalarPerVector_NWaveNPerXdl,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
true,
true>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
Expand Down
Loading