Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher Order Gradient Support sinh, cosh. #15412

Merged
51 changes: 49 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,30 @@ The storage type of ``sinh`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_sinh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sinh, unary_bwd<mshadow_op::sinh_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sinh, unary_bwd<mshadow_op::sinh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = sinh(x)
// f'(x) = cosh(x)
// f''(x) = sinh(x)
auto dydx = MakeNode("cosh", n->attrs.name + "_dydx",
{n->inputs[1]}, nullptr, &n);
auto d2ydx2 = MakeNode("sinh", n->attrs.name + "_grad_grad_mid", {n->inputs[1]}, nullptr, &n);

auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_mid",
{n->inputs[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// cosh
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(cosh, cpu, mshadow_op::cosh)
Expand All @@ -328,7 +351,31 @@ The storage type of ``cosh`` output is always dense
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_cosh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cosh, unary_bwd<mshadow_op::cosh_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_cosh, unary_bwd<mshadow_op::cosh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseUseIn)
// f(x) = cosh(x)
// f'(x) = sinh(x)
// f''(x) = cosh(x)
auto dydx = MakeNode("sinh", n->attrs.name + "_dydx",
{n->inputs[1]}, nullptr, &n);
auto d2ydx2 = MakeNode("cosh", n->attrs.name + "_grad_grad_mid", {n->inputs[1]}, nullptr, &n);

auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_mid",
{n->inputs[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});


// tanh
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(tanh, cpu, mshadow_op::tanh)
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,34 @@ def grad_grad_op(x):
check_second_order_unary(array, tan, grad_grad_op)


@with_seed()
def test_sinh():
def sinh(x):
return nd.sinh(x)

def grad_grad_op(x):
return sinh(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sinh, grad_grad_op)


@with_seed()
def test_cosh():
def cosh(x):
return nd.cosh(x)

def grad_grad_op(x):
return cosh(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cosh, grad_grad_op)


@with_seed()
def test_tanh():
def tanh(x):
Expand Down Expand Up @@ -245,6 +273,7 @@ def grad_grad_op(x):
check_second_order_unary(array, dropout, grad_grad_op)


@with_seed()
def test_sigmoid():
def sigmoid(x):
return nd.sigmoid(x)
Expand Down