Skip to content

Commit

Permalink
[MXNET-978] Higher Order Gradient Support sqrt, cbrt. (apache#15474)
Browse files Browse the repository at this point in the history
* support sqrt, cbrt for higher order grad

* add relevant tests

* remove unnecessary variable
  • Loading branch information
kshitij12345 authored and gyshi committed Sep 7, 2019
1 parent e51b647 commit b0621b2
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
71 changes: 69 additions & 2 deletions src/operator/tensor/elemwise_unary_op_pow.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,38 @@ The storage type of ``sqrt`` output depends upon the input storage type:
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sqrt"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_sqrt,
unary_bwd<mshadow_op::square_root_grad>);
unary_bwd<mshadow_op::square_root_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = x^1/2
// ograds[0] : head_grads
// f'(x) = 1/(2*x^1/2)
// f''(x) = f'(x) * -1/(2*x) = -1/(4 * x^3/2)
const std::unordered_map<std::string, std::string> mul_args = {{"scalar", "0.5"}};
auto x = MakeNode("square", n->attrs.name + "_cube_x", {n->inputs[1]}, nullptr, &n);
auto r_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_x",
{nnvm::NodeEntry{x}}, nullptr, &n);
auto neg_r_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_x",
{nnvm::NodeEntry{r_x}}, nullptr, &n);
auto half_neg_r_cube_x = MakeNode("_mul_scalar", n->attrs.name + "_half_neg_reciprocal_x",
{nnvm::NodeEntry{neg_r_x}}, &mul_args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_mid",
{nnvm::NodeEntry{half_neg_r_cube_x}, n->inputs[0]},
nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});

// rsqrt
MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(rsqrt, cpu, mshadow_op::reciprocal_square_root)
Expand Down Expand Up @@ -186,7 +217,43 @@ The storage type of ``cbrt`` output depends upon the input storage type:
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_cbrt"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_cbrt,
unary_bwd<mshadow_op::cube_root_grad>);
unary_bwd<mshadow_op::cube_root_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// NodeEntry{n} : y_grad * f'(x)
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = x^1/3
// ograds[0] : head_grads
// f'(x) = 1/(3*x^2/3)
// f''(x) = f'(x) * -2/(3*x) = -2/(9 * x^5/3)
const std::unordered_map<std::string, std::string> three = {{"scalar", "3.0"}};
const std::unordered_map<std::string, std::string> two = {{"scalar", "2.0"}};
auto x = MakeNode("_power_scalar", n->attrs.name + "_x", {n->inputs[1]}, &three, &n);
auto three_x = MakeNode("_mul_scalar", n->attrs.name + "_three_x",
{nnvm::NodeEntry{x}}, &three, &n);
auto r_three_x = MakeNode("reciprocal", n->attrs.name + "_reciprocal_three_x",
{nnvm::NodeEntry{three_x}}, nullptr, &n);
auto neg_r_three_x = MakeNode("negative", n->attrs.name + "_neg_reciprocal_three_x",
{nnvm::NodeEntry{r_three_x}}, nullptr, &n);
auto two_third_neg_r_x = MakeNode("_mul_scalar",
n->attrs.name + "_two_third_neg_reciprocal_x",
{nnvm::NodeEntry{neg_r_three_x}}, &two, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_grad_mid",
{nnvm::NodeEntry{two_third_neg_r_x}, n->inputs[0]},
nullptr, &n);
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div",
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n);

// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad",
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});


// rcbrt
MXNET_OPERATOR_REGISTER_UNARY(rcbrt)
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,46 @@ def grad_grad_op(x):
check_second_order_unary(array, sigmoid, grad_grad_op)


@with_seed()
def test_sqrt():
def sqrt(x):
return nd.sqrt(x)

def grad_grad_op(x):
return -1/(4 * sqrt(x**3))

sigma = random.randint(25, 100)
mu = random.randint(500, 1000)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
array = sigma * array + mu
# Only positive numbers
assert((array > 0).all())
check_second_order_unary(array, sqrt, grad_grad_op)


@with_seed()
def test_cbrt():
def cbrt(x):
return nd.cbrt(x)

def grad_grad_op(x):
return -2/(9 * cbrt(x**5))

sigma = random.randint(25, 100)
mu = random.randint(500, 1000)

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
array = sigma * array + mu
# Only positive numbers
assert((array > 0).all())
check_second_order_unary(array, cbrt, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op, rtol=None, atol=None):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
Expand Down

0 comments on commit b0621b2

Please sign in to comment.