From b23e36b412a500c1d7f34e559c3b4b62b84e47ed Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Wed, 24 Nov 2021 11:21:11 +0800 Subject: [PATCH 1/8] [p_norm] optimize zeronorm, infnorm and neginfonorm --- paddle/fluid/operators/p_norm_op.cu | 133 ++++++++++------------------ 1 file changed, 48 insertions(+), 85 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index cfe778c49121f..b74799b0a4be3 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -23,6 +23,8 @@ namespace cub = hipcub; #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" namespace paddle { namespace operators { @@ -56,87 +58,32 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } -template -__global__ void Pnorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, float porder, T* out_norm) { - using MT = typename details::MPTypeTrait::Type; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - auto porder_t = static_cast(porder); - auto porder_inv = static_cast(1.0 / porder); - - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - MT sum = static_cast(0.0); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const MT x_ij = static_cast(x[base + j * post]); - sum += inline_pow(inline_abs(x_ij), porder_t); - } - MT reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) - out_norm[i] = static_cast(inline_pow(reduce_result, porder_inv)); +struct NonzeroFunctor { + HOSTDEVICE explicit inline NonzeroFunctor() {} + template + HOSTDEVICE inline T operator()(const T& x) const { + return static_cast(static_cast(x) != 0); } -} +}; -template -__global__ void ZeorNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - using MT = typename details::MPTypeTrait::Type; - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - MT sum = static_cast(0.0); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - const MT x_ij = static_cast(x[base + j * post]); - sum += static_cast(static_cast(x_ij) != 0); - } - MT reduce_result = BlockReduce(temp_storage).Sum(sum); - if (threadIdx.x == 0) out_norm[i] = static_cast(reduce_result); +struct AbsFunctor { + HOSTDEVICE explicit inline AbsFunctor() {} + template + HOSTDEVICE inline T operator()(const T& x) const { + return static_cast(inline_abs(x)); } -} +}; -template -__global__ void InfNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - T cur_max = inline_abs(x[base]); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - T x_ij_abs = inline_abs(x[base + j * post]); - if (cur_max < x_ij_abs) cur_max = x_ij_abs; - } - T reduce_result = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); - if (threadIdx.x == 0) out_norm[i] = reduce_result; +struct PowFunctor { + HOSTDEVICE explicit inline PowFunctor(float porder) { + this->porder = porder; } -} - -template -__global__ void NegInfNorm(const T* x, const int pre, - const int axis_n, // dim in axis - const int post, T* out_norm) { - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - int base = (i / post) * post * axis_n + (i % post); - T cur_min = inline_abs(x[base]); - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - T x_ij_abs = inline_abs(x[base + j * post]); - if (cur_min > x_ij_abs) cur_min = x_ij_abs; - } - T reduce_result = BlockReduce(temp_storage).Reduce(cur_min, cub::Min()); - if (threadIdx.x == 0) out_norm[i] = reduce_result; + template + HOSTDEVICE inline T operator()(const T& x) const { + return inline_pow(inline_abs(x), static_cast(porder)); } -} + float porder; +}; template class PnormCUDAKernel : public framework::OpKernel { @@ -145,7 +92,6 @@ class PnormCUDAKernel : public framework::OpKernel { auto* in_x = ctx.Input("X"); auto* out_norm = ctx.Output("Out"); const T* x = in_x->data(); - T* norm = out_norm->mutable_data(ctx.GetPlace()); auto xdim = in_x->dims(); auto ndim = out_norm->dims(); @@ -153,10 +99,12 @@ class PnormCUDAKernel : public framework::OpKernel { int axis = ctx.Attr("axis"); bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; + std::vector reduce_axis = {axis}; int pre, n, post; GetDims(xdim, axis, &pre, &n, &post, asvector); - auto& dev_ctx = ctx.cuda_device_context(); + auto& dev_ctx = ctx.device_context(); + auto stream = ctx.cuda_device_context().stream(); #ifdef __HIPCC__ const int block = 256; @@ -167,18 +115,33 @@ class PnormCUDAKernel : public framework::OpKernel { int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); int grid = std::min(max_blocks, pre * post); + int reduce_idx = (blockIdx.x / post) * post * n + (blockIdx.x % post); if (porder == 0) { - ZeorNorm<<>>(x, pre, n, post, - norm); + TensorReduce ( + *in_x, out_norm, reduce_axis, static_cast(0), cub::Sum(), + NonzeroFunctor(), stream); } else if (porder == INFINITY) { - InfNorm<<>>(x, pre, n, post, - norm); + TensorReduce ( + *in_x, out_norm, reduce_axis, static_cast(inline_abs(x[reduce_idx])), + cub::Max(), AbsFunctor(), stream); } else if (porder == -INFINITY) { - NegInfNorm<<>>(x, pre, n, - post, norm); + TensorReduce ( + *in_x, out_norm, reduce_axis, static_cast(inline_abs(x[reduce_idx])), + cub::Min(), AbsFunctor(), stream); } else { - Pnorm<<>>(x, pre, n, post, - porder, norm); + TensorReduce ( + *in_x, out_norm, reduce_axis, static_cast(0), + cub::Sum(), PowFunctor(porder), stream); + //const T* tmp_norm = out_norm->data(); + //T* norm = out_norm->mutable_data(ctx.GetPlace()); + //kps::ElementwiseUnary ( + // norm, tmp_norm, PowFunctor(1. / porder)); + const framework::Tensor* tmp_norm = out_norm; + std::vector ins = {tmp_norm}; + std::vector outs = {out_norm}; + auto func = PowFunctor(porder); + LaunchSameDimsElementwiseCudaKernel ( + dev_ctx, ins, &outs, func); } } }; From 7ee4116ed0ea437637940cca8166f28e9c094a71 Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Mon, 29 Nov 2021 14:54:49 +0000 Subject: [PATCH 2/8] [pnorm] optimize pnorm for special case, a preliminary attempt --- paddle/fluid/operators/p_norm_op.cu | 76 +++++++++++++++-------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index b74799b0a4be3..a8b95f397457e 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -21,10 +21,10 @@ limitations under the License. */ namespace cub = hipcub; #endif #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/p_norm_op.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -74,8 +74,17 @@ struct AbsFunctor { } }; +template struct PowFunctor { - HOSTDEVICE explicit inline PowFunctor(float porder) { + HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; } + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + +struct UnaryPowFunctor { + HOSTDEVICE explicit inline UnaryPowFunctor(float porder) { this->porder = porder; } template @@ -92,7 +101,7 @@ class PnormCUDAKernel : public framework::OpKernel { auto* in_x = ctx.Input("X"); auto* out_norm = ctx.Output("Out"); const T* x = in_x->data(); - + T* norm = out_norm->mutable_data(ctx.GetPlace()); auto xdim = in_x->dims(); auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); @@ -100,48 +109,41 @@ class PnormCUDAKernel : public framework::OpKernel { bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; std::vector reduce_axis = {axis}; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); - auto& dev_ctx = ctx.device_context(); + auto& dev_ctx = ctx.cuda_device_context(); auto stream = ctx.cuda_device_context().stream(); -#ifdef __HIPCC__ - const int block = 256; -#else - const int block = 512; -#endif - - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - int grid = std::min(max_blocks, pre * post); - int reduce_idx = (blockIdx.x / post) * post * n + (blockIdx.x % post); + using MT = typename details::MPTypeTrait::Type; if (porder == 0) { - TensorReduce ( - *in_x, out_norm, reduce_axis, static_cast(0), cub::Sum(), - NonzeroFunctor(), stream); + TensorReduce( + *in_x, out_norm, reduce_axis, static_cast(0), cub::Sum(), + NonzeroFunctor(), stream); } else if (porder == INFINITY) { - TensorReduce ( - *in_x, out_norm, reduce_axis, static_cast(inline_abs(x[reduce_idx])), - cub::Max(), AbsFunctor(), stream); + auto inf = -std::numeric_limits::infinity(); + TensorReduce(*in_x, out_norm, reduce_axis, + static_cast(inf), cub::Max(), + AbsFunctor(), stream); } else if (porder == -INFINITY) { - TensorReduce ( - *in_x, out_norm, reduce_axis, static_cast(inline_abs(x[reduce_idx])), - cub::Min(), AbsFunctor(), stream); + auto inf = std::numeric_limits::infinity(); + TensorReduce(*in_x, out_norm, reduce_axis, + static_cast(inf), cub::Min(), + AbsFunctor(), stream); } else { - TensorReduce ( - *in_x, out_norm, reduce_axis, static_cast(0), - cub::Sum(), PowFunctor(porder), stream); - //const T* tmp_norm = out_norm->data(); - //T* norm = out_norm->mutable_data(ctx.GetPlace()); - //kps::ElementwiseUnary ( - // norm, tmp_norm, PowFunctor(1. / porder)); - const framework::Tensor* tmp_norm = out_norm; + framework::Tensor tmp; + tmp.mutable_data(ndim, ctx.GetPlace()); + TensorReduce( + *in_x, &tmp, reduce_axis, static_cast(0), cub::Sum(), + UnaryPowFunctor(porder), stream); + const framework::Tensor* tmp_norm = &tmp; std::vector ins = {tmp_norm}; std::vector outs = {out_norm}; - auto func = PowFunctor(porder); - LaunchSameDimsElementwiseCudaKernel ( - dev_ctx, ins, &outs, func); + auto func = PowFunctor(1. / porder); + const auto& cuda_ctx = + ctx.template device_context(); + + LaunchSameDimsElementwiseCudaKernel>(cuda_ctx, ins, + &outs, func); } } }; From 5cd7469416ebbc35bdd7b90eebccb72cf4e64632 Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Tue, 30 Nov 2021 12:06:19 +0000 Subject: [PATCH 3/8] [pnorm] optimize p_norm for special cases with Kernel Primitive API --- paddle/fluid/operators/p_norm_op.cu | 83 +++++++++++++++++++---------- 1 file changed, 56 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index a8b95f397457e..8f7c834ba8747 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -24,6 +24,8 @@ namespace cub = hipcub; #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/p_norm_op.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -60,6 +62,7 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { struct NonzeroFunctor { HOSTDEVICE explicit inline NonzeroFunctor() {} + HOSTDEVICE explicit inline NonzeroFunctor(int n) {} template HOSTDEVICE inline T operator()(const T& x) const { return static_cast(static_cast(x) != 0); @@ -68,6 +71,7 @@ struct NonzeroFunctor { struct AbsFunctor { HOSTDEVICE explicit inline AbsFunctor() {} + HOSTDEVICE explicit inline AbsFunctor(int n) {} template HOSTDEVICE inline T operator()(const T& x) const { return static_cast(inline_abs(x)); @@ -83,15 +87,37 @@ struct PowFunctor { float porder; }; -struct UnaryPowFunctor { - HOSTDEVICE explicit inline UnaryPowFunctor(float porder) { - this->porder = porder; +template +struct AbsAndMin { + using Transformer = AbsFunctor; + using MT = typename details::MPTypeTrait::Type; + inline Ty initial() { + return static_cast(std::numeric_limits::infinity()); } - template - HOSTDEVICE inline T operator()(const T& x) const { - return inline_pow(inline_abs(x), static_cast(porder)); + __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { + return (a < b) ? a : b; + } +}; + +template +struct AbsAndMax { + using Transformer = AbsFunctor; + using MT = typename details::MPTypeTrait::Type; + inline Ty initial() { + return static_cast(-std::numeric_limits::infinity()); + } + __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { + return (a > b) ? a : b; + } +}; + +template +struct NonzeroAndSum { + using Transformer = NonzeroFunctor; + inline Ty initial() { return static_cast(0.0f); } + __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { + return b + a; } - float porder; }; template @@ -115,35 +141,38 @@ class PnormCUDAKernel : public framework::OpKernel { using MT = typename details::MPTypeTrait::Type; if (porder == 0) { - TensorReduce( - *in_x, out_norm, reduce_axis, static_cast(0), cub::Sum(), - NonzeroFunctor(), stream); + TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, + stream); } else if (porder == INFINITY) { - auto inf = -std::numeric_limits::infinity(); - TensorReduce(*in_x, out_norm, reduce_axis, - static_cast(inf), cub::Max(), - AbsFunctor(), stream); + TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, + stream); } else if (porder == -INFINITY) { - auto inf = std::numeric_limits::infinity(); - TensorReduce(*in_x, out_norm, reduce_axis, - static_cast(inf), cub::Min(), - AbsFunctor(), stream); + TensorReduceFunctorImpl(*in_x, out_norm, reduce_axis, + stream); } else { - framework::Tensor tmp; - tmp.mutable_data(ndim, ctx.GetPlace()); - TensorReduce( - *in_x, &tmp, reduce_axis, static_cast(0), cub::Sum(), - UnaryPowFunctor(porder), stream); - const framework::Tensor* tmp_norm = &tmp; - std::vector ins = {tmp_norm}; - std::vector outs = {out_norm}; - auto func = PowFunctor(1. / porder); + framework::Tensor tmp_x; + tmp_x.mutable_data(xdim, ctx.GetPlace()); + std::vector ins = {in_x}; + std::vector outs = {&tmp_x}; + auto func = PowFunctor(porder); const auto& cuda_ctx = ctx.template device_context(); LaunchSameDimsElementwiseCudaKernel>(cuda_ctx, ins, &outs, func); + framework::Tensor tmp_y; + tmp_y.mutable_data(ndim, ctx.GetPlace()); + TensorReduceFunctorImpl(tmp_x, &tmp_y, reduce_axis, + stream); + const framework::Tensor* tmp_norm = &tmp_y; + ins = {tmp_norm}; + outs = {out_norm}; + auto func_inverse = PowFunctor(1. / porder); + + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, func_inverse); } } }; From cc21bcfcf09cbf36a07108fd2d6c7fa8baff4e2a Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Mon, 6 Dec 2021 12:56:20 +0000 Subject: [PATCH 4/8] [pnorm] optimize backward speed for special cases, and modify reduce_op for flexible call. --- paddle/fluid/operators/p_norm_op.cu | 146 ++++++++---------- paddle/fluid/operators/reduce_ops/reduce_op.h | 117 +++++++------- 2 files changed, 133 insertions(+), 130 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 8f7c834ba8747..e7784dd982838 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -26,6 +26,7 @@ namespace cub = hipcub; #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { @@ -78,11 +79,22 @@ struct AbsFunctor { } }; +template +struct UnsignedPowFunctor { + HOSTDEVICE explicit inline UnsignedPowFunctor(float porder) { + this->porder = porder; + } + HOSTDEVICE inline Ty operator()(const Tx& x) const { + return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + } + float porder; +}; + template struct PowFunctor { HOSTDEVICE explicit inline PowFunctor(float porder) { this->porder = porder; } HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(inline_pow(inline_abs(x), static_cast(porder))); + return static_cast(inline_pow(x, static_cast(porder))); } float porder; }; @@ -154,13 +166,13 @@ class PnormCUDAKernel : public framework::OpKernel { tmp_x.mutable_data(xdim, ctx.GetPlace()); std::vector ins = {in_x}; std::vector outs = {&tmp_x}; - auto func = PowFunctor(porder); + auto func = UnsignedPowFunctor(porder); const auto& cuda_ctx = ctx.template device_context(); LaunchSameDimsElementwiseCudaKernel>(cuda_ctx, ins, - &outs, func); + UnsignedPowFunctor>( + cuda_ctx, ins, &outs, func); framework::Tensor tmp_y; tmp_y.mutable_data(ndim, ctx.GetPlace()); TensorReduceFunctorImpl(tmp_x, &tmp_y, reduce_axis, @@ -168,73 +180,45 @@ class PnormCUDAKernel : public framework::OpKernel { const framework::Tensor* tmp_norm = &tmp_y; ins = {tmp_norm}; outs = {out_norm}; - auto func_inverse = PowFunctor(1. / porder); + auto func_inverse = UnsignedPowFunctor(1. / porder); LaunchSameDimsElementwiseCudaKernel>( + UnsignedPowFunctor>( cuda_ctx, ins, &outs, func_inverse); } } }; -template -__global__ void PnormGradient(const T* x, const T* x_norm, const T* y_grad, - const float porder, const int pre, - const int axis_n, const int post, const T eps, - T* x_grad) { - using MT = typename details::MPTypeTrait::Type; - // dx = (x/pnorm_broadcast).pow(p-1) * norm_dy.broadcast * sign(x) - int num = pre * post; - auto porder_grad = static_cast(porder - 1.0f); - for (int i = blockIdx.x; i < num; i += gridDim.x) { - __shared__ MT pnorm_i; - __shared__ MT yout_i; - - auto base = (i / post) * post * axis_n + (i % post); - - if (threadIdx.x == 0) { - pnorm_i = static_cast(x_norm[i]); - yout_i = static_cast(y_grad[i]); - } - __syncthreads(); - - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - int index = base + j * post; - const MT x_ij = static_cast(inline_abs(x[index])); - x_grad[index] = static_cast( - inline_pow(x_ij, porder_grad) / - (inline_pow(pnorm_i, porder_grad) + static_cast(eps)) * yout_i * - static_cast(inline_sign(x[index]))); - } +template +struct AbsMaxAndMinGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto equals = ((*x).abs() == y->broadcast(dim)); + auto ones = dx->constant(static_cast(1.)); + auto negs = dx->constant(static_cast(-1.)); + auto zeros = dx->constant(static_cast(0.)); + auto positives = (*x) > zeros; + dx->device(place) = dy->broadcast(dim) * equals.select(ones, zeros) * + positives.select(ones, negs); } -} - -template -__global__ void InfNormGradient(const T* x, const T* x_norm, const T* y_grad, - const int pre, const int axis_n, const int post, - T* x_grad) { - int num = pre * post; - for (int i = blockIdx.x; i < num; i += gridDim.x) { - __shared__ T pnorm_i; - __shared__ T yout_i; - auto base = (i / post) * post * axis_n + (i % post); - if (threadIdx.x == 0) { - pnorm_i = x_norm[i]; - yout_i = y_grad[i]; - } - __syncthreads(); +}; - for (int j = threadIdx.x; j < axis_n; j += blockDim.x) { - int index = base + j * post; - const T x_ij = inline_abs(x[index]); - if (x_ij == pnorm_i) { - x_grad[index] = static_cast(inline_sign(x[index])) * yout_i; - } else { - x_grad[index] = static_cast(0); - } - } +template +struct PNormPostGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + auto ones = dx->constant(static_cast(1.)); + auto negs = dx->constant(static_cast(-1.)); + auto zeros = dx->constant(static_cast(0.)); + auto positives = (*x) > zeros; + dx->device(place) = (*dx) * dy->broadcast(dim) * y->broadcast(dim) * + positives.select(ones, negs); } -} +}; template class PnormGradCUDAKernel : public framework::OpKernel { @@ -254,32 +238,38 @@ class PnormGradCUDAKernel : public framework::OpKernel { float porder = ctx.Attr("porder"); T eps = static_cast(ctx.Attr("epsilon")); int axis = ctx.Attr("axis"); + bool reduce_all = ((axis < 0) || (in_norm->numel() == 1)); bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; - int pre, n, post; - GetDims(xdim, axis, &pre, &n, &post, asvector); + const std::vector dims = {axis}; auto& dev_ctx = ctx.cuda_device_context(); + auto& cuda_ctx = ctx.template device_context(); -#ifdef __HIPCC__ - const int block = 256; -#else - const int block = 512; -#endif - - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - int grid = std::min(max_blocks, pre * post); if (porder == 0) { math::SetConstant set_zero; - auto& dev_ctx = ctx.template device_context(); - set_zero(dev_ctx, out_dx, static_cast(0)); + set_zero(cuda_ctx, out_dx, static_cast(0)); } else if (porder == INFINITY || porder == -INFINITY) { - InfNormGradient<<>>( - x, x_norm, norm_dy, pre, n, post, dx); + LaunchReduceGradKernel>( + ctx, in_x, in_norm, in_norm_dy, out_dx, dims, reduce_all); } else { - PnormGradient<<>>( - x, x_norm, norm_dy, porder, pre, n, post, eps, dx); + framework::Tensor tmp_norm; + tmp_norm.mutable_data(in_norm->dims(), ctx.GetPlace()); + std::vector ins = {in_norm}; + std::vector outs = {&tmp_norm}; + auto pow_functor = PowFunctor(1. - porder); + LaunchSameDimsElementwiseCudaKernel>(cuda_ctx, ins, &outs, + pow_functor); + ins = {in_x}; + outs = {out_dx}; + auto unsigned_pow = UnsignedPowFunctor(porder - 1.); + LaunchSameDimsElementwiseCudaKernel>( + cuda_ctx, ins, &outs, unsigned_pow); + const framework::Tensor* tmp_norm_const = &tmp_norm; + LaunchReduceGradKernel>( + ctx, in_x, tmp_norm_const, in_norm_dy, out_dx, dims, reduce_all); } } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index df27083edbed5..1b6d82b149e65 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -326,6 +326,67 @@ class BoolReduceKernel : public framework::OpKernel { } }; +template +void LaunchReduceGradKernel(const framework::ExecutionContext& context, + const framework::Tensor* input0, + const framework::Tensor* input1, + const framework::Tensor* input2, + paddle::framework::Tensor* output, + const std::vector& dims, + bool reduce_all = false) { + if (reduce_all) { + auto x = EigenVector::Flatten(*input0); + auto x_reduce = EigenVector::Flatten(*input1); + auto x_reduce_grad = EigenVector::Flatten(*input2); + auto x_grad = EigenVector::Flatten(*output); + auto& place = + *context.template device_context().eigen_device(); + auto broadcast_dim = + Eigen::array({{static_cast(input0->numel())}}); + Functor functor; + functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, + broadcast_dim[0]); + } else { + int rank = input0->dims().size(); + switch (rank) { + case 1: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + case 2: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + case 3: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + case 4: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + case 5: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + case 6: + ReduceGradFunctor( + context.template device_context(), *input0, *input1, + *input2, output, dims); + break; + default: + HandleLargeDimGrad(context, input0, input1, + input2, output, dims); + break; + } + } +} + template class ReduceGradKernel : public framework::OpKernel { @@ -362,61 +423,13 @@ class ReduceGradKernel : public framework::OpKernel { input1 = input2; } + const std::vector const_dims = dims; + // NOTE(dengkaipeng): Out is unnecessary in some reduce kernel and // not be set as Input in grad Maker, use Out_grad to replace here if (!input1) input1 = input2; - - if (reduce_all) { - auto x = EigenVector::Flatten(*input0); - auto x_reduce = EigenVector::Flatten(*input1); - auto x_reduce_grad = EigenVector::Flatten(*input2); - auto x_grad = EigenVector::Flatten(*output); - auto& place = - *context.template device_context().eigen_device(); - auto broadcast_dim = - Eigen::array({{static_cast(input0->numel())}}); - Functor functor; - functor(place, &x, &x_reduce, &x_grad, &x_reduce_grad, broadcast_dim, - broadcast_dim[0]); - } else { - int rank = input0->dims().size(); - switch (rank) { - case 1: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - case 2: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - case 3: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - case 4: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - case 5: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - case 6: - ReduceGradFunctor( - context.template device_context(), *input0, - *input1, *input2, output, dims); - break; - default: - HandleLargeDimGrad(context, input0, input1, - input2, output, dims); - break; - } - } + LaunchReduceGradKernel( + context, input0, input1, input2, output, const_dims, reduce_all); } void Compute(const framework::ExecutionContext& context) const override { From 6790302de01c968e8caac46cd8075d4fa144a894 Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Tue, 7 Dec 2021 05:22:50 +0000 Subject: [PATCH 5/8] [pnorm] remove reduce_function_op import --- paddle/fluid/operators/p_norm_op.cu | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index e7784dd982838..641f3630617b8 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -23,8 +23,6 @@ namespace cub = hipcub; #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/p_norm_op.h" -#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" -#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.h" #include "paddle/fluid/platform/float16.h" @@ -61,6 +59,15 @@ __device__ __forceinline__ double inline_pow(double base, double exponent) { return pow(base, exponent); } +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + HOSTDEVICE explicit inline IdentityFunctor(int n) {} + template + HOSTDEVICE inline T operator()(const T& x) const { + return static_cast(x); + } +}; + struct NonzeroFunctor { HOSTDEVICE explicit inline NonzeroFunctor() {} HOSTDEVICE explicit inline NonzeroFunctor(int n) {} @@ -132,6 +139,15 @@ struct NonzeroAndSum { } }; +template +struct IdentityAndSum { + using Transformer = IdentityFunctor; + inline Ty initial() { return static_cast(0.0f); } + __device__ __forceinline__ Ty operator()(const Ty& a, const Ty& b) const { + return b + a; + } +}; + template class PnormCUDAKernel : public framework::OpKernel { public: @@ -175,8 +191,8 @@ class PnormCUDAKernel : public framework::OpKernel { cuda_ctx, ins, &outs, func); framework::Tensor tmp_y; tmp_y.mutable_data(ndim, ctx.GetPlace()); - TensorReduceFunctorImpl(tmp_x, &tmp_y, reduce_axis, - stream); + TensorReduceFunctorImpl(tmp_x, &tmp_y, reduce_axis, + stream); const framework::Tensor* tmp_norm = &tmp_y; ins = {tmp_norm}; outs = {out_norm}; From d0588489d044eb653b5913a14f8ec846abd8d7fa Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Tue, 7 Dec 2021 07:33:54 +0000 Subject: [PATCH 6/8] [pnorm] remove p_norm_op.cc from unity group --- paddle/fluid/operators/unity_build_rule.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 5faa0dba6b878..23b6d6b2b2ebf 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -186,7 +186,6 @@ register_unity_group(cc norm_op.cc one_hot_op.cc one_hot_v2_op.cc - p_norm_op.cc pad2d_op.cc pad3d_op.cc pad_constant_like_op.cc From efbb36191840d249a406fb81295f7390c3855c4d Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Tue, 7 Dec 2021 07:58:44 +0000 Subject: [PATCH 7/8] [pnorm] remove p_norm_op.cu from unity group --- paddle/fluid/operators/unity_build_rule.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 23b6d6b2b2ebf..25aef67425ef9 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -467,7 +467,6 @@ register_unity_group(cu nll_loss_op.cu norm_op.cu one_hot_op.cu - p_norm_op.cu pad2d_op.cu pad3d_op.cu pad_constant_like_op.cu From 1cb511d1488bb86ebb587330902840cb01c79c0d Mon Sep 17 00:00:00 2001 From: Huijuan Wang Date: Thu, 9 Dec 2021 12:27:11 +0000 Subject: [PATCH 8/8] [pnorm] remove unused variables --- paddle/fluid/operators/p_norm_op.cu | 8 -------- 1 file changed, 8 deletions(-) diff --git a/paddle/fluid/operators/p_norm_op.cu b/paddle/fluid/operators/p_norm_op.cu index 641f3630617b8..1a481c1cf5c11 100644 --- a/paddle/fluid/operators/p_norm_op.cu +++ b/paddle/fluid/operators/p_norm_op.cu @@ -160,11 +160,9 @@ class PnormCUDAKernel : public framework::OpKernel { auto ndim = out_norm->dims(); float porder = ctx.Attr("porder"); int axis = ctx.Attr("axis"); - bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; std::vector reduce_axis = {axis}; - auto& dev_ctx = ctx.cuda_device_context(); auto stream = ctx.cuda_device_context().stream(); using MT = typename details::MPTypeTrait::Type; @@ -246,20 +244,14 @@ class PnormGradCUDAKernel : public framework::OpKernel { ctx.Input(framework::GradVarName("Out")); auto* out_dx = ctx.Output(framework::GradVarName("X")); T* dx = out_dx->mutable_data(ctx.GetPlace()); - const T* x = in_x->data(); - const T* x_norm = in_norm->data(); - const T* norm_dy = in_norm_dy->data(); auto xdim = in_x->dims(); float porder = ctx.Attr("porder"); - T eps = static_cast(ctx.Attr("epsilon")); int axis = ctx.Attr("axis"); bool reduce_all = ((axis < 0) || (in_norm->numel() == 1)); - bool asvector = ctx.Attr("asvector"); if (axis < 0) axis = xdim.size() + axis; const std::vector dims = {axis}; - auto& dev_ctx = ctx.cuda_device_context(); auto& cuda_ctx = ctx.template device_context(); if (porder == 0) {