From a3ae30979989f488cd933c2fbb6416a4e187de9d Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Mon, 8 Jul 2019 02:35:09 +0530 Subject: [PATCH] [MXNET-978] Higher Order Gradient Support `reciprocal`, `abs`. (#15413) * add higher order support for reciprocal and abs * add relevant tests * address comments * fix extra line in tests. * fix missing space. * fix incorrect comment. --- .../tensor/elemwise_unary_op_basic.cc | 54 ++++++++++++++++++- .../python/unittest/test_higher_order_grad.py | 27 ++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 26c74085dbe6..6da384d3679b 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -717,7 +717,38 @@ Example:: MXNET_OPERATOR_REGISTER_BINARY(_backward_reciprocal) .set_attr("FCompute", - ElemwiseBinaryOp::Compute >); + ElemwiseBinaryOp::Compute >) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = 1/x + // f'(x) = -1/x^2 + // f''(x) = 2/x^3 = -2 * (f'(x) * f(x)) + + const std::unordered_map args = {{"scalar", "-2.0"}}; + + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {dydx_mul_dldy, n->inputs[0]}, nullptr, &n); + auto fx = MakeNode("reciprocal", n->attrs.name + "_fx", + {n->inputs[1]}, nullptr, &n); + + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{fx}}, nullptr, &n); + + auto d2ydx2 = MakeNode("_mul_scalar", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, &args, &n); + + std::vector ret; + + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); + return ret; +}); // abs MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(abs, cpu, mshadow_op::abs) @@ -736,7 +767,26 @@ The storage type of ``abs`` output depends upon the input storage type: )code" ADD_FILELINE) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_abs"}); -MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_abs, unary_bwd); +MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_abs, unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) -> abs(x) + // f'(x) = 1 if x > 0 else -1 + // f''(x) = 0 + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", + {ograds[0], nnvm::NodeEntry(dydx)}, nullptr, &n)); + ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_backward_grad_grad_in", + {n->inputs[1]}, nullptr, &n)); + return ret; + }); + // sign MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(sign, cpu, mshadow_op::sign) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index ad14c5050c1b..0f07d014d435 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -107,6 +107,33 @@ def grad_grad_op(x): @with_seed() +def test_reciprocal(): + def reciprocal(x): + return nd.reciprocal(x) + + def grad_grad_op(x): + return 2 / x**3 + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + check_second_order_unary(array, reciprocal, grad_grad_op) + + +@with_seed() +def test_abs(): + def abs(x): + return nd.abs(x) + + def grad_grad_op(x): + return nd.zeros_like(x) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + check_second_order_unary(array, abs, grad_grad_op) + + def test_sigmoid(): def sigmoid(x): return nd.sigmoid(x)