Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[bug] fix higher grad log (#15120)
Browse files Browse the repository at this point in the history
* fix bug with higher order log implementation.

* bug: the head_grads were not preserved in higher order.
* add test to validate the fix of the same.

* fix grad for head_grads and update relevant test

* address comments

* remove assertion for y_grad gradient.
* rename variables.
* fix and update computation.

* address comments

* explicitly pass arguments with name.

* fix mistyped comment.

Co-Authored-By: Lin Yuan <[email protected]>
  • Loading branch information
kshitij12345 and apeforest committed Jun 20, 2019
1 parent 4d96671 commit 2b7fbc5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 30 deletions.
68 changes: 42 additions & 26 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1090,68 +1090,84 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log,
unary_bwd<mshadow_op::log_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log(x)
// f'(x) = 1/x
// f''(x) = -1 * (f'(x) * f'(x))
auto gx = nnvm::NodeEntry{n};
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, gx}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dlogx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10,
unary_bwd<mshadow_op::log10_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log10
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log10(x)
// f'(x) = 1 / (log(10) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n, 0, 0};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
{n->inputs[0]}, nullptr, &n);
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2,
unary_bwd<mshadow_op::log2_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log2
// ograds[0]: dL/dxgrad
// inputs[0]: dL/dy
// inputs[1]: x
// f(x) = y = log2(x)
// f'(x) = 1 / (log(2) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx",
{n->inputs[0]}, nullptr, &n);
auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);
auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid",
{dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n);
auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2",
{nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n));
return ret;
});

Expand Down
21 changes: 17 additions & 4 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,26 @@ def grad_grad_op(x):

def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
expect_grad_grad = grad_grad_op(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()

# Manual head_grads.
y_grad = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)

# Perform compute.
with autograd.record():
y = op(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
create_graph=True, retain_graph=True)[0]
x_grad.backward(head_grad_grads)

# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()

# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())


if __name__ == '__main__':
Expand Down

0 comments on commit 2b7fbc5

Please sign in to comment.