-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Higher order gradient for sigmoid #15288
Changes from all commits
45e1502
904adb4
d5dc994
0e69075
0c7cf98
492e4cd
45b334e
3bbfbac
4dc0907
c4034b2
3fe54e6
76aa6ad
8458717
f66610b
30ff1e9
8ecffcc
d9ba3da
1c93c7d
de721bc
0ac0942
f8e624e
3315124
8538980
1ee38b5
c18f317
689cfee
d56e132
2207815
0b6c2ef
31f671f
62fcca3
a0a0e75
451c4bd
b9b0c93
d060102
94e3b5f
7d95760
f43489d
55d7ebc
25d3f78
10dab58
d134d2f
30b1ba9
6848d42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -121,7 +121,35 @@ The storage type of ``sigmoid`` output is always dense | |
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, | ||
unary_bwd<mshadow_op::sigmoid_grad>); | ||
unary_bwd<mshadow_op::sigmoid_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// n->inputs[0] : y_grad | ||
// n->inputs[1] : f(x) = sigmoid(x) | ||
// ograds[0] : head_grads | ||
// f''(x) = f'(x) * (1 - 2*f(x)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
// NodeEntry{n} : y_grad * f'(x) | ||
auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n); | ||
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}}; | ||
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n); | ||
auto one_minus_two_y = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub", | ||
{nnvm::NodeEntry{ones}, nnvm::NodeEntry{two_y}}, nullptr, &n); | ||
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul", | ||
{n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, nullptr, &n); | ||
auto dydx = MakeNode("elemwise_div", n->attrs.name + "_grad_div", | ||
{nnvm::NodeEntry{n}, n->inputs[0]}, nullptr, &n); | ||
|
||
// when building gradient graph, the backward node of n->inputs[1] will be | ||
// added to the graph again, therefore f`(x) will be multiplied | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this behaviour seem a bit odd? Is this actually the expected behaviour? What would have happened if this was There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is actually the expected behavior. The nnvm graph will perform a DFS traverse when performing backward pass. Since the input to this backward_sigmoid node is an output of another node sigmoid, during the backward pass the gradient function sigmoid will be invoked. See: This is to collect dependent nodes in the graph during RecordOp: This is to actually perform the backward pass: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also didn't get that. A drawing might help. I would expect that you multiply with the backward node itself "n" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that is what he meant as
Backward of node One up for the visual drawing/graph of what is actually happening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks I'll look at it today. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kind of see from the graph you have drawn and dumped graph as to what is happening. I still strongly feel that it is the incorrect behaviour. To check that I modified the test code slightly to also compute and validate for the 3rd order gradient as it should come for free since the second order are composed from differentiable functions. Have tested for I have also attached the hand computation of third-order for For some reason it is rotated. import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed
@with_seed()
def test_sin():
def sin(x):
return nd.sin(x)
def grad_grad_op(x):
return -nd.sin(x)
def grad_grad_grad_op(x):
return -nd.cos(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sin, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_cos():
def cos(x):
return nd.cos(x)
def grad_grad_op(x):
return -nd.cos(x)
def grad_grad_grad_op(x):
return nd.sin(x)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, cos, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_log():
def log(x):
return nd.log(x)
def grad_grad_op(x):
return -1/(x**2)
def grad_grad_grad_op(x):
return 2/(x**3)
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, log, grad_grad_op, grad_grad_grad_op)
@with_seed()
def test_sigmoid():
def sigmoid(x):
return nd.sigmoid(x)
def grad_op(x):
return sigmoid(x) * (1 - sigmoid(x))
def grad_grad_op(x):
return grad_op(x) * (1 - 2 * sigmoid(x))
def grad_grad_grad_op(x):
return grad_grad_op(x) - 2 * ( grad_op(x)**2 + grad_grad_op(x) * sigmoid(x))
for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sigmoid, grad_grad_op, grad_grad_grad_op)
def check_second_order_unary(x, op, grad_grad_op, grad_grad_grad_op):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
grad_grad_grad_x = grad_grad_grad_op(x)
x.attach_grad()
# Manual head_grads.
y_grad = nd.random.normal(shape=x.shape)
head_grad_grads = nd.random.normal(shape=x.shape)
# Perform compute.
with autograd.record():
y = op(x)
x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
create_graph=True, retain_graph=True)[0]
x_grad_grad = autograd.grad(heads=x_grad, variables=x, head_grads=head_grad_grads,
create_graph=True, retain_graph=True)[0]
x_grad_grad.backward()
# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()
expected_grad_grad_grad = grad_grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
y_grad.asnumpy()
# Validate the gradients.
assert_almost_equal(expected_grad_grad, x_grad_grad.asnumpy())
assert_almost_equal(expected_grad_grad_grad, x.grad.asnumpy())
if __name__ == '__main__':
import nose
nose.runmodule() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect the discrepancy in 3rd order gradient is not because an error in my implementation of grad_grad_input, but because I did not return back the first grad_grad in correctly. After all, the first output may be useful in calculating higher order even if they are not visible output at Python level. I will look into it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kshitij12345 I have fixed the issue. The result can pass your test now. Please review again. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM. |
||
std::vector<nnvm::NodeEntry> ret; | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad", | ||
{ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in", | ||
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
|
||
|
||
DMLC_REGISTER_PARAMETER(HardSigmoidParam); | ||
MXNET_OPERATOR_REGISTER_UNARY(hard_sigmoid) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -106,6 +106,23 @@ def grad_grad_op(x): | |
check_second_order_unary(array, log10, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_sigmoid(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks correct to me. |
||
def sigmoid(x): | ||
return nd.sigmoid(x) | ||
|
||
def grad_op(x): | ||
return sigmoid(x) * (1 - sigmoid(x)) | ||
|
||
def grad_grad_op(x): | ||
return grad_op(x) * (1 - 2 * sigmoid(x)) | ||
|
||
for dim in range(1, 5): | ||
shape = rand_shape_nd(dim) | ||
array = random_arrays(shape) | ||
check_second_order_unary(array, sigmoid, grad_grad_op) | ||
|
||
|
||
def check_second_order_unary(x, op, grad_grad_op): | ||
x = nd.array(x) | ||
grad_grad_x = grad_grad_op(x) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice hack :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain how logging of static memory helps here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
He wants to dump the graph. maybe should be separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually dumping out the computation graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apeforest isn't that what I wrote? could you answer wrt separating into a different PR? Also MXNET_MEM_PLAN_VERBOSE_LOGGING is not documented in faq/env_var.md
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@larroy I did not see your comment before I refresh the page. Our minds think alike :)
Sure, I can separate this into a different PR. I added here only to help @kshitij12345 dump out the graph and understand the backward pass better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR is great, up to you to separate not a big deal. I'm not a radical on my reviews.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apeforest Really like the dump, its great help to actually see the graph. Thank You.