Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-978] Higher Order Gradient Support arcsinh, arccosh. (#15530)
Browse files Browse the repository at this point in the history
* support arcsinh, arccosh for higher order grad

* add relevant tests

* update comments

* use NodeOpGen

* retrigger CI
  • Loading branch information
kshitij12345 authored and apeforest committed Sep 30, 2019
1 parent 3ffd2c2 commit 66f1656
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 2 deletions.
52 changes: 50 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,31 @@ The storage type of ``arcsinh`` output depends upon the input storage type:
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arcsinh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arcsinh,
unary_bwd<mshadow_op::arcsinh_grad>);
unary_bwd<mshadow_op::arcsinh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arcsinh(x)
// n: f'(x) = 1/(x^2 + 1)^1/2
// f''(x) = f'(x) * x/(x^2 + 1) = x/(x^2 + 1)^(3/2)
// Note: x/(x^2 + 1) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto op = mxnet::util::NodeOpGen{n};

auto grad_x = op.div(dydx_mul_grad_x, dydx);
auto grad_x_square = op.square(grad_x);
auto grad_x_square_mul_x = op.mul(grad_x_square, x);
auto grad_grad_x = op.mul(dydx_mul_grad_x, grad_x_square_mul_x);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(op.mul(ograds[0], grad_x));
ret.emplace_back(op.mul(ograds[0], grad_grad_x));
return ret;
});

// arccosh
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(arccosh, cpu, mshadow_op::arccosh)
Expand All @@ -451,7 +475,31 @@ The storage type of ``arccosh`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arccosh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arccosh,
unary_bwd<mshadow_op::arccosh_grad>);
unary_bwd<mshadow_op::arccosh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dxgrad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arccosh(x)
// n: f'(x) = 1/((x - 1)^1/2 * (x + 1)^1/2)
// f''(x) = f'(x) * x/((x + 1)*(x - 1)) = x/((x-1)^1/2 * (x+1)^1/2 * (x-1) * (x+1))
// Note: x/((x-1)*(x+1)) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto op = mxnet::util::NodeOpGen{n};

auto grad_x = op.div(dydx_mul_grad_x, dydx);
auto grad_x_square = op.square(grad_x);
auto grad_x_square_mul_x = op.mul(grad_x_square, x);
auto grad_grad_x = op.mul(dydx_mul_grad_x, grad_x_square_mul_x);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(op.mul(ograds[0], grad_x));
ret.emplace_back(op.mul(ograds[0], grad_grad_x));
return ret;
});

// arctanh
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arctanh, cpu, mshadow_op::arctanh)
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,40 @@ def grad_grad_op(x):
check_second_order_unary(array, arctan, grad_grad_op)


@with_seed()
def test_arcsinh():
def arcsinh(x):
return nd.arcsinh(x)

def grad_grad_op(x):
return x/nd.sqrt((nd.square(x)+1)**3)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, arcsinh, grad_grad_op)


@with_seed()
def test_arccosh():
def arccosh(x):
return nd.arccosh(x)

def grad_grad_op(x):
return x/(nd.sqrt(x-1) * nd.sqrt(x+1) * (x+1) * (x-1))

sigma = random.randint(25, 100)
mu = random.randint(500, 1000)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
array = array * sigma + mu
# Domain of arccosh 1 to infinity.
assert((array > 1).all())
check_second_order_unary(array, arccosh, grad_grad_op)


@with_seed()
def test_arctanh():
def arctanh(x):
Expand Down

0 comments on commit 66f1656

Please sign in to comment.