Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
{
// sanity check
{
static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() &&
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
"wrong! inconsitent number of transform");

constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
NewLowerDimensionOldVisibleIdss{});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <index_t BlockSize,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool ThreadTransferSrcResetCoordinateAfterRun,
bool ThreadTransferDstResetCoordinateAfterRun>
bool ThreadTransferDstResetCoordinateAfterRun,
index_t NumThreadScratch = 1>
struct BlockwiseTensorSliceTransfer_v4r1
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
Expand Down Expand Up @@ -86,45 +87,39 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}

template <typename SrcBuffer, typename SrcStepHacks>
__device__ void
RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf, src_step_hacks);
threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id);
}
}

template <typename SrcBuffer>
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_desc, src_buf);
threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id);
}
}

template <typename DstBuffer>
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
}
}

template <typename SrcBuffer, typename DstBuffer>
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
DstBuffer& dst_buf,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_desc, src_buf);
RunWrite(dst_desc, dst_buf);
RunRead(src_desc, src_buf, thread_scratch_id);
RunWrite(dst_desc, dst_buf, thread_scratch_id);
}

__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
Expand All @@ -136,21 +131,6 @@ struct BlockwiseTensorSliceTransfer_v4r1
}
}

// SrcMoveSliceWindowStepHack to control index calculation move slice window
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& step,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(
src_desc, step, src_move_slice_window_step_hack);
}
}

__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
{
if(BlockSize == thread_cluster_desc_.GetElementSize() or
Expand Down Expand Up @@ -182,7 +162,8 @@ struct BlockwiseTensorSliceTransfer_v4r1
SrcScalarStrideInVector,
DstScalarStrideInVector,
ThreadTransferSrcResetCoordinateAfterRun,
ThreadTransferDstResetCoordinateAfterRun>;
ThreadTransferDstResetCoordinateAfterRun,
NumThreadScratch>;

ThreadwiseTransfer threadwise_transfer_;
};
Expand Down
Loading