From a5604e339c06ad9bb45f141cabb4a80e1a44ab27 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 6 Jul 2019 21:56:37 +0530 Subject: [PATCH 1/3] support rsqrt, rcbrt for higher order grad --- .../tensor/elemwise_unary_op_basic.cc | 74 ++++++++++++++++++- 1 file changed, 72 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 26c74085dbe6..03270baf42d5 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -956,7 +956,42 @@ The storage type of ``rsqrt`` output is always dense .set_attr("FGradient", ElemwiseGradUseIn{"_backward_rsqrt"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR( - _backward_rsqrt, unary_bwd); + _backward_rsqrt, unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : x + // ograds[0] : head_grads + // f(x) = 1/(x^1/2) + // f'(x) = -1/(2*x^3/2) + // f''(x) = f'(x) * -3/(2*x) = 3/(4 * x^5/2) + const std::unordered_map three = {{"scalar", "3.0"}}; + const std::unordered_map two = {{"scalar", "2.0"}}; + auto x = n->inputs[1]; + auto dldy_mul_dydx = nnvm::NodeEntry{n}; + auto two_x = MakeNode("_mul_scalar", n->attrs.name + "_two_x", + {nnvm::NodeEntry{x}}, &two, &n); + auto r_two_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_two_x", + {nnvm::NodeEntry{two_x}}, nullptr, &n); + auto neg_r_two_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_two_x", + {nnvm::NodeEntry{r_two_x}}, nullptr, &n); + auto three_by_two_neg_r_x = MakeNode("_mul_scalar", + n->attrs.name + "_three_by_two_neg_reciprocal_x", + {nnvm::NodeEntry{neg_r_two_x}}, &three, &n); + auto grad_grad_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_x", + {nnvm::NodeEntry{three_by_two_neg_r_x}, dldy_mul_dydx}, + nullptr, &n); + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", + {ograds[0], nnvm::NodeEntry{grad_grad_x}}, nullptr, &n)); + return ret; + }); // cbrt MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(cbrt, cpu, mshadow_op::cube_root) @@ -1036,7 +1071,42 @@ Example:: MXNET_OPERATOR_REGISTER_BINARY(_backward_rcbrt) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); + unary_bwd>) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : x + // ograds[0] : head_grads + // f(x) = 1/(x^1/3) + // f'(x) = -1/(3*x^4/3) + // f''(x) = f'(x) * -4/(3*x) = 4/(9 * x^7/3) + const std::unordered_map three = {{"scalar", "3.0"}}; + const std::unordered_map four = {{"scalar", "4.0"}}; + auto x = n->inputs[1]; + auto dldy_mul_dydx = nnvm::NodeEntry{n}; + auto three_x = MakeNode("_mul_scalar", n->attrs.name + "_three_x", + {nnvm::NodeEntry{x}}, &three, &n); + auto r_three_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_three_x", + {nnvm::NodeEntry{three_x}}, nullptr, &n); + auto neg_r_three_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_three_x", + {nnvm::NodeEntry{r_three_x}}, nullptr, &n); + auto four_by_three_neg_r_x = MakeNode("_mul_scalar", + n->attrs.name + "_four_by_three_neg_reciprocal_x", + {nnvm::NodeEntry{neg_r_three_x}}, &four, &n); + auto grad_grad_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_x", + {nnvm::NodeEntry{four_by_three_neg_r_x}, dldy_mul_dydx}, + nullptr, &n); + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", + {nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); + + std::vector ret; + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", + {ograds[0], nnvm::NodeEntry{grad_grad_x}}, nullptr, &n)); + return ret; + }); // exp #if MSHADOW_USE_MKL == 1 From ce4c6b1148d7a613fb1192444de0edf2705d3b7c Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 6 Jul 2019 21:56:53 +0530 Subject: [PATCH 2/3] add relevant tests --- .../python/unittest/test_higher_order_grad.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index ad14c5050c1b..37625ab33239 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -17,6 +17,7 @@ import math +import random from mxnet import nd, autograd from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd from common import with_seed @@ -123,6 +124,46 @@ def grad_grad_op(x): check_second_order_unary(array, sigmoid, grad_grad_op) +@with_seed() +def test_rsqrt(): + def rsqrt(x): + return nd.rsqrt(x) + + def grad_grad_op(x): + return 3/(4 * nd.sqrt(x**5)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, rsqrt, grad_grad_op) + + +@with_seed() +def test_rcbrt(): + def rcbrt(x): + return nd.rcbrt(x) + + def grad_grad_op(x): + return 4/(9 * nd.cbrt(x**7)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, rcbrt, grad_grad_op) + + def check_second_order_unary(x, op, grad_grad_op): x = nd.array(x) grad_grad_x = grad_grad_op(x) From 640b25b67fcc577103c7433f3c14d46e43630e4d Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Fri, 26 Jul 2019 20:47:32 +0530 Subject: [PATCH 3/3] update comments --- src/operator/tensor/elemwise_unary_op_basic.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 03270baf42d5..0bbdce25fa53 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -962,7 +962,7 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR( // NodeEntry{n} : y_grad * f'(x) // n->inputs[0] : y_grad // n->inputs[1] : x - // ograds[0] : head_grads + // ograds[0] : head_grad_grads (dL/dxgrad) // f(x) = 1/(x^1/2) // f'(x) = -1/(2*x^3/2) // f''(x) = f'(x) * -3/(2*x) = 3/(4 * x^5/2) @@ -1077,7 +1077,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rcbrt) // NodeEntry{n} : y_grad * f'(x) // n->inputs[0] : y_grad // n->inputs[1] : x - // ograds[0] : head_grads + // ograds[0] : head_grad_grads (dL/dxgrad) // f(x) = 1/(x^1/3) // f'(x) = -1/(3*x^4/3) // f''(x) = f'(x) * -4/(3*x) = 4/(9 * x^7/3)