Skip to content

Commit

Permalink
Fix a bug in ReadData, ReadDataBc and ReadDataReduce when NX != 1 (#3…
Browse files Browse the repository at this point in the history
…6373) (#36616)

* Fix a bug in ReadData, ReadDataBc and ReadDataReduce when NX != 1
* Update the implement of reduceAnyKernel according to kernel primitive api
  • Loading branch information
AnnaTrainingG authored Oct 22, 2021
1 parent 3090988 commit 6840cf5
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ __device__ __forceinline__ void LoadData(
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(dst, src, block_offset,
config, numel, 1, 1);
config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ __global__ void BroadcastKernelBinary(
// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
arg0, in0, fix, configlists[0], numel, 1, 1);
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
arg1, in1, fix, configlists[1], numel, 1, 1);
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
}
Expand Down
74 changes: 34 additions & 40 deletions paddle/fluid/operators/kernel_primitives/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,16 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
} // namespace details

/**
* @brief Perform unary calculation according to OpFunc. Size of input and
* @brief Perform unary calculation according to OpFunc. Shape of input and
* output are the same.
*
* @template paraments
* InT: Data type of in.
* OutT: Data type of out.
* InT: The data type of in.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* struct XxxFunctor {
Expand All @@ -170,21 +169,20 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
}

/**
* @brief Binary calculation according to OpFunc. Size of The input and output
* @brief Binary calculation according to OpFunc. Shape of The input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns computed by each thread.
* NY: The number of data rows computed by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b) const {
* HOSTDEVICE InT operator()(const InT& a, const InT& b) const {
* return ...;
* }
* };
Expand All @@ -193,7 +191,7 @@ __device__ __forceinline__ void ElementwiseUnary(OutT* out, const InT* in,
* out: The register pointer of out, the size is NX * NY.
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
Expand All @@ -207,21 +205,20 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
}

/**
* @brief Ternary calculation according to OpFunc. Size of input and output
* @brief Ternary calculation according to OpFunc. Shape of input and output
* are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT& a, const InT& b, const InT& c)
* HOSTDEVICE InT operator()(const InT& a, const InT& b, const InT& c)
* const {
* return ...;
* }
Expand All @@ -232,7 +229,7 @@ __device__ __forceinline__ void ElementwiseBinary(OutT* out, const InT* in1,
* in1: The register pointer of fist input, size is NX * NY.
* in2: The register pointer of second input, size is NX * NY.
* in3: The register pointer of third input, size is NX * NY.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, typename OutT, int NX, int NY, int BlockSize,
class OpFunc>
Expand All @@ -247,30 +244,29 @@ __device__ __forceinline__ void ElementwiseTernary(OutT* out, const InT* in1,
}

/**
* @brief Multivariate calculation according to OpFunc. Size of input and output
* are the same.
* @brief Multivariate calculation according to OpFunc. Shape of inputs and
* output are the same.
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* InT: The data type of in1, in2 and in3.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* Arity: The size of ins
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* Arity: The size of ins.
* OpFunc: Compute functor which has an operator() as following:
* template <typename InT, typename OutT>
* template <typename InT>
* struct XxxFunctor {
* HOSTDEVICE OutT operator()(const InT* args) const {
* HOSTDEVICE InT operator()(const InT* args) const {
* return ...;
* }
* };
*
* @param
* out: The register pointer of out, the size is NX * NY.
* ins: An array of pointers consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT, OutT>().
* ins: A pointers of array consisting of multiple inputs.
* compute: Compute function which was declared like OpFunc<InT>().
*/
template <typename InT, typename OutT, int NX, int NY, int BlockSize, int Arity,
class OpFunc>
Expand All @@ -293,13 +289,12 @@ __device__ __forceinline__ void ElementwiseAny(OutT* out, InT (*ins)[NX * NY],
* shape is [NY, NX].
*
* @template paraments
* InT: Data type of in1 and in2.
* OutT: Data type of out.
* InT: The data type of in1 and in2.
* OutT: The data type of out.
* NX: The number of data columns loaded by each thread.
* NY: The number of data rows loaded by each thread.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* OpFunc: Compute functor which has an operator() as following
* template <typename InT, typename OutT>
* struct XxxFunctor {
Expand Down Expand Up @@ -339,8 +334,7 @@ __device__ __forceinline__ void CycleBinary(OutT* out, const T* in1,
* NX: The number of data continuously loaded by each thread.
* NY: The number of data rows loaded by each thread, only NY = 1 was supported.
* BlockSize: Identifies the current device thread index method. For GPU,
* threadIdx.x is used as the thread index, and for xpu, core_id() is used as
* the index. Currently only GPU was supported.
* threadIdx.x is used as the thread index. Currently only GPU was supported.
* ReduceFunctor: Compute functor which has an operator() as following
* template <typename InT>
* struct ReduceFunctor {
Expand Down
Loading

0 comments on commit 6840cf5

Please sign in to comment.