-
Notifications
You must be signed in to change notification settings - Fork 300
[Perf][Bwd-weights]Lds re-layout to avoid ds read/write bank conflict and balance ds ops with address calculations #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 19 commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
f9fde06
add some instance to develop
shaojiewang 09f365a
avoid bank conflicts for wrw for all instance
shaojiewang 0a1e41a
add small K1 test
shaojiewang 0fd4df3
delete some unused instance
shaojiewang 230a41c
reset buffer load oob and ds memcpy to default option
shaojiewang 81ffce2
remove useless instances
shaojiewang a6ebdb4
remove redandunt space
shaojiewang c58fc51
remove printf code
shaojiewang fc17eb4
Merge branch 'develop' into wrw_conv_impr
shaojiewang 1654fcc
clang-format-10 change
shaojiewang 260dcdb
fix clang format for the other files
shaojiewang 93871ca
add bank length computation
shaojiewang 1e5c712
add template to distinguish the instance that need lds padding for wrw
shaojiewang eb09227
use rocm5.1 as docker
shaojiewang 2e6eaf6
Merge branch 'develop' into wrw_conv_impr
shaojiewang 579e8e7
use integer value for GEMM test
507e149
Merge remote-tracking branch 'origin/develop' into wrw_conv_impr
60ee26a
Merge remote-tracking branch 'origin/fix_test' into wrw_conv_impr
5c8ad21
Merge branch 'develop' into wrw_conv_impr
shaojiewang 28bb628
1. move dedicated transform into gridwisegemm's head file. 2. make ld…
shaojiewang a7297c2
use a new gridwise gemm header for bwd-weight
shaojiewang 72b5309
revert gridwise gemm v2r4r2
shaojiewang 3243112
Merge branch 'develop' into wrw_conv_impr
shaojiewang 10a0802
change foramt
shaojiewang 9459a68
rename kernel invoker
shaojiewang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
161 changes: 161 additions & 0 deletions
161
include/ck/tensor_description/merge_transform_for_wrw.hpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| #pragma once | ||
|
|
||
| #include "common_header.hpp" | ||
| #include "multi_index_transform.hpp" | ||
|
|
||
| namespace ck { | ||
|
|
||
| // Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to | ||
| // be used for low_lengths that are known at compile time and are power of 2, otherwise performance | ||
| // will be very bad | ||
| template <typename LowLengths> | ||
| struct Merge_v3_division_mod_for_wrw | ||
| { | ||
| static constexpr index_t NDimLow = LowLengths::Size(); | ||
|
|
||
| using LowerIndex = MultiIndex<NDimLow>; | ||
| using UpperIndex = MultiIndex<1>; | ||
|
|
||
| using LowLengthsScan = | ||
| decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); | ||
|
|
||
| using UpLengths = | ||
| decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); | ||
|
|
||
| LowLengths low_lengths_; | ||
| LowLengthsScan low_lengths_scan_; | ||
| UpLengths up_lengths_; | ||
|
|
||
| __host__ __device__ constexpr Merge_v3_division_mod_for_wrw() = default; | ||
|
|
||
| __host__ __device__ constexpr Merge_v3_division_mod_for_wrw(const LowLengths& low_lengths) | ||
| : low_lengths_{low_lengths}, | ||
| low_lengths_scan_{ | ||
| container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, | ||
| up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} | ||
| { | ||
| static_assert(LowerIndex::Size() == NDimLow, "wrong!"); | ||
| } | ||
|
|
||
| __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } | ||
|
|
||
| __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } | ||
|
|
||
| __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } | ||
|
|
||
| template <typename LowIdx, typename UpIdx> | ||
| __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, | ||
| const UpIdx& idx_up) const | ||
| { | ||
| static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, | ||
| "wrong! inconsistent # of dimension"); | ||
|
|
||
| index_t tmp = idx_up[Number<0>{}]; | ||
|
|
||
| // division and mod | ||
| static_for<0, NDimLow - 1, 1>{}([&](auto i) { | ||
| idx_low(i) = tmp / this->low_lengths_scan_[i]; | ||
| tmp %= this->low_lengths_scan_[i]; | ||
| }); | ||
|
|
||
| idx_low(Number<NDimLow - 1>{}) = tmp; | ||
| } | ||
|
|
||
| template <typename LowIdxDiff, | ||
| typename UpIdxDiff, | ||
| typename LowIdx, | ||
| typename UpIdx, | ||
| index_t Hack> | ||
| __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, | ||
| const UpIdxDiff& idx_up_diff, | ||
| LowIdx& idx_low, | ||
| const UpIdx& idx_up_new, | ||
| Number<Hack>) const | ||
| { | ||
| static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && | ||
| LowIdx::Size() == NDimLow && UpIdx::Size() == 1, | ||
| "wrong! inconsistent # of dimension"); | ||
|
|
||
| constexpr auto I0 = Number<0>{}; | ||
| constexpr auto INm1 = Number<NDimLow - 1>{}; | ||
|
|
||
| index_t tmp = idx_up_new[I0]; | ||
|
|
||
| // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){ | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| // //printf("%d, %d, %d\n", __LINE__, tmp, tmp2); | ||
| // //printf("%d, %d, %d\n", | ||
| // // __LINE__, | ||
| // // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())), | ||
| // // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>()))); | ||
| // printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), | ||
| // idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>())); | ||
| //} | ||
|
|
||
| // static_for<0, NDimLow - 1, 1>{}([&](auto i) { | ||
| // const index_t tmp2 = idx_low[i]; | ||
| // idx_low(i) = tmp / this->low_lengths_scan_[i]; | ||
| // idx_diff_low(i) = idx_low[i] - tmp2; | ||
| // tmp %= this->low_lengths_scan_[i]; | ||
| //}); | ||
|
|
||
| // const index_t tmp2 = idx_low[INm1]; | ||
| // idx_low(INm1) = tmp; | ||
| // idx_diff_low(INm1) = idx_low[INm1] - tmp2; | ||
|
|
||
| idx_low(INm1) = tmp; | ||
| idx_diff_low(INm1) = idx_up_diff[I0]; | ||
|
|
||
| // if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0){ | ||
| // //printf("%d, %d, %d\n", __LINE__, tmp, tmp2); | ||
| // printf("%d, %d, %d\n", | ||
| // __LINE__, | ||
| // static_cast<index_t>(this->low_lengths_scan_.At(Number<0>())), | ||
| // static_cast<index_t>(this->low_lengths_scan_.At(Number<1>()))); | ||
| // printf("%d, %d, %d, %d, %d, %d\n", __LINE__, NDimLow, idx_low.At(Number<0>()), | ||
| // idx_low.At(Number<1>()), idx_diff_low.At(Number<0>()), idx_diff_low.At(Number<1>())); | ||
| //} | ||
| } | ||
|
|
||
| __host__ __device__ static constexpr bool IsLinearTransform() { return false; } | ||
|
|
||
| __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() | ||
| { | ||
| return true; | ||
| } | ||
|
|
||
| __host__ __device__ static constexpr bool IsKnownAtCompileTime() | ||
| { | ||
| return is_known_at_compile_time<LowLengths>::value && | ||
| is_known_at_compile_time<LowLengthsScan>::value && | ||
| is_known_at_compile_time<UpLengths>::value; | ||
| } | ||
|
|
||
| template <typename UpIdx> | ||
| __host__ __device__ static constexpr bool | ||
| IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) | ||
| { | ||
| return true; | ||
| } | ||
|
|
||
| __host__ __device__ void Print() const | ||
| { | ||
| printf("{"); | ||
| printf("Merge_v3_direct_division_mod_wrw, "); | ||
| printf("low_lengths_ "); | ||
| print_multi_index(low_lengths_); | ||
| printf("low_lengths_scan_ "); | ||
| print_multi_index(low_lengths_scan_); | ||
| printf("up_lengths_ "); | ||
| print_multi_index(up_lengths_); | ||
| printf("}"); | ||
| } | ||
| }; | ||
|
|
||
| template <typename LowLengths> | ||
| __host__ __device__ constexpr auto | ||
| make_merge_transform_v3_division_mod_for_wrw(const LowLengths& low_lengths) | ||
|
asroy marked this conversation as resolved.
Outdated
|
||
| { | ||
| return Merge_v3_division_mod_for_wrw<LowLengths>{low_lengths}; | ||
| } | ||
|
|
||
| } // namespace ck | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.