Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");

constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <index_t BlockSize,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
Expand Down Expand Up @@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}

template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}

template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf);
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}

template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}

template <typename SrcBuffer, typename DstBuffer>
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}

__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
Expand All @@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}

// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}

__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
Expand Down Expand Up @@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;

ThreadwiseTransfer threadwise_transfer_;
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP

#include "common_header.hpp"

namespace ck {

template <typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename BlockwiseGemm,
typename CThreadBuffer,
index_t NumPrefetch,
bool HasMainLoop>
struct GridwiseGemmPipeline_v1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};

static __device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t NumLoop)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we are supposedly to use snake_case for variables?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

{
static_assert(NumPrefetch == 1 || NumPrefetch == 2, "wrong!");

if constexpr(NumPrefetch == 1)
{
#if 0
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);

// Initialize C
c_thread_buf.Clear();

a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);

// main body
if constexpr(HasMainLoop)
{
index_t i = 0;

do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);

block_sync_lds();

b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

block_sync_lds();

a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);

++i;
} while(i < (NumLoop - 1));
}

// tail
{
block_sync_lds();

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);

a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

// Initialize C
c_thread_buf.Clear();

a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);

// main body
if constexpr(HasMainLoop)
{
index_t i = 0;

do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);

block_sync_lds();

b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

block_sync_lds();

a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);

++i;
} while(i < (NumLoop - 1));
}

// tail
{
block_sync_lds();

blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#endif
}
else if constexpr(NumPrefetch == 2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd preferably move these different pipelines to their own template specializations. Things like:

template <
    typename Params,
    index_t NumPrefetch>
struct GridwiseGemmPipeline_v1; // class template declaration

template <typename Params>
struct GridwiseGemmPipeline_v1<Params, 1>
{
     // implements 1 stage prefetch
}

template <typename Params>
struct GridwiseGemmPipeline_v1<Params, 2>
{
     // implements 2 stage prefetch
}

The upside of this approach is that when someone accidentally calls nonexistent 3 stage prefetch it triggers compilation error immediately instead of silently compiles. But of course you already have static_assert to guard against exactly the case. So, just matter of personal preference.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

{
// preload data into LDS
{
// Read 0
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);

// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

// Read 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
}

// Initialize C
c_thread_buf.Clear();

// main body
if constexpr(HasMainLoop)
{
index_t i = 0;

do
{
// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

// Write i
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);

// Read i+2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);

// Sync
block_sync_lds();

// Gemm i
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

// Sync
block_sync_lds();

// Move
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);

// Write i+1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);

// Read i+3
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);

// Sync
block_sync_lds();

// Gemm i+1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

// Sync
block_sync_lds();

i += 2;
} while(i < (NumLoop - 2));
}

// tail
{
// Write NumLoop - 2
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);

// Sync
block_sync_lds();

// Gemm NumLoop - 2
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);

// Sync
block_sync_lds();

// Write NumLoop - 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1);

// Sync
block_sync_lds();

// Gemm NumLoop - 1
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
}
}
};

} // namespace ck
#endif
Loading