Skip to content

Commit

Permalink
[Cherry-Pick] Add Log double grad kernel (PaddlePaddle#27604)
Browse files Browse the repository at this point in the history
[Cherry-Pick] Add Log double grad kernel (PaddlePaddle#27604)
  • Loading branch information
joey12300 authored Sep 28, 2020
1 parent 5634d2c commit 8622f8c
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
51 changes: 51 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,28 @@ class SquareDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};

// log Grad: dx = dout / x
// log Grad Grad: ddout = ddx / x; dx = -(dout / x) * (ddx / x)
template <typename T>
class LogDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("log_grad_grad");
op->SetInput("X", this->Input("X"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
// X@GRAD: dx
op->SetOutput("DX", this->InputGrad("X"));
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};

DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
Expand Down Expand Up @@ -1219,3 +1241,32 @@ REGISTER_OP_CPU_KERNEL(
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::AbsGradGradFunctor<int64_t>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
REGISTER_OPERATOR(
log, ops::ActivationOp, ops::LogOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::LogGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(log_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::LogDoubleGradMaker<paddle::framework::OpDesc>,
ops::LogDoubleGradMaker<paddle::imperative::OpBase>);

REGISTER_OPERATOR(
log_grad_grad,
ops::ActivationOpDoubleGrad<ops::LogGradGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);

REGISTER_ACTIVATION_CPU_KERNEL(log, Log, LogFunctor, LogGradFunctor);

REGISTER_OP_CPU_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<float>>,
ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<double>>,
ops::LogDoubleGradKernel<plat::CPUDeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
12 changes: 12 additions & 0 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,15 @@ REGISTER_OP_CUDA_KERNEL(
ops::ActivationDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::AbsGradGradFunctor<int64_t>>);
/* ========================================================================== */

/* ========================== Log register ==================================*/
REGISTER_ACTIVATION_CUDA_KERNEL(log, Log, LogFunctor, LogGradFunctor);

REGISTER_OP_CUDA_KERNEL(
log_grad_grad, ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<float>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<double>>,
ops::LogDoubleGradKernel<plat::CUDADeviceContext,
ops::LogGradGradFunctor<plat::float16>>);
/* ========================================================================== */
36 changes: 35 additions & 1 deletion paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,10 @@ class SquareDoubleGradKernel
}
};

template <typename DeviceContext, typename Functor>
class LogDoubleGradKernel
: public SquareDoubleGradKernel<DeviceContext, Functor> {};

template <typename DeviceContext, typename Functor>
class ELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -1740,6 +1744,37 @@ class PowGradKernel
functor(*place, x, out, dout, dx);
}
};

template <typename T>
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
// calculate dx first, so ddout can inplace ddx
if (dX) {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
dx.device(*d) = dout * static_cast<T>(-1) * ddx / (x * x);
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(1) / x;
}
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

} // namespace operators
} // namespace paddle

Expand All @@ -1758,7 +1793,6 @@ class PowGradKernel
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, Log, LogFunctor, LogGradFunctor); \
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@ def test_grad(self):
self.func(p)


class TestLogDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 1e-6
dtype = np.float64

x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.log(x)

x_arr = np.random.uniform(0.1, 1, shape).astype(dtype)

gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)

def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)


class TestSquareDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
Expand Down

0 comments on commit 8622f8c

Please sign in to comment.