diff --git a/src/operator/tensor/elemwise_unary_op_pow.cc b/src/operator/tensor/elemwise_unary_op_pow.cc index 084772980ed1..6702625fcc43 100644 --- a/src/operator/tensor/elemwise_unary_op_pow.cc +++ b/src/operator/tensor/elemwise_unary_op_pow.cc @@ -222,7 +222,33 @@ The storage type of ``rsqrt`` output is always dense .set_attr("FGradient", ElemwiseGradUseIn{"_backward_rsqrt"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR( - _backward_rsqrt, unary_bwd); + _backward_rsqrt, unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : x + // ograds[0] : head_grad_grads (dL/dxgrad) + // f(x) = 1/(x^1/2) + // f'(x) = -1/(2*x^3/2) + // f''(x) = f'(x) * -3/(2*x) = 3/(4 * x^5/2) + auto dydx = n->inputs[0]; + auto x = n->inputs[1]; + auto dydx_mul_grad_x = nnvm::NodeEntry{n}; + auto op = mxnet::util::NodeOpGen{n}; + + auto two_x = op.mul(2.0, x); + auto r_two_x = op.reciprocal(two_x); + auto neg_r_two_x = op.negative(r_two_x); + auto three_by_two_neg_r_x = op.mul(3.0, neg_r_two_x); + auto x_grad_grad = op.mul(three_by_two_neg_r_x, dydx_mul_grad_x); + auto x_grad = op.div(dydx_mul_grad_x, dydx); + + std::vector ret; + ret.emplace_back(op.mul(ograds[0], x_grad)); + ret.emplace_back(op.mul(ograds[0], x_grad_grad)); + return ret; + }); // cbrt MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(cbrt, cpu, mshadow_op::cube_root) @@ -301,7 +327,33 @@ Example:: MXNET_OPERATOR_REGISTER_BINARY(_backward_rcbrt) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); + unary_bwd>) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // NodeEntry{n} : y_grad * f'(x) + // n->inputs[0] : y_grad + // n->inputs[1] : x + // ograds[0] : head_grad_grads (dL/dxgrad) + // f(x) = 1/(x^1/3) + // f'(x) = -1/(3*x^4/3) + // f''(x) = f'(x) * -4/(3*x) = 4/(9 * x^7/3) + auto dydx = n->inputs[0]; + auto x = n->inputs[1]; + auto dydx_mul_grad_x = nnvm::NodeEntry{n}; + auto op = mxnet::util::NodeOpGen{n}; + + auto three_x = op.mul(3.0, x); + auto r_three_x = op.reciprocal(three_x); + auto neg_r_three_x = op.negative(r_three_x); + auto four_by_three_neg_r_x = op.mul(4.0, neg_r_three_x); + auto x_grad_grad = op.mul(four_by_three_neg_r_x, dydx_mul_grad_x); + auto x_grad = op.div(dydx_mul_grad_x, dydx); + + std::vector ret; + ret.emplace_back(op.mul(ograds[0], x_grad)); + ret.emplace_back(op.mul(ograds[0], x_grad_grad)); + return ret; + }); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 527c35d5dd94..b995e5863c45 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -424,6 +424,46 @@ def grad_grad_op(x): check_second_order_unary(array, cbrt, grad_grad_op) +@with_seed() +def test_rsqrt(): + def rsqrt(x): + return nd.rsqrt(x) + + def grad_grad_op(x): + return 3/(4 * nd.sqrt(x**5)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, rsqrt, grad_grad_op) + + +@with_seed() +def test_rcbrt(): + def rcbrt(x): + return nd.rcbrt(x) + + def grad_grad_op(x): + return 4/(9 * nd.cbrt(x**7)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = sigma * array + mu + # Only positive numbers + assert((array > 0).all()) + check_second_order_unary(array, rcbrt, grad_grad_op) + + def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None): check_nth_order_unary(x, op, grad_grad_op, 2, rtol, atol)