From 66f16560c083a15aed1f13b097fda4a8f24929fb Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 1 Oct 2019 02:33:10 +0530 Subject: [PATCH] [MXNET-978] Higher Order Gradient Support `arcsinh`, `arccosh`. (#15530) * support arcsinh, arccosh for higher order grad * add relevant tests * update comments * use NodeOpGen * retrigger CI --- src/operator/tensor/elemwise_unary_op_trig.cc | 52 ++++++++++++++++++- .../python/unittest/test_higher_order_grad.py | 34 ++++++++++++ 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_trig.cc b/src/operator/tensor/elemwise_unary_op_trig.cc index 7546f5096ccd..a436ebb284a3 100644 --- a/src/operator/tensor/elemwise_unary_op_trig.cc +++ b/src/operator/tensor/elemwise_unary_op_trig.cc @@ -437,7 +437,31 @@ The storage type of ``arcsinh`` output depends upon the input storage type: .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_arcsinh" }); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arcsinh, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // ograds[0]: head_grad_grads (dL/dxgrad) + // inputs[0]: dL/dy + // inputs[1]: x (ElemwiseGradUseIn) + // f(x) = arcsinh(x) + // n: f'(x) = 1/(x^2 + 1)^1/2 + // f''(x) = f'(x) * x/(x^2 + 1) = x/(x^2 + 1)^(3/2) + // Note: x/(x^2 + 1) = x * f'(x)^2 + auto dydx = n->inputs[0]; + auto x = n->inputs[1]; + auto dydx_mul_grad_x = nnvm::NodeEntry{n}; + auto op = mxnet::util::NodeOpGen{n}; + + auto grad_x = op.div(dydx_mul_grad_x, dydx); + auto grad_x_square = op.square(grad_x); + auto grad_x_square_mul_x = op.mul(grad_x_square, x); + auto grad_grad_x = op.mul(dydx_mul_grad_x, grad_x_square_mul_x); + + std::vector ret; + ret.emplace_back(op.mul(ograds[0], grad_x)); + ret.emplace_back(op.mul(ograds[0], grad_grad_x)); + return ret; + }); // arccosh MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(arccosh, cpu, mshadow_op::arccosh) @@ -451,7 +475,31 @@ The storage type of ``arccosh`` output is always dense .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_arccosh" }); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arccosh, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // ograds[0]: head_grad_grads (dL/dxgrad) + // inputs[0]: dL/dy + // inputs[1]: x (ElemwiseGradUseIn) + // f(x) = arccosh(x) + // n: f'(x) = 1/((x - 1)^1/2 * (x + 1)^1/2) + // f''(x) = f'(x) * x/((x + 1)*(x - 1)) = x/((x-1)^1/2 * (x+1)^1/2 * (x-1) * (x+1)) + // Note: x/((x-1)*(x+1)) = x * f'(x)^2 + auto dydx = n->inputs[0]; + auto x = n->inputs[1]; + auto dydx_mul_grad_x = nnvm::NodeEntry{n}; + auto op = mxnet::util::NodeOpGen{n}; + + auto grad_x = op.div(dydx_mul_grad_x, dydx); + auto grad_x_square = op.square(grad_x); + auto grad_x_square_mul_x = op.mul(grad_x_square, x); + auto grad_grad_x = op.mul(dydx_mul_grad_x, grad_x_square_mul_x); + + std::vector ret; + ret.emplace_back(op.mul(ograds[0], grad_x)); + ret.emplace_back(op.mul(ograds[0], grad_grad_x)); + return ret; + }); // arctanh MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arctanh, cpu, mshadow_op::arctanh) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index c4a1948fef72..0b0b00fffac7 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -150,6 +150,40 @@ def grad_grad_op(x): check_second_order_unary(array, arctan, grad_grad_op) +@with_seed() +def test_arcsinh(): + def arcsinh(x): + return nd.arcsinh(x) + + def grad_grad_op(x): + return x/nd.sqrt((nd.square(x)+1)**3) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + check_second_order_unary(array, arcsinh, grad_grad_op) + + +@with_seed() +def test_arccosh(): + def arccosh(x): + return nd.arccosh(x) + + def grad_grad_op(x): + return x/(nd.sqrt(x-1) * nd.sqrt(x+1) * (x+1) * (x-1)) + + sigma = random.randint(25, 100) + mu = random.randint(500, 1000) + + for dim in range(1, 5): + shape = rand_shape_nd(dim) + array = random_arrays(shape) + array = array * sigma + mu + # Domain of arccosh 1 to infinity. + assert((array > 1).all()) + check_second_order_unary(array, arccosh, grad_grad_op) + + @with_seed() def test_arctanh(): def arctanh(x):