Skip to content

Commit

Permalink
update BASE_SIZE
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Feb 7, 2022
1 parent 86a8d77 commit 61e2048
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions paddle/pten/kernels/funcs/elementwise_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ namespace kps = pten::kps;

#endif

#define MAX_NUM_INPUTS \
3 // Maximum number of inputs supported by kernel, MAX_NUM_INPUTS is equal
// to the maximum value in ElementwiseType
#define BASE_SIZE 1 // To avoid running errors when k = 0

namespace pten {

Expand Down Expand Up @@ -562,13 +560,13 @@ template <typename InT,
bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(

const pten::framework::Array<const _ptr_ InT *__restrict__, MAX_NUM_INPUTS>
&in,
const pten::framework::Array<const _ptr_ InT *__restrict__,
Arity + BASE_SIZE> &in,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int num,
int data_offset,
Functor func) {
InT args[MAX_NUM_INPUTS][VecSize];
InT args[Arity + BASE_SIZE][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];

#pragma unroll
Expand Down Expand Up @@ -598,7 +596,8 @@ template <typename InT,
int NumOuts,
int VecSize>
__global__ void VectorizedElementwiseKernel(
pten::framework::Array<const _ptr_ InT *__restrict__, MAX_NUM_INPUTS> ins,
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int size,
int main_offset,
Expand Down Expand Up @@ -639,7 +638,7 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
std::vector<DenseTensor *> *outs,
Functor func) {
auto numel = (*outs)[0]->numel();
pten::framework::Array<const _ptr_ InT *__restrict__, MAX_NUM_INPUTS>
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins_data;
pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data;

Expand Down

0 comments on commit 61e2048

Please sign in to comment.