Skip to content

Commit

Permalink
[icafe-31094] Add function comments and instructions to the API
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Sep 23, 2021
1 parent 91f25ee commit 80ee8a9
Show file tree
Hide file tree
Showing 2 changed files with 296 additions and 198 deletions.
236 changes: 152 additions & 84 deletions paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class MPTypeTrait<platform::float16> {
};

/**
* @brief will be used in BlockYReduce, get the index of reduce_num in shared
* memory
* @brief Will be used in BlockYReduce, get the index of reduce_num in shared
* memory.
*/
__device__ __forceinline__ int SharedMemoryIndex(int index) {
return (threadIdx.y + index) * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -83,7 +83,7 @@ __device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
*/

/**
* @brief BlockXReduce reduce along blockDim.x
* @brief BlockXReduce reduce along blockDim.x.
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
Expand Down Expand Up @@ -115,7 +115,7 @@ __device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
}

/**
* @brief BlockYReduce reduce along blockDim.y
* @brief BlockYReduce reduce along blockDim.y.
*/
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
Expand All @@ -135,24 +135,33 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
} // namespace details

/**
* @brief unary function
* @param
* T: data type of in
* OutT: data type of out
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Perform unary calculation according to OpFunc. Size of input and
* output are the same.
*
* @template paraments
* InT: Data type of in.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a) const {
* HOSTDEVICE OutT operator()(const InT& a) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
__device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
Expand All @@ -161,25 +170,35 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const T* in,
}

/**
* @brief binary function, in1 and in2 have same shape
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Binary calculation according to OpFunc. Size of The input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b) const {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param:
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
const T* in2,
__device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
const InT* in2,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
Expand All @@ -188,25 +207,38 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const T* in1,
}

/**
* @brief ternary function, in1, in2 and in3 have same shape
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Ternary calculation according to OpFunc. Size of input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T& a, const T& b, const T& c) const {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
* const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* in3: The register pointer of third input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
const T* in2, const T* in3,
__device__ __forceinline__ void ElementwiseTernary(OutT* out, const InT* in1,
const InT* in2,
const InT* in3,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
Expand All @@ -215,27 +247,36 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const T* in1,
}

/**
* @brief a general function for elementwise computation, all inputs have
* the same shape.
* @param
* T: data type of in1, in2, in3
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor which have an operator() as following
* template <typename T, typename OutT>
* @brief Multivariate calculation according to OpFunc. Size of input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* Arity: The size of ins
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const T* args) const {
* HOSTDEVICE OutT operator()(const InT* args) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* ins: An array of pointers consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize, int Arity,
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity,
class OpFunc>
__device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
__device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY],
OpFunc compute) {
T args[Arity];
InT args[Arity];
#pragma unroll
for (int idx = 0; idx < NX * NY; ++idx) {
#pragma unroll
Expand All @@ -247,15 +288,31 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, T (*ins)[NX * NY],
}

/**
* @brief cycle binary function, in1's shape size is [1, NX], in2's shape size
* is [NY, NX], out's shape size is [NY, NX]
* @brief Binary calculation according to OpFunc. Shape of in1 and in2 are the
* different. Shape of in1 is [1, NX], but in2's shape is [NY, NX], the output
* shape is [NY, NX].
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
*
* @param
* T: data type of in1, in2
* OutT: data type of out
* NX: the cols of in1, in2
* NY: the rows of in1, in2
* BlockSize: the config of this device
* OpFunc: compute functor eg: in1 + in2, in1 - in2
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * 1.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
*/
template <typename T, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
Expand All @@ -272,26 +329,37 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
}

/**
* @brief reduce function, in's shape size is [NX, NY].
* If ReduceMode == kLocalMode then reduce NX, the shape of out is [NY, 1],
* if ReduceMode == kGlobalMode then reduce between different threads, the
* shape of out is [NY, NX]. If reduce_last_dim is false and reduce_num was
* split, BlockYReduce will be called. If reduce_last_dim is true and
* reduce_num was split, BlockXReduce will be called
* @typename
* T: data type of in
* NX: the cols of in
* NY: the rows of in
* BlockSize: the config of this device
* OpFunc: reduce functor, eg: CustomSum, CustomMean in reduce_functor_op.h
* @param:
* reducer: reduce functor, eg: CustomSum<T>()
* reduce_last_dim: if in's last dim need to be reduce then reduce_last_dim =
* true
* @brief The Reduce provides collective methods for computing a parallel
* reduction of items partitioned across a CUDA block and intra thread. When
* ReduceMode == kLocalMode, thread reduce along nx. When ReduceMode ==
* kGlobalMode, use shared memory to reduce between threads.
*
* @template paraments
* T: The type of data.
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* ReduceFunctor: Compute functor which has an operator() as following
* template <typename InT>
* struct ReduceFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
* ReduceMode: Reduce mode, can be kLocalMode, kGlobalMode.
*
* @param
* out: The register pointer of out, the size is NX * NY.
* in: The register pointer of in, the size is NX * NY.
* reducer: Compute function which was declared like ReduceFunctor<InT>().
* reduce_last_dim: if the last dim gets involved in reduction.
*/
template <typename T, int NX, int NY, int BlockSize, class OpFunc,
template <typename T, int NX, int NY, int BlockSize, class ReduceFunctor,
details::ReduceMode Mode>
__device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
__device__ __forceinline__ void Reduce(T* out, const T* in,
ReduceFunctor reducer,
bool reduce_last_dim) {
int block_index = blockDim.y;

Expand All @@ -302,15 +370,15 @@ __device__ __forceinline__ void Reduce(T* out, const T* in, OpFunc reducer,
if (block_reduce_y) {
#pragma unroll
for (int i = 0; i < NY * NX; i++) { // reduce along blockdim.y
out[i] = details::BlockYReduce<T, OpFunc>(out[i], reducer);
out[i] = details::BlockYReduce<T, ReduceFunctor>(out[i], reducer);
}
}

// when last dimension need to be reduced
if (reduce_last_dim) {
#pragma unroll
for (int i = 0; i < NY * NX; i++) { // reduce along blockDim.x
out[i] = details::BlockXReduce<T, OpFunc>(out[i], reducer);
out[i] = details::BlockXReduce<T, ReduceFunctor>(out[i], reducer);
}
}
} else { // else kLocalMode
Expand Down
Loading

0 comments on commit 80ee8a9

Please sign in to comment.