Skip to content

Commit

Permalink
[MXNET-978] Add higher order gradient support tan, tanh (apache#1…
Browse files Browse the repository at this point in the history
…5253)

* init to reset

* issue: higher order backward sigmoid

* update gradient code.

update code as per apache#15288.

* undo changes

* relax tolerance of gradient mismatch for tanh

* update comments

* update comments
  • Loading branch information
kshitij12345 authored and Ubuntu committed Aug 20, 2019
1 parent 0f35cdf commit 4bc265c
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 4 deletions.
57 changes: 55 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,33 @@ The storage type of ``tan`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tan" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tan, unary_bwd<mshadow_op::tan_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad (dL/dy)
// n->inputs[1] : y = f(x) = tan(x) (ElemwiseGradUseOut)
// ograds[0] : head_grads (dL/dxgrad)
// f'(x) = sec^2(x)
// f''(x) = 2 * f'(x) * f(x)
//
// Note: When building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
// So we need to compute only -> 2 * f(x) * dL/dy_grad * y_grad
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{two_y}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// arcsin
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsin, cpu, mshadow_op::arcsin)
Expand Down Expand Up @@ -290,7 +316,34 @@ The storage type of ``tanh`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{ "_backward_tanh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_tanh, unary_bwd<mshadow_op::tanh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad (dL/dy)
// n->inputs[1] : y = f(x) = tanh(x) (ElemwiseGradUseOut)
// ograds[0] : head_grads dL/dxgrad
// f'(x) = sech^2(x)
// f''(x) = -2 * f'(x) * f(x)
//
// Note: when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
// So we need to compute only -> -2 * f(x) * dL/dy_grad * y_grad
const std::unordered_map<std::string, std::string> args = {{"scalar", "-2.0"}};
auto neg_two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_neg_two",
{n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{neg_two_y}}, nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// arcsinh
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arcsinh, cpu, mshadow_op::arcsinh)
Expand Down
40 changes: 38 additions & 2 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,41 @@ def grad_grad_op(x):
check_second_order_unary(array, cos, grad_grad_op)


@with_seed()
def test_tan():
def tan(x):
return nd.tan(x)

def grad_op(x):
return 1 / nd.cos(x)**2

def grad_grad_op(x):
return 2 * tan(x) * grad_op(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, tan, grad_grad_op)


@with_seed()
def test_tanh():
def tanh(x):
return nd.tanh(x)

def grad_op(x):
return 1 / nd.cosh(x)**2

def grad_grad_op(x):
return -2 * tanh(x) * grad_op(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(
array, tanh, grad_grad_op, rtol=1e-6, atol=1e-6)


@with_seed()
def test_relu():
def relu(x):
Expand Down Expand Up @@ -150,7 +185,7 @@ def grad_grad_op(x):
check_second_order_unary(array, sigmoid, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op):
def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()
Expand All @@ -171,7 +206,8 @@ def check_second_order_unary(x, op, grad_grad_op):
y_grad.asnumpy()

# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
assert_almost_equal(expected_grad_grad,
x.grad.asnumpy(), rtol=rtol, atol=atol)


if __name__ == '__main__':
Expand Down

0 comments on commit 4bc265c

Please sign in to comment.