Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-978] Higher Order Gradient Support reciprocal, abs. (#15413)
Browse files Browse the repository at this point in the history
* add higher order support for reciprocal and abs

* add relevant tests

* address comments

* fix extra line in tests.
* fix missing space.
* fix incorrect comment.
  • Loading branch information
kshitij12345 authored and apeforest committed Jul 7, 2019
1 parent a6ed12f commit a3ae309
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
54 changes: 52 additions & 2 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,38 @@ Example::

MXNET_OPERATOR_REGISTER_BINARY(_backward_reciprocal)
.set_attr<FCompute>("FCompute<cpu>",
ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::reciprocal_grad> >);
ElemwiseBinaryOp::Compute<cpu, unary_bwd<mshadow_op::reciprocal_grad> >)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = 1/x
// f'(x) = -1/x^2
// f''(x) = 2/x^3 = -2 * (f'(x) * f(x))

const std::unordered_map<std::string, std::string> args = {{"scalar", "-2.0"}};

auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
{dydx_mul_dldy, n->inputs[0]}, nullptr, &n);
auto fx = MakeNode("reciprocal", n->attrs.name + "_fx",
{n->inputs[1]}, nullptr, &n);

auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{fx}}, nullptr, &n);

auto d2ydx2 = MakeNode("_mul_scalar", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, &args, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

// abs
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(abs, cpu, mshadow_op::abs)
Expand All @@ -736,7 +767,26 @@ The storage type of ``abs`` output depends upon the input storage type:
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_abs"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_abs, unary_bwd<mshadow_op::sign>);
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_abs, unary_bwd<mshadow_op::sign>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) -> abs(x)
// f'(x) = 1 if x > 0 else -1
// f''(x) = 0
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry(dydx)}, nullptr, &n));
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_backward_grad_grad_in",
{n->inputs[1]}, nullptr, &n));
return ret;
});


// sign
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(sign, cpu, mshadow_op::sign)
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,33 @@ def grad_grad_op(x):


@with_seed()
def test_reciprocal():
def reciprocal(x):
return nd.reciprocal(x)

def grad_grad_op(x):
return 2 / x**3

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, reciprocal, grad_grad_op)


@with_seed()
def test_abs():
def abs(x):
return nd.abs(x)

def grad_grad_op(x):
return nd.zeros_like(x)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, abs, grad_grad_op)


def test_sigmoid():
def sigmoid(x):
return nd.sigmoid(x)
Expand Down

0 comments on commit a3ae309

Please sign in to comment.