Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher Order Gradient Support arcsinh, arccosh. #15530

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions src/operator/tensor/elemwise_unary_op_trig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,36 @@ The storage type of ``arcsinh`` output depends upon the input storage type:
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arcsinh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arcsinh,
unary_bwd<mshadow_op::arcsinh_grad>);
unary_bwd<mshadow_op::arcsinh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dy_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dL/dx_grad ?

#15331 (comment)

// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arcsinh(x)
// n: f'(x) = 1/(x^2 + 1)^1/2
// f''(x) = f'(x) * x/(x^2 + 1) = x/(x^2 + 1)^(3/2)
// Note: x/(x^2 + 1) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto grad_x = MakeNode("elemwise_div", n->attrs.name + "_grad_x",
{dydx_mul_grad_x, dydx}, nullptr, &n);
auto grad_x_square = MakeNode("square", n->attrs.name + "_grad_x_square",
{nnvm::NodeEntry{grad_x}}, nullptr, &n);
auto grad_x_square_mul_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_x_square_mul_x",
{nnvm::NodeEntry{grad_x_square}, x}, nullptr, &n);
auto grad_grad_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_x",
{dydx_mul_grad_x, nnvm::NodeEntry{grad_x_square_mul_x}},
nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{grad_x}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_x}}, nullptr, &n));
return ret;
});

// arccosh
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(arccosh, cpu, mshadow_op::arccosh)
Expand All @@ -321,7 +350,36 @@ The storage type of ``arccosh`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arccosh" });

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_arccosh,
unary_bwd<mshadow_op::arccosh_grad>);
unary_bwd<mshadow_op::arccosh_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// ograds[0]: head_grad_grads (dL/dy_grad)
// inputs[0]: dL/dy
// inputs[1]: x (ElemwiseGradUseIn)
// f(x) = arccosh(x)
// n: f'(x) = 1/((x - 1)^1/2 * (x + 1)^1/2)
// f''(x) = f'(x) * x/((x + 1)*(x - 1)) = x/((x-1)^1/2 * (x+1)^1/2 * (x-1) * (x+1))
// Note: x/((x-1)*(x+1)) = x * f'(x)^2
auto dydx = n->inputs[0];
auto x = n->inputs[1];
auto dydx_mul_grad_x = nnvm::NodeEntry{n};
auto grad_x = MakeNode("elemwise_div", n->attrs.name + "_grad_x",
{dydx_mul_grad_x, dydx}, nullptr, &n);
auto grad_x_square = MakeNode("square", n->attrs.name + "_grad_x_square",
{nnvm::NodeEntry{grad_x}}, nullptr, &n);
auto grad_x_square_mul_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_x_square_mul_x",
{nnvm::NodeEntry{grad_x_square}, x}, nullptr, &n);
auto grad_grad_x = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_x",
{dydx_mul_grad_x, nnvm::NodeEntry{grad_x_square_mul_x}},
nullptr, &n);

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], nnvm::NodeEntry{grad_x}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_x}}, nullptr, &n));
return ret;
});

// arctanh
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(arctanh, cpu, mshadow_op::arctanh)
Expand Down
35 changes: 35 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


import math
import random
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed
Expand Down Expand Up @@ -50,6 +51,40 @@ def grad_grad_op(x):
check_second_order_unary(array, cos, grad_grad_op)


@with_seed()
def test_arcsinh():
def arcsinh(x):
return nd.arcsinh(x)

def grad_grad_op(x):
return x/nd.sqrt((nd.square(x)+1)**3)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, arcsinh, grad_grad_op)


@with_seed()
def test_arccosh():
def arccosh(x):
return nd.arccosh(x)

def grad_grad_op(x):
return x/(nd.sqrt(x-1) * nd.sqrt(x+1) * (x+1) * (x-1))

sigma = random.randint(25, 100)
mu = random.randint(500, 1000)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
array = array * sigma + mu
# Domain of arccosh 1 to infinity.
assert((array > 1).all())
check_second_order_unary(array, arccosh, grad_grad_op)


@with_seed()
def test_relu():
def relu(x):
Expand Down