Skip to content

Commit

Permalink
cherry-pick 42645
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Jun 6, 2022
1 parent 40a7e0a commit 85a1938
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 145 deletions.
11 changes: 5 additions & 6 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
Expand All @@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(

// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
Expand Down Expand Up @@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);

phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast;

use_broadcast[0] = false;
Expand All @@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);

auto func = AddFunctor<T>();
auto stream = ctx.stream();
Expand Down
139 changes: 40 additions & 99 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,19 @@ struct DimensionsTransform {
}
};

template <typename T, int VecSize, int Rank, bool IsBoundary = false>
template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
kps::ReadDataBc<T, VecSize, 1, 1, IsBoundary>(
dst, src, block_offset, config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
Expand All @@ -210,14 +210,13 @@ template <typename InT,
int Arity,
int NumOuts,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
Functor func) {
Expand All @@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl(
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
}
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
Expand All @@ -254,14 +253,13 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast,
uint32_t numel,
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset,
int tail_tid,
Functor func) {
Expand All @@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
}
Expand All @@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
Expand All @@ -334,15 +328,14 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
void LaunchBroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func,
DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel();
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Expand All @@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
Expand All @@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
Expand All @@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, 0, stream>>>(
ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}

template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void BroadcastKernelForDifferentDimSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;

switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.",
merge_dims.dim_size,
phi::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}

template <ElementwiseType ET,
typename InT,
typename OutT,
Expand Down Expand Up @@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize(
: in_vec_size;
}
int vec_size = std::min(out_vec_size, in_vec_size);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

switch (vec_size) {
case 4: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
ctx, ins, outs, func, merge_dims);
break;
}
case 2: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
ctx, ins, outs, func, merge_dims);
break;
}
case 1: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
ctx, ins, outs, func, merge_dims);
break;
}
default: {
Expand Down
Loading

1 comment on commit 85a1938

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.