Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

delete rank switch in broadcast_function.h for compile #42645 #43205

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ template <typename InT, typename OutT, int ShapeSize, int VecSize,
__global__ void BroadcastKernelBinary(
const InT* __restrict__ in0, const InT* __restrict__ in1, OutT* out,
phi::Array<bool, MAX_INPUT_NUM> use_broadcast, uint32_t numel,
phi::Array<kps::details::BroadcastConfig<ShapeSize>, MAX_INPUT_NUM>
configlists,
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists,
int main_tid, int tail_tid, Functor func) {
int fix = blockIdx.x * blockDim.x * VecSize;
int num = tail_tid;
Expand All @@ -65,14 +64,14 @@ __global__ void BroadcastKernelBinary(

// load in0
if (use_broadcast[0]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg0, in0, fix, configlists[0], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg0, in0 + fix, num);
}
// load in1
if (use_broadcast[1]) {
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1, ShapeSize>(
kernel_primitives::ReadDataBc<InT, VecSize, DATA_PER_THREAD, 1>(
arg1, in1, fix, configlists[1], numel);
} else {
kernel_primitives::ReadData<InT, VecSize, 1, 1>(arg1, in1 + fix, num);
Expand Down Expand Up @@ -104,7 +103,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
int main_tid = numel / (data_per_thread * vec_size * threads);
int tail_tid = numel % (data_per_thread * vec_size * threads);

phi::Array<kps::details::BroadcastConfig<2>, MAX_INPUT_NUM> configlists;
phi::Array<kps::details::BroadcastConfig, MAX_INPUT_NUM> configlists;
phi::Array<bool, MAX_INPUT_NUM> use_broadcast;

use_broadcast[0] = false;
Expand All @@ -115,7 +114,7 @@ void LaunchBiasAddFwKernel(const platform::CUDADeviceContext& ctx, int m, int n,
// Here, dims are transposed due to the logic in BroadcastConfig.
std::vector<int64_t> input1_dims = {n, 1};
std::vector<int64_t> out_dims = {n, m};
configlists[1] = kps::details::BroadcastConfig<2>(out_dims, input1_dims, 2);
configlists[1] = kps::details::BroadcastConfig(out_dims, input1_dims, 2);

auto func = AddFunctor<T>();
auto stream = ctx.stream();
Expand Down
139 changes: 40 additions & 99 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,19 @@ struct DimensionsTransform {
}
};

template <typename T, int VecSize, int Rank, bool IsBoundary = false>
template <typename T, int VecSize, bool IsBoundary = false>
__device__ __forceinline__ void LoadData(
T *dst,
const _ptr_ T *src,
uint32_t block_offset,
const kps::details::BroadcastConfig<Rank> &config,
const kps::details::BroadcastConfig &config,
int numel,
int num,
int need_broadcast) {
// numel : whole num of output
// num: how many data will be deal with in this time
if (need_broadcast) {
kps::ReadDataBc<T, VecSize, 1, 1, Rank, IsBoundary>(
kps::ReadDataBc<T, VecSize, 1, 1, IsBoundary>(
dst, src, block_offset, config, numel);
} else {
kps::ReadData<T, VecSize, 1, 1, IsBoundary>(dst, src + block_offset, num);
Expand All @@ -210,14 +210,13 @@ template <typename InT,
int Arity,
int NumOuts,
int VecSize,
int Rank,
bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl(
const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
const phi::Array<int, Arity> &use_broadcast,
uint32_t numel,
const phi::Array<kps::details::BroadcastConfig<Rank>, Arity> &configs,
const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
int num,
int block_offset,
Functor func) {
Expand All @@ -227,13 +226,13 @@ __device__ void VectorizedBroadcastKernelImpl(
#pragma unroll
for (int i = 0; i < Arity; i++) {
kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f));
LoadData<InT, VecSize, Rank, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
LoadData<InT, VecSize, IsBoundary>(args[i],
ins[i],
block_offset,
configs[i],
numel,
num,
use_broadcast[i]);
}
constexpr bool kCallElementwiseAny =
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
Expand All @@ -254,14 +253,13 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
__global__ void VectorizedBroadcastKernel(
phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
phi::Array<_ptr_ OutT *, NumOuts> outs,
phi::Array<int, Arity> use_broadcast,
uint32_t numel,
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs,
phi::Array<kps::details::BroadcastConfig, Arity> configs,
int main_offset,
int tail_tid,
Functor func) {
Expand All @@ -276,7 +274,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -294,7 +291,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, num, block_offset, func);
}
Expand All @@ -306,7 +302,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
false>(ins,
outs,
use_broadcast,
Expand All @@ -322,7 +317,6 @@ __global__ void VectorizedBroadcastKernel(
Arity,
NumOuts,
VecSize,
Rank,
true>(
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
}
Expand All @@ -334,15 +328,14 @@ template <typename InT,
typename Functor,
int Arity,
int NumOuts,
int VecSize,
int Rank>
int VecSize>
void LaunchBroadcastKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func,
DimensionsTransform merge_dims) {
int numel = (*outs)[0]->numel();
phi::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<int, Arity> use_broadcast;
phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
phi::Array<_ptr_ OutT *, NumOuts> outs_data;
Expand All @@ -358,7 +351,7 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
// get the broadcast config,
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
configs[i] = kps::details::BroadcastConfig<Rank>(
configs[i] = kps::details::BroadcastConfig(
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
}
}
Expand All @@ -374,15 +367,14 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#else
const int threads = 256;
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
Expand All @@ -394,58 +386,18 @@ void LaunchBroadcastKernel(const KPDevice &ctx,
Functor,
Arity,
NumOuts,
VecSize,
Rank><<<blocks, threads, 0, stream>>>(ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
VecSize><<<blocks, threads, 0, stream>>>(
ins_data,
outs_data,
use_broadcast,
numel,
configs,
main_offset,
tail_tid,
func);
#endif
}

template <typename InT,
typename OutT,
typename Functor,
int Arity,
int NumOuts,
int VecSize>
void BroadcastKernelForDifferentDimSize(
const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis,
Functor func) {
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
case rank: { \
LaunchBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
ctx, ins, outs, func, merge_dims); \
} break;

switch (merge_dims.dim_size) {
CALL_BROADCAST_FOR_DIM_SIZE(1);
CALL_BROADCAST_FOR_DIM_SIZE(2);
CALL_BROADCAST_FOR_DIM_SIZE(3);
CALL_BROADCAST_FOR_DIM_SIZE(4);
CALL_BROADCAST_FOR_DIM_SIZE(5);
CALL_BROADCAST_FOR_DIM_SIZE(6);
CALL_BROADCAST_FOR_DIM_SIZE(7);
CALL_BROADCAST_FOR_DIM_SIZE(8);
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"The maximum dimension of input tensor is expected to be less than "
"%d, but recieved %d.",
merge_dims.dim_size,
phi::DDim::kMaxRank));
}
}
#undef CALL_BROADCAST_FOR_DIM_SIZE
}

template <ElementwiseType ET,
typename InT,
typename OutT,
Expand Down Expand Up @@ -506,33 +458,22 @@ void BroadcastKernelForDifferentVecSize(
: in_vec_size;
}
int vec_size = std::min(out_vec_size, in_vec_size);
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);

switch (vec_size) {
case 4: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
4>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 4>(
ctx, ins, outs, func, merge_dims);
break;
}
case 2: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
2>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 2>(
ctx, ins, outs, func, merge_dims);
break;
}
case 1: {
BroadcastKernelForDifferentDimSize<InT,
OutT,
Functor,
kArity,
NumOuts,
1>(ctx, ins, outs, axis, func);
LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, 1>(
ctx, ins, outs, func, merge_dims);
break;
}
default: {
Expand Down
Loading