-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Higher order gradient for sigmoid #15288
Conversation
…xnet into develop/higher_order_grad
@kshitij12345 I have figured out how backward works when one of the inputs is an output of the forward node. Please review this PR. Thanks! |
@larroy @sxjscience Please help review this PR. Thanks! |
@larroy I also added the method to dump computation graph to imperative mode since it will be very useful for us to debug. However, it's still very rudimentary and we still need your help to implement a |
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul", | ||
{n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, nullptr, &n); | ||
// when building gradient graph, the backward node of n->inputs[1] will be | ||
// added to the graph again, therefore f`(x) will be multiplied |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this behaviour seem a bit odd? Is this actually the expected behaviour? What would have happened if this was split
like function where we would have had many outputs, backward of all outputs be added? Actually I am confused by the behaviour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is actually the expected behavior. The nnvm graph will perform a DFS traverse when performing backward pass. Since the input to this backward_sigmoid node is an output of another node sigmoid, during the backward pass the gradient function sigmoid will be invoked.
See:
This is to collect dependent nodes in the graph during RecordOp:
https://github.com/apache/incubator-mxnet/blob/master/src/imperative/imperative.cc#L180
This is to actually perform the backward pass:
https://github.com/dmlc/tvm/blob/21935dcbf56ad3bd66ebff9891a6bc3865b8106d/nnvm/src/pass/gradient.cc#L126
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also didn't get that. A drawing might help. I would expect that you multiply with the backward node itself "n"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that is what he meant as
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (1 - 2*f(x))
Backward of node n->inputs[1]
is the Node n
itself.
One up for the visual drawing/graph of what is actually happening.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks I'll look at it today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kind of see from the graph you have drawn and dumped graph as to what is happening. I still strongly feel that it is the incorrect behaviour.
To check that I modified the test code slightly to also compute and validate for the 3rd order gradient as it should come for free since the second order are composed from differentiable functions.
Have tested for sin
, cos
, log
, sigmoid
. Out of them only sigmoid
fails for the third order.
I have also attached the hand computation of third-order for sigmoid
, so please verify that just to make sure that it is not incorrect.
For some reason it is rotated.
import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed
@with_seed()
def test_sin():
def sin(x):
return nd.sin(x)
def grad_grad_op(x):
return -nd.sin(x)
def grad_grad_grad_op(x):
return -nd.cos(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sin, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_cos():
def cos(x):
return nd.cos(x)
def grad_grad_op(x):
return -nd.cos(x)
def grad_grad_grad_op(x):
return nd.sin(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cos, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_log():
def log(x):
return nd.log(x)
def grad_grad_op(x):
return -1/(x**2)
def grad_grad_grad_op(x):
return 2/(x**3)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_sigmoid():
def sigmoid(x):
return nd.sigmoid(x)
def grad_op(x):
return sigmoid(x) * (1 - sigmoid(x))
def grad_grad_op(x):
return grad_op(x) * (1 - 2 * sigmoid(x))
def grad_grad_grad_op(x):
return grad_grad_op(x) - 2 * ( grad_op(x)**2 + grad_grad_op(x) * sigmoid(x))
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sigmoid, grad_grad_op, grad_grad_grad_op)
def check_second_order_unary(x, op, grad_grad_op, grad_grad_grad_op):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
grad_grad_grad_x = grad_grad_grad_op(x)
x.attach_grad()
# Manual head_grads.
y_grad = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)
# Perform compute.
with autograd.record():
y = op(x)
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
create_graph=True, retain_graph=True)[0]
x_grad_grad = autograd.grad(heads=x_grad, variables=x, head_grads=head_grad_grads,
create_graph=True, retain_graph=True)[0]
x_grad_grad.backward()
# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()
expected_grad_grad_grad = grad_grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()
# Validate the gradients.
assert_almost_equal(expected_grad_grad, x_grad_grad.asnumpy())
assert_almost_equal(expected_grad_grad_grad, x.grad.asnumpy())
if __name__ == '__main__':
import nose
nose.runmodule()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the discrepancy in 3rd order gradient is not because an error in my implementation of grad_grad_input, but because I did not return back the first grad_grad in correctly. After all, the first output may be useful in calculating higher order even if they are not visible output at Python level. I will look into it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kshitij12345 I have fixed the issue. The result can pass your test now. Please review again. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
// n->inputs[0] : y_grad | ||
// n->inputs[1] : f(x) = sigmoid(x) | ||
// ograds[0] : head_grads | ||
// f''(x) = f'(x) * (1 - 2*f(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
// n->inputs[0] : y_grad | ||
// n->inputs[1] : f(x) = sigmoid(x) | ||
// ograds[0] : head_grads | ||
// f''(x) = f'(x) * (1 - 2*f(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -106,6 +106,23 @@ def grad_grad_op(x): | |||
check_second_order_unary(array, log10, grad_grad_op) | |||
|
|||
|
|||
@with_seed() | |||
def test_sigmoid(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks correct to me.
@@ -501,6 +501,10 @@ std::vector<NDArray*> Imperative::Backward( | |||
} | |||
} | |||
|
|||
if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice hack :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how logging of static memory helps here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He wants to dump the graph. maybe should be separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually dumping out the computation graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apeforest isn't that what I wrote? could you answer wrt separating into a different PR? Also MXNET_MEM_PLAN_VERBOSE_LOGGING is not documented in faq/env_var.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larroy I did not see your comment before I refresh the page. Our minds think alike :)
Sure, I can separate this into a different PR. I added here only to help @kshitij12345 dump out the graph and understand the backward pass better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is great, up to you to separate not a big deal. I'm not a radical on my reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apeforest Really like the dump, its great help to actually see the graph. Thank You.
Can we merge this? |
I verified the result is the same as pytorch
|
@sxjscience Please help to review this PR. Thanks! |
update code as per apache#15288.
@kshitij12345 could you approve the PR if everything looks good to you now? thx |
…5253) * init to reset * issue: higher order backward sigmoid * update gradient code. update code as per apache#15288. * undo changes * relax tolerance of gradient mismatch for tanh * update comments * update comments
Description
This PR supports higher order gradient for sigmoid operator.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments