Skip to content

Commit

Permalink
fix p_norm gpu nan bug while divide zero (PaddlePaddle#41359)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiboniu authored and wu.zeng committed Apr 10, 2022
1 parent 6e8911f commit c327f29
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions paddle/phi/kernels/gpu/p_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ struct AbsMaxAndMinGradFunctor {

template <typename T>
struct PNormGradFunctor {
HOSTDEVICE explicit inline PNormGradFunctor(float porder) {
HOSTDEVICE explicit inline PNormGradFunctor(float porder, float eps) {
this->porder = static_cast<T>(porder - 1.);
this->eps = static_cast<T>(eps);
}
template <typename Context,
typename X,
Expand All @@ -58,11 +59,12 @@ struct PNormGradFunctor {
DY* dy,
const Dim& dim,
int size) {
dx->device(place) = (*x).abs().pow(this->porder) * (*x).sign() *
dy->broadcast(dim) *
(*y).pow(-this->porder).broadcast(dim);
dx->device(place) =
(*x).abs().pow(this->porder) * (*x).sign() * dy->broadcast(dim) *
(*y + y->constant(eps)).pow(-this->porder).broadcast(dim);
}
T porder;
T eps;
};

template <typename T, typename Context>
Expand Down Expand Up @@ -96,7 +98,7 @@ void PNormGradKernel(const Context& dev_ctx,
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);

} else {
auto functor = PNormGradFunctor<T>(porder);
auto functor = PNormGradFunctor<T>(porder, epsilon);
funcs::LaunchReduceGradKernel<Context, T, PNormGradFunctor<T>>(
dev_ctx, in_x, in_norm, in_norm_dy, out_dx, functor, dims, reduce_all);
}
Expand Down

0 comments on commit c327f29

Please sign in to comment.