-
Notifications
You must be signed in to change notification settings - Fork 300
Add gridwise GEMM pipeline #89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
2c53a47
a3fa42b
8340275
e775a3d
447b9f2
b6885ac
d55aa1c
f3caf6e
4f59507
b67083c
7553f5b
a3311eb
a301414
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| #ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP | ||
| #define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP | ||
|
|
||
| #include "common_header.hpp" | ||
|
|
||
| namespace ck { | ||
|
|
||
| template <typename AGridDesc, | ||
| typename ABlockDesc, | ||
| typename ABlockTransfer, | ||
| typename AGridBuffer, | ||
| typename ABlockBuffer, | ||
| typename ABlockTransferStep, | ||
| typename BGridDesc, | ||
| typename BBlockDesc, | ||
| typename BBlockTransfer, | ||
| typename BGridBuffer, | ||
| typename BBlockBuffer, | ||
| typename BBlockTransferStep, | ||
| typename BlockwiseGemm, | ||
| typename CThreadBuffer, | ||
| index_t NumPrefetch, | ||
| bool HasMainLoop> | ||
| struct GridwiseGemmPipeline_v1 | ||
| { | ||
| static constexpr auto I0 = Number<0>{}; | ||
| static constexpr auto I1 = Number<1>{}; | ||
|
|
||
| static __device__ void Run(const AGridDesc& a_grid_desc, | ||
| const ABlockDesc& a_block_desc, | ||
| ABlockTransfer& a_blockwise_copy, | ||
| const AGridBuffer& a_grid_buf, | ||
| ABlockBuffer& a_block_buf, | ||
| const ABlockTransferStep& a_block_copy_step, | ||
| const BGridDesc& b_grid_desc, | ||
| const BBlockDesc& b_block_desc, | ||
| BBlockTransfer& b_blockwise_copy, | ||
| const BGridBuffer& b_grid_buf, | ||
| BBlockBuffer& b_block_buf, | ||
| const BBlockTransferStep& b_block_copy_step, | ||
| const BlockwiseGemm& blockwise_gemm, | ||
| CThreadBuffer& c_thread_buf, | ||
| index_t NumLoop) | ||
| { | ||
| static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!"); | ||
|
|
||
| if constexpr(NumPrefetch == 1) | ||
| { | ||
| #if 0 | ||
| // preload data into LDS | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); | ||
|
|
||
| // Initialize C | ||
| c_thread_buf.Clear(); | ||
|
|
||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); | ||
|
|
||
| // main body | ||
| if constexpr(HasMainLoop) | ||
| { | ||
| index_t i = 0; | ||
|
|
||
| do | ||
| { | ||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); | ||
|
|
||
| block_sync_lds(); | ||
|
|
||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); | ||
|
|
||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
|
|
||
| block_sync_lds(); | ||
|
|
||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); | ||
|
|
||
| ++i; | ||
| } while(i < (NumLoop - 1)); | ||
| } | ||
|
|
||
| // tail | ||
| { | ||
| block_sync_lds(); | ||
|
|
||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
| } | ||
| #else | ||
| // preload data into LDS | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); | ||
|
|
||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| // Initialize C | ||
| c_thread_buf.Clear(); | ||
|
|
||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); | ||
|
|
||
| // main body | ||
| if constexpr(HasMainLoop) | ||
| { | ||
| index_t i = 0; | ||
|
|
||
| do | ||
| { | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); | ||
|
|
||
| block_sync_lds(); | ||
|
|
||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); | ||
|
|
||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
|
|
||
| block_sync_lds(); | ||
|
|
||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); | ||
|
|
||
| ++i; | ||
| } while(i < (NumLoop - 1)); | ||
| } | ||
|
|
||
| // tail | ||
| { | ||
| block_sync_lds(); | ||
|
|
||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
| } | ||
| #endif | ||
| } | ||
| else if constexpr(NumPrefetch == 2) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd preferably move these different pipelines to their own template specializations. Things like: template <
typename Params,
index_t NumPrefetch>
struct GridwiseGemmPipeline_v1; // class template declaration
template <typename Params>
struct GridwiseGemmPipeline_v1<Params, 1>
{
// implements 1 stage prefetch
}
template <typename Params>
struct GridwiseGemmPipeline_v1<Params, 2>
{
// implements 2 stage prefetch
}The upside of this approach is that when someone accidentally calls nonexistent 3 stage prefetch it triggers compilation error immediately instead of silently compiles. But of course you already have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| { | ||
| // preload data into LDS | ||
| { | ||
| // Read 0 | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); | ||
|
|
||
| // Move | ||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| // Read 1 | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); | ||
| } | ||
|
|
||
| // Initialize C | ||
| c_thread_buf.Clear(); | ||
|
|
||
| // main body | ||
| if constexpr(HasMainLoop) | ||
| { | ||
| index_t i = 0; | ||
|
|
||
| do | ||
| { | ||
| // Move | ||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| // Write i | ||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); | ||
|
|
||
| // Read i+2 | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Gemm i | ||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Move | ||
| a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); | ||
| b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); | ||
|
|
||
| // Write i+1 | ||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); | ||
|
|
||
| // Read i+3 | ||
| a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); | ||
| b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Gemm i+1 | ||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| i += 2; | ||
| } while(i < (NumLoop - 2)); | ||
| } | ||
|
|
||
| // tail | ||
| { | ||
| // Write NumLoop - 2 | ||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Gemm NumLoop - 2 | ||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Write NumLoop - 1 | ||
| a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); | ||
| b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); | ||
|
|
||
| // Sync | ||
| block_sync_lds(); | ||
|
|
||
| // Gemm NumLoop - 1 | ||
| blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); | ||
| } | ||
| } | ||
| } | ||
| }; | ||
|
|
||
| } // namespace ck | ||
| #endif | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we are supposedly to use snake_case for variables?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed