Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Add higher order gradient support tan, tanh #15253

Merged
merged 10 commits into from
Jul 29, 2019
57 changes: 55 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,33 @@ The storage type of ``tan`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tan" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad (dL/dy)
// n->inputs[1] : y = f(x) = tan(x) (ElemwiseGradUseOut)
// ograds[0] : head_grads (dL/dxgrad)
// f'(x) = sec^2(x)
// f''(x) = 2 * f'(x) * f(x)
//
// Note: When building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
// So we need to compute only -> 2 * f(x) * dL/dy_grad * y_grad
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we clarify / add a comment on why is correct to multiply by y_grad (the first head gradient?) again? This would help readers as is not obvious, as well as the very non-obvious implicit multiplication by f'(x) it compounds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks. Will get to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the comment. See if it is okay? Or maybe the phrasing can be improved.
Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. About the outputs, I think we should write some documentation explaining what we are doing as I find it non trivial. Can you help me understand the y_grad_grad (first output)?

If you want, we can move the conversation to the dev list or slack, as the PR LGTM.

IMG_20190711_122948__01

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I clarified with @apeforest , this makes sense now.

{n->inputs[0], nnvm::NodeEntry{two_y}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// arcsin
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsin, cpu, mshadow_op::arcsin)
Expand Down Expand Up @@ -290,7 +316,34 @@ The storage type of ``tanh`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tanh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad (dL/dy)
// n->inputs[1] : y = f(x) = tanh(x) (ElemwiseGradUseOut)
// ograds[0] : head_grads dL/dxgrad
// f'(x) = sech^2(x)
// f''(x) = -2 * f'(x) * f(x)
//
// Note: when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
// So we need to compute only -> -2 * f(x) * dL/dy_grad * y_grad
const std::unordered_map<std::string, std::string> args = {{"scalar", "-2.0"}};
auto neg_two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_neg_two",
{n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{neg_two_y}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// arcsinh
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsinh, cpu, mshadow_op::arcsinh)
Expand Down
40 changes: 38 additions & 2 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,41 @@ def grad_grad_op(x):
check_second_order_unary(array, cos, grad_grad_op)


@with_seed()
def test_tan():
def tan(x):
return nd.tan(x)

def grad_op(x):
return 1 / nd.cos(x)**2

def grad_grad_op(x):
return 2 * tan(x) * grad_op(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, tan, grad_grad_op)


@with_seed()
def test_tanh():
def tanh(x):
return nd.tanh(x)

def grad_op(x):
return 1 / nd.cosh(x)**2

def grad_grad_op(x):
return -2 * tanh(x) * grad_op(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(
array, tanh, grad_grad_op, rtol=1e-6, atol=1e-6)


@with_seed()
def test_relu():
def relu(x):
Expand Down Expand Up @@ -123,7 +158,7 @@ def grad_grad_op(x):
check_second_order_unary(array, sigmoid, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op):
def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()
Expand All @@ -144,7 +179,8 @@ def check_second_order_unary(x, op, grad_grad_op):
y_grad.asnumpy()

# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
assert_almost_equal(expected_grad_grad,
x.grad.asnumpy(), rtol=rtol, atol=atol)


if __name__ == '__main__':
Expand Down