From 3310def5bd80106263c4efebe2d65ed6fa28bbe0 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 7 Sep 2019 04:40:44 +0530 Subject: [PATCH] [MXNET-978] Higher Order Gradient Support `sqrt`, `cbrt`. (#15474) * support sqrt, cbrt for higher order grad * add relevant tests * remove unnecessary variable --- src/operator/tensor/elemwise_unary_op_pow.cc | 71 ++++++++++++++++++- .../python/unittest/test_higher_order_grad.py | 40 +++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_pow.cc b/src/operator/tensor/elemwise_unary_op_pow.cc index f22dabc7201a..486fe268b0cf 100644 --- a/src/operator/tensor/elemwise_unary_op_pow.cc +++ b/src/operator/tensor/elemwise_unary_op_pow.cc @@ -143,7 +143,38 @@ The storage type of ``sqrt`` output depends upon the input storage type: .set_attr("FGradient", ElemwiseGradUseOut{"_backward_sqrt"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sqrt, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : f(x) = x^1/2 + // ograds[0] : head_grads + // f'(x) = 1/(2*x^1/2) + // f''(x) = f'(x) * -1/(2*x) = -1/(4 * x^3/2) + const std::unordered_map mul_args = {{"scalar", "0.5"}}; + auto x = MakeNode("square", n->attrs.name + "_cube_x", {n->inputs[1]}, nullptr, &n); + auto r_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_x", + {nnvm::NodeEntry{x}}, nullptr, &n); + auto neg_r_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_x", + {nnvm::NodeEntry{r_x}}, nullptr, &n); + auto half_neg_r_cube_x = MakeNode("_mul_scalar", n->attrs.name + "_half_neg_reciprocal_x", + {nnvm::NodeEntry{neg_r_x}}, &mul_args, &n); + auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_mid", + {nnvm::NodeEntry{half_neg_r_cube_x}, n->inputs[0]}, + nullptr, &n); + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + // when building gradient graph, the backward node of n->inputs[1] will be + // added to the graph again, therefore f`(x) will be multiplied + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", + {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n)); + return ret; + }); // rsqrt MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(rsqrt, cpu, mshadow_op::reciprocal_square_root) @@ -186,7 +217,43 @@ The storage type of ``cbrt`` output depends upon the input storage type: .set_attr("FGradient", ElemwiseGradUseOut{"_backward_cbrt"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_cbrt, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : f(x) = x^1/3 + // ograds[0] : head_grads + // f'(x) = 1/(3*x^2/3) + // f''(x) = f'(x) * -2/(3*x) = -2/(9 * x^5/3) + const std::unordered_map three = {{"scalar", "3.0"}}; + const std::unordered_map two = {{"scalar", "2.0"}}; + auto x = MakeNode("_power_scalar", n->attrs.name + "_x", {n->inputs[1]}, &three, &n); + auto three_x = MakeNode("_mul_scalar", n->attrs.name + "_three_x", + {nnvm::NodeEntry{x}}, &three, &n); + auto r_three_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_three_x", + {nnvm::NodeEntry{three_x}}, nullptr, &n); + auto neg_r_three_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_three_x", + {nnvm::NodeEntry{r_three_x}}, nullptr, &n); + auto two_third_neg_r_x = MakeNode("_mul_scalar", + n->attrs.name + "_two_third_neg_reciprocal_x", + {nnvm::NodeEntry{neg_r_three_x}}, &two, &n); + auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_mid", + {nnvm::NodeEntry{two_third_neg_r_x}, n->inputs[0]}, + nullptr, &n); + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + // when building gradient graph, the backward node of n->inputs[1] will be + // added to the graph again, therefore f`(x) will be multiplied + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", + {ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n)); + return ret; + }); + // rcbrt MXNET_OPERATOR_REGISTER_UNARY(rcbrt) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index c70c747411b8..64c429a94c49 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -231,6 +231,46 @@ def grad_grad_op(x): check_second_order_unary(array, sigmoid, grad_grad_op) +@with_seed() +def test_sqrt(): + def sqrt(x): + return nd.sqrt(x) + + def grad_grad_op(x): + return -1/(4 * sqrt(x**3)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, sqrt, grad_grad_op) + + +@with_seed() +def test_cbrt(): + def cbrt(x): + return nd.cbrt(x) + + def grad_grad_op(x): + return -2/(9 * cbrt(x**5)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, cbrt, grad_grad_op) + + def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None): x = nd.array(x) grad_grad_x = grad_grad_op(x)