Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,20 @@
#ifndef CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP
#define CK_REDUCTION_FUNCTIONS_BLOCKWISE_HPP

#include "data_type.hpp"

#include "reduction_common.hpp"
#include "reduction_operator.hpp"
#include "reduction_functions_accumulate.hpp"

#include "cluster_descriptor.hpp"

namespace ck {

// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
index_t BlockSize,
typename ThreadClusterLengths_M_K,
Expand All @@ -61,8 +65,11 @@ struct PartitionedBlockwiseReduction
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;

template <typename BufferType>
__device__ static void Reduce(BufferType& block_buffer, AccDataType& accuData)
__device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");

constexpr auto cluster_len_shift = get_shift<BufferLength_K>();

const auto thread_cluster_idx =
Expand All @@ -71,6 +78,10 @@ struct PartitionedBlockwiseReduction
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];

work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;

__syncthreads();

static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I());

Expand All @@ -80,21 +91,28 @@ struct PartitionedBlockwiseReduction
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));

AccDataType opData1 = type_convert<AccDataType>(block_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_buffer[offset2]);
AccDataType opData1 = work_buffer[offset1];
Comment thread
asroy marked this conversation as resolved.
AccDataType opData2 = work_buffer[offset2];
Accumulation::Calculate(opData1, opData2);
block_buffer(offset1) = type_convert<AccDataType>(opData1);
work_buffer(offset1) = opData1;
}

__syncthreads();
});

index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));

accuData = type_convert<AccDataType>(block_buffer[offset]);
in_out_value = work_buffer[offset];
};
};

// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template <typename AccDataType,
typename IndexDataType,
index_t BlockSize,
Expand Down Expand Up @@ -123,11 +141,16 @@ struct PartitionedBlockwiseReductionWithIndex

// This interface accumulates on both data values and indices
template <typename BufferType, typename IdxBufferType>
__device__ static void Reduce(BufferType& block_val_buffer,
IdxBufferType& block_idx_buffer,
AccDataType& accuData,
IndexDataType& accuIndex)
__device__ static void Reduce(BufferType& work_val_buffer,
IdxBufferType& work_idx_buffer,
AccDataType& in_out_value,
IndexDataType& in_out_index)
{
static_assert(is_same<typename BufferType::type, AccDataType>{},
"Buffer data type should be consistent as AccDataType!");
static_assert(is_same<typename IdxBufferType::type, IndexDataType>{},
"Buffer data type should be consistent as IndexDataType!");

constexpr auto cluster_len_shift = get_shift<BufferLength_K>();

const auto thread_cluster_idx =
Expand All @@ -136,6 +159,11 @@ struct PartitionedBlockwiseReductionWithIndex
const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}];
const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}];

work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value;
work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index;

__syncthreads();

static_for<0, cluster_len_shift, 1>{}([&](auto I) {
constexpr index_t indOffset = 1 << I();

Expand All @@ -145,24 +173,24 @@ struct PartitionedBlockwiseReductionWithIndex
index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx +
make_tuple(0, indOffset));

AccDataType opData1 = type_convert<AccDataType>(block_val_buffer[offset1]);
AccDataType opData2 = type_convert<AccDataType>(block_val_buffer[offset2]);
IndexDataType currIndex1 = block_idx_buffer[offset1];
IndexDataType currIndex2 = block_idx_buffer[offset2];
AccDataType opData1 = work_val_buffer[offset1];
Comment thread
asroy marked this conversation as resolved.
AccDataType opData2 = work_val_buffer[offset2];
IndexDataType currIndex1 = work_idx_buffer[offset1];
IndexDataType currIndex2 = work_idx_buffer[offset2];

Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2);
block_val_buffer(offset1) = type_convert<AccDataType>(opData1);
block_idx_buffer(offset1) = currIndex1;
work_val_buffer(offset1) = opData1;
work_idx_buffer(offset1) = currIndex1;
}

__syncthreads();
});

index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0));

accuData = type_convert<AccDataType>(block_val_buffer[offset]);
accuIndex = block_idx_buffer[offset];
}
in_out_value = work_val_buffer[offset];
in_out_index = work_idx_buffer[offset];
};
};

}; // end of namespace ck
Expand Down
Loading