Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Optimize p_norm gpu #69660

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions paddle/phi/kernels/funcs/p_norm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,40 @@ __device__ __forceinline__ float inline_pow(float base, float exponent) {
__device__ __forceinline__ double inline_pow(double base, double exponent) {
return pow(base, exponent);
}

__device__ __forceinline__ dtype::float16 inline_fabs(dtype::float16 x) {
return static_cast<dtype::float16>(fabs(static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_fabs(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(fabs(static_cast<float>(x)));
}
__device__ __forceinline__ float inline_fabs(float x) { return fabs(x); }
__device__ __forceinline__ double inline_fabs(double x) { return fabs(x); }

__device__ __forceinline__ dtype::float16 inline_square(dtype::float16 x) {
return static_cast<dtype::float16>(static_cast<float>(x) *
static_cast<float>(x));
}
__device__ __forceinline__ dtype::bfloat16 inline_square(dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(static_cast<float>(x) *
static_cast<float>(x));
}
__device__ __forceinline__ float inline_square(float x) { return x * x; }
__device__ __forceinline__ double inline_square(double x) { return x * x; }

__device__ __forceinline__ dtype::float16 inline_fabs_cubic(dtype::float16 x) {
return static_cast<dtype::float16>(fabs(
static_cast<float>(x) * static_cast<float>(x) * static_cast<float>(x)));
}
__device__ __forceinline__ dtype::bfloat16 inline_fabs_cubic(
dtype::bfloat16 x) {
return static_cast<dtype::bfloat16>(fabs(
static_cast<float>(x) * static_cast<float>(x) * static_cast<float>(x)));
}
__device__ __forceinline__ float inline_fabs_cubic(float x) {
return fabs(x * x * x);
}
__device__ __forceinline__ double inline_fabs_cubic(double x) {
return fabs(x * x * x);
}
} // namespace phi
56 changes: 49 additions & 7 deletions paddle/phi/kernels/gpu/p_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,30 @@ struct UnsignedPowFunctor {
float porder;
};

template <typename T>
struct FabsFunctor {
HOSTDEVICE explicit inline FabsFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_fabs(x));
}
};

template <typename T>
struct SquareFunctor {
HOSTDEVICE explicit inline SquareFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_square(x));
}
};

template <typename T>
struct FabsCubicFunctor {
HOSTDEVICE explicit inline FabsCubicFunctor() = default;
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(inline_fabs_cubic(x));
}
};

template <typename T, typename Context>
void PNormKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -84,14 +108,32 @@ void PNormKernel(const Context& dev_ctx,
phi::funcs::ReduceKernel<T, T, kps::MinFunctor, AbsFunctor<T>>(
dev_ctx, *in_x, out_norm, AbsFunctor<T>(), reduce_axis);
} else {
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
if (porder == 1.0) {
// fast 1-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsFunctor<T>(), reduce_axis);
} else if (porder == 2.0) {
// fast 2-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, SquareFunctor<T>>(
dev_ctx, *in_x, out_norm, SquareFunctor<T>(), reduce_axis);
} else if (porder == 3.0) {
// fast 3-norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, FabsCubicFunctor<T>>(
dev_ctx, *in_x, out_norm, FabsCubicFunctor<T>(), reduce_axis);
} else {
// vanilla norm
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, UnsignedPowFunctor<T>>(
dev_ctx, *in_x, out_norm, UnsignedPowFunctor<T>(porder), reduce_axis);
}

const DenseTensor* tmp_norm = out_norm;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out_norm};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
if (porder != 1.0) {
// save computation when porder is 1.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

porder is 1.0的描述和判断条件 porder != 1.0 不太对应

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

porder is 1.0的描述和判断条件 porder != 1.0 不太对应

代码上应该是没问题的,1.0的时候不用跑下面的幂运算,注释我后续移动到上面去

const DenseTensor* tmp_norm = out_norm;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out_norm};
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, UnsignedPowFunctor<T>(1. / porder));
}
}
}
} // namespace phi
Expand Down