From 452c75b8034e485a2626e22cac39c95c07b883b4 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 9 Mar 2022 21:37:32 +0800 Subject: [PATCH 001/239] move elementwise mul grad (#40252) --- .../new_executor/standalone_executor_test.cc | 2 +- .../elementwise/elementwise_functor.h | 41 --- .../elementwise/elementwise_mul_op.cc | 49 ---- .../elementwise/elementwise_mul_op.cu | 68 ----- .../elementwise/elementwise_mul_op.h | 238 --------------- .../kernels/cpu/elementwise_grad_kernel.cc | 61 +++- paddle/phi/kernels/elementwise_grad_kernel.h | 39 +++ .../phi/kernels/funcs/elementwise_functor.h | 44 +++ paddle/phi/kernels/gpu/elementwise_grad.h | 37 +++ .../kernels/gpu/elementwise_grad_kernel.cu | 54 ++++ .../impl/elementwise_grad_kernel_impl.h | 273 ++++++++++++++++++ paddle/phi/ops/compat/elementwise_sig.cc | 34 +++ 12 files changed, 539 insertions(+), 401 deletions(-) diff --git a/paddle/fluid/framework/new_executor/standalone_executor_test.cc b/paddle/fluid/framework/new_executor/standalone_executor_test.cc index 62d87b6917e40..a69cc0d6b866d 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor_test.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor_test.cc @@ -46,7 +46,7 @@ USE_OP(matmul_grad); USE_OP(square); USE_OP(transpose2_grad); USE_OP(concat_grad); -USE_OP(elementwise_mul_grad); +USE_OP_ITSELF(elementwise_mul_grad); USE_OP(sigmoid_grad); USE_OP(tanh_grad); USE_OP(sum); diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index 8e0bf78e9b7f9..14baeaa74d242 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -196,47 +196,6 @@ struct MinGradXYFunctor { } }; -template -struct MulGradFunctor { - inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; } -}; -template -struct MulGradFunctor> { - inline HOSTDEVICE Complex operator()(const Complex a, - const Complex b) const { - Complex b_conj(b.real, -b.imag); - return a * b_conj; - } -}; - -template -struct MulGradXYFunctor { - inline HOSTDEVICE phi::Array operator()(const InT a, const InT b, - const InT c) { - phi::Array outs; - // dx = dout * y - outs[0] = a * b; - // dy = dout * x - outs[1] = a * c; - return outs; - } -}; - -template -struct MulGradXYFunctor, Complex> { - inline HOSTDEVICE phi::Array, 2> operator()( - const Complex a, const Complex b, const Complex c) { - phi::Array, 2> outs; - // dx = dout * y - Complex b_conj(b.real, -b.imag); - outs[0] = a * b_conj; - // dy = dout * x - Complex c_conj(c.real, -c.imag); - outs[1] = a * c_conj; - return outs; - } -}; - // Ternary compare template struct MaxGradXFunctor { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index e172279145e28..830e09eeae481 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -173,55 +173,6 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::ElementwiseMulKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_grad, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel>, - ops::ElementwiseMulGradKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_grad_grad, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel>, - ops::ElementwiseMulDoubleGradKernel>); -REGISTER_OP_CPU_KERNEL( - elementwise_mul_triple_grad, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel>, - ops::ElementwiseMulTripleGradKernel>); REGISTER_OP_VERSION(elementwise_mul) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 45c87a27a180a..f7b9fd1e265f5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -63,33 +63,6 @@ class ElementwiseMulKernel } }; -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - const auto& dev_ctx = - ctx.template device_context(); - const auto place = ctx.GetPlace(); - - if (dx != nullptr && dy != nullptr) { - std::vector ins = {dout, y, x}; - GetGradXAndYOut( - dev_ctx, place, axis, ins, dout, dx, dy, MulGradXYFunctor()); - } else if (dx != nullptr && dy == nullptr) { - std::vector ins = {dout, y}; - GetGradXOrYOut(dev_ctx, place, axis, ins, dout, - dx, MulGradFunctor()); - } else if (dx == nullptr && dy != nullptr) { - std::vector ins = {dout, x}; - GetGradXOrYOut(dev_ctx, place, axis, ins, dout, - dy, MulGradFunctor()); - } -} - } // namespace operators } // namespace paddle @@ -103,44 +76,3 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMulKernel, ops::ElementwiseMulKernel>, ops::ElementwiseMulKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_grad, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel, - ops::ElementwiseMulGradKernel>, - ops::ElementwiseMulGradKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_grad_grad, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel, - ops::ElementwiseMulDoubleGradKernel>, - ops::ElementwiseMulDoubleGradKernel>); -REGISTER_OP_CUDA_KERNEL( - elementwise_mul_triple_grad, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel, - ops::ElementwiseMulTripleGradKernel>, - ops::ElementwiseMulTripleGradKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index c81266d584468..58a3123c7e332 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -137,244 +137,6 @@ class ElementwiseMulKernel : public framework::OpKernel { } } }; -template -struct MulGradDX { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } -}; - -template -struct MulGradDX> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex y_conj(y.real, -y.imag); - return dout * y_conj; - } -}; - -template -struct MulGradDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } -}; - -template -struct MulGradDY> { - HOSTDEVICE paddle::platform::complex operator()( - paddle::platform::complex x, paddle::platform::complex y, - paddle::platform::complex out, - paddle::platform::complex dout) const { - paddle::platform::complex x_conj(x.real, -x.imag); - return dout * x_conj; - } -}; -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy) { - int axis = ctx.Attr("axis"); - ElemwiseGradCompute, MulGradDY>( - ctx, *x, *y, *out, *dout, axis, dx, dy, MulGradDX(), MulGradDY()); -} - -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -template -typename std::enable_if< - std::is_same::value>::type -ElementwiseMulGrad(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - const framework::Tensor* out, const framework::Tensor* dout, - framework::Tensor* dx, framework::Tensor* dy); -#endif - -template -class ElementwiseMulGradKernel : public ElemwiseGradKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* out = dout; // out is not necessary - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - ElementwiseMulGrad(ctx, x, y, out, dout, dx, dy); - } -}; - -template -class ElementwiseMulDoubleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input("DOut"); - auto* ddx = ctx.Input("DDX"); - auto* ddy = ctx.Input("DDY"); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - auto* ddout = ctx.Output("DDOut"); - - if (ddout) ddout->mutable_data(ctx.GetPlace()); - - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - // dx = dout * ddy - // dy = dout * ddx - // ddout = ddx * y + x * ddy - // change computation sequence to save memory, so ddout can inplace ddx and - // dx can be used as 'tmp' tensor - // (1) dx = x * ddy - // (2) dy = dout * ddx - // (3) ddout = ddx * y - // (4) ddout = ddout + dx - // (5) dx = dout * ddy - if (ddout) { - int axis = ctx.Attr("axis"); - auto& place = - *ctx.template device_context().eigen_device(); - // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace - if (ddout->numel() > ddx->numel()) { - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, dx, dy, MulGradDX(), - MulGradDY()); - - Tensor ddout_tmp; - ddout_tmp.mutable_data(ddout->dims(), ctx.GetPlace()); - - default_elementwise_mul(ctx, y, &ddx_safe, ddout); - default_elementwise_mul(ctx, &ddy_safe, x, - &ddout_tmp); - - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - } else { - // use dx to save memory, other than alloc tmp tensor - Tensor* ddout_tmp = dx; - - default_elementwise_mul(ctx, x, &ddy_safe, ddout_tmp); - // NOTE: in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. - ElemwiseGradCompute, MulGradDY>( - ctx, ddx_safe, ddy_safe, *dout, *dout, axis, nullptr, dy, - MulGradDX(), MulGradDY()); - default_elementwise_mul(ctx, &ddx_safe, y, ddout); - - auto ddout_t = framework::EigenVector::Flatten(*ddout); - auto ddout_tmp_t = framework::EigenVector::Flatten(*ddout_tmp); - ddout_t.device(place) = ddout_t + ddout_tmp_t; - default_elementwise_mul(ctx, dout, &ddy_safe, dx); - } - } - } -}; - -template -class ElementwiseMulTripleGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - // get input - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input("DOut"); - auto* ddx = ctx.Input("DDX"); - auto* ddy = ctx.Input("DDY"); - - auto* d_dx = ctx.Input("D_DX"); - auto* d_dy = ctx.Input("D_DY"); - auto* d_ddout = ctx.Input("D_DDOut"); - - // get output - auto* out_d_x = ctx.Output("D_X"); - auto* out_d_y = ctx.Output("D_Y"); - auto* out_d_dout = ctx.Output("D_DOut"); - - auto* out_d_ddx = ctx.Output("D_DDX"); - auto* out_d_ddy = ctx.Output("D_DDY"); - - if (out_d_x) out_d_x->mutable_data(x->dims(), ctx.GetPlace()); - if (out_d_y) out_d_y->mutable_data(y->dims(), ctx.GetPlace()); - if (out_d_dout) out_d_dout->mutable_data(dout->dims(), ctx.GetPlace()); - if (out_d_ddx) out_d_ddx->mutable_data(x->dims(), ctx.GetPlace()); - if (out_d_ddy) out_d_ddy->mutable_data(y->dims(), ctx.GetPlace()); - - auto& place = *ctx.template device_context().eigen_device(); - - Tensor ddx_safe, ddy_safe; - GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); - GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); - - if (d_ddout) { - if (out_d_x) { - // out_d_x = ddy * d_ddout - default_elementwise_mul(ctx, &ddy_safe, d_ddout, - out_d_x); - } - if (out_d_y) { - // out_d_y = ddx * d_ddout - default_elementwise_mul(ctx, &ddx_safe, d_ddout, - out_d_y); - } - } - - if (out_d_dout) { - // get out_d_dout - // out_d_dout = ddy * d_dx + d_dy * ddx - Tensor out_d_dout_tmp; - out_d_dout_tmp.mutable_data(dout->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, d_dy, &ddx_safe, - out_d_dout); - default_elementwise_mul(ctx, &ddy_safe, d_dx, - &out_d_dout_tmp); - auto out_d_dout_t = framework::EigenVector::Flatten(*out_d_dout); - auto out_d_dout_tmp_t = - framework::EigenVector::Flatten(out_d_dout_tmp); - out_d_dout_t.device(place) = out_d_dout_t + out_d_dout_tmp_t; - } - - if (out_d_ddx) { - // get out_d_ddx - // out_d_ddx = dout * d_dy + y * d_ddout - Tensor out_d_ddx_tmp; - out_d_ddx_tmp.mutable_data(ddx->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, dout, d_dy, out_d_ddx); - default_elementwise_mul(ctx, y, d_ddout, - &out_d_ddx_tmp); - auto out_d_ddx_t = framework::EigenVector::Flatten(*out_d_ddx); - auto out_d_ddx_tmp_t = framework::EigenVector::Flatten(out_d_ddx_tmp); - out_d_ddx_t.device(place) = out_d_ddx_t + out_d_ddx_tmp_t; - } - - if (out_d_ddy) { - // get out_d_ddy - // out_d_ddy = dout * d_dx + x * d_ddout - Tensor out_d_ddy_tmp; - out_d_ddy_tmp.mutable_data(ddy->dims(), ctx.GetPlace()); - default_elementwise_mul(ctx, dout, d_dx, out_d_ddy); - default_elementwise_mul(ctx, x, d_ddout, - &out_d_ddy_tmp); - auto out_d_ddy_t = framework::EigenVector::Flatten(*out_d_ddy); - auto out_d_ddy_tmp_t = framework::EigenVector::Flatten(out_d_ddy_tmp); - out_d_ddy_t.device(place) = out_d_ddy_t + out_d_ddy_tmp_t; - } - } -}; } // namespace operators } // namespace paddle diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index c9177f1c46eac..cd513e809fd84 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -121,6 +121,20 @@ void DivideGradKernel(const Context& dev_ctx, dev_ctx, x, y, out, dout, axis, dx, dy, DivGradDX(), DivGradDY()); } +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + auto* out = &dout; // out is not necessary + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, x, y, *out, dout, axis, dx, dy, MulGradDX(), MulGradDY()); +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -193,8 +207,8 @@ PD_REGISTER_KERNEL(divide_grad, double, int, int64_t, - paddle::platform::complex, - paddle::platform::complex) {} + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(divide_double_grad, CPU, @@ -204,5 +218,44 @@ PD_REGISTER_KERNEL(divide_double_grad, double, int, int64_t, - paddle::platform::complex, - paddle::platform::complex) {} + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_double_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyDoubleGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_triple_grad, + CPU, + ALL_LAYOUT, + phi::MultiplyTripleGradKernel, + float, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/elementwise_grad_kernel.h b/paddle/phi/kernels/elementwise_grad_kernel.h index bcd5a98f07ee9..58ae11a9c4256 100644 --- a/paddle/phi/kernels/elementwise_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_grad_kernel.h @@ -85,4 +85,43 @@ void DivideDoubleGradKernel(const Context& dev_ctx, DenseTensor* dy, DenseTensor* dout, DenseTensor* ddout); + +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy); + +template +void MultiplyDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout); + +template +void MultiplyTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + paddle::optional d_ddout, + int axis, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_dout, + DenseTensor* d_ddx, + DenseTensor* d_ddy); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/elementwise_functor.h b/paddle/phi/kernels/funcs/elementwise_functor.h index 5615a450b5c54..b01d50015f01a 100644 --- a/paddle/phi/kernels/funcs/elementwise_functor.h +++ b/paddle/phi/kernels/funcs/elementwise_functor.h @@ -160,5 +160,49 @@ struct DivGradYFunctor> { } }; +template +struct MultiplyGradFunctor { + inline HOSTDEVICE T operator()(const T a, const T b) const { return a * b; } +}; +template +struct MultiplyGradFunctor> { + inline HOSTDEVICE ComplexType operator()(const ComplexType a, + const ComplexType b) const { + ComplexType b_conj(b.real, -b.imag); + return a * b_conj; + } +}; + +template +struct MultiplyGradXYFunctor { + inline HOSTDEVICE phi::Array operator()(const InT a, + const InT b, + const InT c) { + phi::Array outs; + // dx = dout * y + outs[0] = a * b; + // dy = dout * x + outs[1] = a * c; + return outs; + } +}; + +template +struct MultiplyGradXYFunctor, ComplexType> { + inline HOSTDEVICE phi::Array, 2> operator()( + const ComplexType a, + const ComplexType b, + const ComplexType c) { + phi::Array, 2> outs; + // dx = dout * y + ComplexType b_conj(b.real, -b.imag); + outs[0] = a * b_conj; + // dy = dout * x + ComplexType c_conj(c.real, -c.imag); + outs[1] = a * c_conj; + return outs; + } +}; + } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad.h b/paddle/phi/kernels/gpu/elementwise_grad.h index 98df65c92f34c..e5432b5f9187c 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad.h +++ b/paddle/phi/kernels/gpu/elementwise_grad.h @@ -360,4 +360,41 @@ void ElementwiseDivGrad(const GPUContext &dev_ctx, } } +/* +****************************** + Mul Grad +****************************** +*/ + +template +void ElementwiseMulGrad(const GPUContext &dev_ctx, + const DenseTensor &x, + const DenseTensor &y, + const DenseTensor &dout, + DenseTensor *dx, + DenseTensor *dy, + int axis) { + const auto place = dev_ctx.GetPlace(); + + if (dx != nullptr && dy != nullptr) { + std::vector ins = {&dout, &y, &x}; + GetGradXAndYOut( + dev_ctx, + place, + axis, + ins, + dout, + dx, + dy, + funcs::MultiplyGradXYFunctor()); + } else if (dx != nullptr && dy == nullptr) { + std::vector ins = {&dout, &y}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dx, funcs::MultiplyGradFunctor()); + } else if (dx == nullptr && dy != nullptr) { + std::vector ins = {&dout, &x}; + GetGradXOrYOut( + dev_ctx, place, axis, ins, dout, dy, funcs::MultiplyGradFunctor()); + } +} } // namespace phi diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index 45c8b9a21639f..81f7fac108803 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -136,6 +136,18 @@ void DivideGradKernel(const Context& dev_ctx, } } +template +void MultiplyGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + int axis, + DenseTensor* dx, + DenseTensor* dy) { + funcs::ElementwiseGradPreProcess(dout, dx); + ElementwiseMulGrad(dev_ctx, x, y, dout, dx, dy, axis); +} + } // namespace phi PD_REGISTER_KERNEL(add_grad, @@ -228,3 +240,45 @@ PD_REGISTER_KERNEL(divide_double_grad, int64_t, phi::dtype::complex, phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_double_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyDoubleGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} + +PD_REGISTER_KERNEL(multiply_triple_grad, + GPU, + ALL_LAYOUT, + phi::MultiplyTripleGradKernel, + float, + phi::dtype::float16, + double, + int, + int64_t, + bool, + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index e8831f90213b6..65427e87506f7 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -259,4 +259,277 @@ void DivideDoubleGradKernel(const Context& dev_ctx, } } +template +struct MulGradDX { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * y; } +}; + +template +struct MulGradDX> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex y_conj(y.real, -y.imag); + return dout * y_conj; + } +}; + +/* +****************************** + Multiply Grad +****************************** +*/ + +template +struct MulGradDY { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { return dout * x; } +}; + +template +struct MulGradDY> { + HOSTDEVICE phi::dtype::complex operator()( + phi::dtype::complex x, + phi::dtype::complex y, + phi::dtype::complex out, + phi::dtype::complex dout) const { + phi::dtype::complex x_conj(x.real, -x.imag); + return dout * x_conj; + } +}; + +template +void MultiplyDoubleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + int axis, + DenseTensor* dx, + DenseTensor* dy, + DenseTensor* ddout) { + if (ddout) dev_ctx.template Alloc(ddout); + + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, x, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + // dx = dout * ddy + // dy = dout * ddx + // ddout = ddx * y + x * ddy + // change computation sequence to save memory, so ddout can inplace ddx and + // dx can be used as 'tmp' tensor + // (1) dx = x * ddy + // (2) dy = dout * ddx + // (3) ddout = ddx * y + // (4) ddout = ddout + dx + // (5) dx = dout * ddy + if (ddout) { + auto& place = *dev_ctx.eigen_device(); + // size(ddout) > size(ddx), ddout can't use memory of ddx using inplace + if (ddout->numel() > ddx.get_ptr()->numel()) { + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + dx, + dy, + MulGradDX(), + MulGradDY()); + + DenseTensor ddout_tmp; + ddout_tmp.Resize(ddout->dims()); + dev_ctx.template Alloc(&ddout_tmp); + + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, ddx_safe, ddout, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, x, &ddout_tmp, axis); + + auto ddout_t = phi::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = phi::EigenVector::Flatten(ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + } else { + // use dx to save memory, other than alloc tmp tensor + DenseTensor* ddout_tmp = dx; + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, ddy_safe, ddout_tmp, axis); + // NOTE: in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs::ElemwiseGradCompute, MulGradDY>( + dev_ctx, + ddx_safe, + ddy_safe, + dout, + dout, + axis, + nullptr, + dy, + MulGradDX(), + MulGradDY()); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, y, ddout, axis); + + auto ddout_t = phi::EigenVector::Flatten(*ddout); + auto ddout_tmp_t = phi::EigenVector::Flatten(*ddout_tmp); + ddout_t.device(place) = ddout_t + ddout_tmp_t; + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, ddy_safe, dx, axis); + } + } +} + +template +void MultiplyTripleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& dout, + paddle::optional ddx, + paddle::optional ddy, + const DenseTensor& d_dx, + const DenseTensor& d_dy, + paddle::optional d_ddout, + int axis, + DenseTensor* d_x, + DenseTensor* d_y, + DenseTensor* d_dout, + DenseTensor* d_ddx, + DenseTensor* d_ddy) { + if (d_x) { + d_x->Resize(x.dims()); + dev_ctx.template Alloc(d_x); + } + if (d_y) { + d_y->Resize(y.dims()); + dev_ctx.template Alloc(d_y); + } + if (d_dout) { + d_dout->Resize(dout.dims()); + dev_ctx.template Alloc(d_dout); + } + if (d_ddx) { + d_ddx->Resize(x.dims()); + dev_ctx.template Alloc(d_ddx); + } + if (d_ddy) { + d_ddy->Resize(y.dims()); + dev_ctx.template Alloc(d_ddy); + } + + auto& place = *dev_ctx.eigen_device(); + + DenseTensor ddx_safe, ddy_safe; + funcs::GetDoubleGradSafeTensor( + dev_ctx, x, ddx.get_ptr(), &ddx_safe); + funcs::GetDoubleGradSafeTensor( + dev_ctx, y, ddy.get_ptr(), &ddy_safe); + + if (d_ddout.get_ptr()) { + if (d_x) { + // d_x = ddy * d_ddout + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, *(d_ddout.get_ptr()), d_x, axis); + } + if (d_y) { + // d_y = ddx * d_ddout + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddx_safe, *(d_ddout.get_ptr()), d_y, axis); + } + } + + if (d_dout) { + // get d_dout + // d_dout = ddy * d_dx + d_dy * ddx + DenseTensor d_dout_tmp; + d_dout_tmp.Resize(dout.dims()); + dev_ctx.template Alloc(&d_dout_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, d_dy, ddx_safe, d_dout, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, ddy_safe, d_dx, &d_dout_tmp, axis); + auto d_dout_t = phi::EigenVector::Flatten(*d_dout); + auto d_dout_tmp_t = phi::EigenVector::Flatten(d_dout_tmp); + d_dout_t.device(place) = d_dout_t + d_dout_tmp_t; + } + + if (d_ddx) { + // get d_ddx + // d_ddx = dout * d_dy + y * d_ddout + DenseTensor d_ddx_tmp; + d_ddx_tmp.Resize(ddx->dims()); + dev_ctx.template Alloc(&d_ddx_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dy, d_ddx, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, y, *(d_ddout.get_ptr()), &d_ddx_tmp, axis); + auto d_ddx_t = phi::EigenVector::Flatten(*d_ddx); + auto d_ddx_tmp_t = phi::EigenVector::Flatten(d_ddx_tmp); + d_ddx_t.device(place) = d_ddx_t + d_ddx_tmp_t; + } + + if (d_ddy) { + // get d_ddy + // d_ddy = dout * d_dx + x * d_ddout + DenseTensor d_ddy_tmp; + d_ddy_tmp.Resize(ddy->dims()); + dev_ctx.template Alloc(&d_ddy_tmp); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, dout, d_dx, d_ddy, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, x, *(d_ddout.get_ptr()), &d_ddy_tmp, axis); + auto d_ddy_t = phi::EigenVector::Flatten(*d_ddy); + auto d_ddy_tmp_t = phi::EigenVector::Flatten(d_ddy_tmp); + d_ddy_t.device(place) = d_ddy_t + d_ddy_tmp_t; + } +} + } // namespace phi diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index d4a25866907a0..fc890fa3a4923 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -122,6 +122,31 @@ KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( {GradVarName("Y"), "DOut", "DDOut"}); } +KernelSignature ElementwiseMulGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("multiply_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); +} + +KernelSignature ElementwiseMulDoubleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("multiply_double_grad", + {"X", "Y", "DOut", "DDX", "DDY"}, + {"axis"}, + {GradVarName("X"), GradVarName("Y"), "DDOut"}); +} + +KernelSignature ElementwiseMulTripleGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "multiply_triple_grad", + {"X", "Y", "DOut", "DDX", "DDY", "D_DX", "D_DY", "D_DDOut"}, + {"axis"}, + {"D_X", "D_Y", "D_DOut", "D_DDX", "D_DDY"}); +} + } // namespace phi PD_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); @@ -135,6 +160,9 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad_grad, subtract_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad, divide_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad); PD_REGISTER_ARG_MAPPING_FN(elementwise_add, phi::ElementwiseAddOpArgumentMapping); @@ -158,3 +186,9 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad, phi::ElementwiseDivGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_div_grad_grad, phi::ElementwiseDivDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad, + phi::ElementwiseMulGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad, + phi::ElementwiseMulDoubleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad, + phi::ElementwiseMulTripleGradOpArgumentMapping); From b97e6d13fd552df98bda8156e7851d21399c6579 Mon Sep 17 00:00:00 2001 From: Linjie Chen <40840292+linjieccc@users.noreply.github.com> Date: Wed, 9 Mar 2022 22:38:14 +0800 Subject: [PATCH 002/239] [phi] move viterbi_decode to phi (#40186) * move viterbi to phi * move infershape to phi * update infershape * fix * resolve conflicts --- paddle/fluid/operators/viterbi_decode_op.cc | 53 +-- paddle/fluid/operators/viterbi_decode_op.cu | 206 -------- paddle/fluid/operators/viterbi_decode_op.h | 438 ------------------ paddle/phi/infermeta/ternary.cc | 47 ++ paddle/phi/infermeta/ternary.h | 8 + .../phi/kernels/cpu/viterbi_decode_kernel.cc | 319 +++++++++++++ .../kernels/funcs/viterbi_decode_functor.h | 140 ++++++ .../phi/kernels/gpu/viterbi_decode_kernel.cu | 402 ++++++++++++++++ paddle/phi/kernels/viterbi_decode_kernel.h | 30 ++ 9 files changed, 953 insertions(+), 690 deletions(-) delete mode 100644 paddle/fluid/operators/viterbi_decode_op.cu delete mode 100644 paddle/fluid/operators/viterbi_decode_op.h create mode 100644 paddle/phi/kernels/cpu/viterbi_decode_kernel.cc create mode 100644 paddle/phi/kernels/funcs/viterbi_decode_functor.h create mode 100644 paddle/phi/kernels/gpu/viterbi_decode_kernel.cu create mode 100644 paddle/phi/kernels/viterbi_decode_kernel.h diff --git a/paddle/fluid/operators/viterbi_decode_op.cc b/paddle/fluid/operators/viterbi_decode_op.cc index bf1cdeed65a84..602376d54e0d2 100644 --- a/paddle/fluid/operators/viterbi_decode_op.cc +++ b/paddle/fluid/operators/viterbi_decode_op.cc @@ -9,8 +9,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/viterbi_decode_op.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/infermeta_utils.h" +#include "paddle/phi/infermeta/ternary.h" namespace paddle { namespace operators { @@ -19,47 +21,6 @@ class ViterbiDecodeOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasInput("Transition"), "Input", "Transition", - "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasInput("Length"), "Input", "Length", "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasOutput("Scores"), "Output", "Scores", - "ViterbiDecode"); - OP_INOUT_CHECK(ctx->HasOutput("Path"), "Output", "Path", "ViterbiDecode"); - auto in_dims = ctx->GetInputDim("Input"); - PADDLE_ENFORCE_EQ(in_dims.size(), 3, - platform::errors::InvalidArgument( - "The rank of Input in ViterbiDecode must be 3. But " - "received Input's rank is %d.", - in_dims.size())); - auto length_dims = ctx->GetInputDim("Length"); - PADDLE_ENFORCE_EQ(length_dims.size(), 1, - platform::errors::InvalidArgument( - "The rank of Length in ViterbiDecode must be 1. But " - "received Length's rank is %d.", - length_dims.size())); - auto transition_dims = ctx->GetInputDim("Transition"); - PADDLE_ENFORCE_EQ( - transition_dims.size(), 2, - platform::errors::InvalidArgument( - "The rank of Transition in ViterbiDecode must be 2. But " - "received Transition's rank is %d.", - transition_dims.size())); - if (ctx->IsRuntime()) { - PADDLE_ENFORCE_EQ( - in_dims[0], length_dims[0], - platform::errors::InvalidArgument( - "The batch size of Input and Length should be equal.")); - PADDLE_ENFORCE_EQ(in_dims[2], transition_dims[0], - platform::errors::InvalidArgument( - "The number of tags of Input (%d) and Transition " - "(%d) should be equal.", - transition_dims[0], in_dims[2])); - } - ctx->SetOutputDim("Scores", length_dims); - } - protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -102,8 +63,8 @@ class ViterbiDecodeOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; namespace platform = paddle::platform; +DECLARE_INFER_SHAPE_FUNCTOR(viterbi_decode, ViterbiDecodeInferShapeFunctor, + PD_INFER_META(phi::ViterbiDecodeInferMeta)); REGISTER_OP_WITHOUT_GRADIENT(viterbi_decode, ops::ViterbiDecodeOp, - ops::ViterbiDecodeOpMaker); -REGISTER_OP_CPU_KERNEL( - viterbi_decode, ops::ViterbiDecodeKernel, - ops::ViterbiDecodeKernel); + ops::ViterbiDecodeOpMaker, + ViterbiDecodeInferShapeFunctor); diff --git a/paddle/fluid/operators/viterbi_decode_op.cu b/paddle/fluid/operators/viterbi_decode_op.cu deleted file mode 100644 index 68628fb2748c4..0000000000000 --- a/paddle/fluid/operators/viterbi_decode_op.cu +++ /dev/null @@ -1,206 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_functor.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/viterbi_decode_op.h" -#include "paddle/phi/kernels/funcs/gather.cu.h" - -#ifdef __NVCC__ -#include "cub/cub.cuh" -#endif -#ifdef __HIPCC__ -#include -namespace cub = hipcub; -#endif - -namespace paddle { -namespace operators { - -#define FIXED_BLOCK_DIM_CASE_BASE(log2_block_dim, ...) \ - case (1 << (log2_block_dim)): { \ - constexpr auto kBlockDim = (1 << (log2_block_dim)); \ - __VA_ARGS__; \ - } break - -#define FIXED_BLOCK_DIM_CASE(...) \ - FIXED_BLOCK_DIM_CASE_BASE(10, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(9, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(8, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(7, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(6, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(5, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(4, ##__VA_ARGS__); \ - FIXED_BLOCK_DIM_CASE_BASE(3, ##__VA_ARGS__); - -int64_t ComputeBlockSize(int64_t col) { - if (col > 512) - return 1024; - else if (col > 256) - return 512; - else if (col > 128) - return 256; - else if (col > 64) - return 128; - else if (col > 32) - return 64; - else if (col > 16) - return 32; - else if (col > 8) - return 16; - else - return 8; -} - -template