-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[bug] fix higher grad log #15120
[bug] fix higher grad log #15120
Changes from 2 commits
03cd1c7
37ce3b8
10f2b10
dee4efb
23eaf42
7736801
cf80ed6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1074,16 +1074,19 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, | |
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// For f(x) -> f = log | ||
// f''(x) = -1 * (f'(x) * f'(x)) | ||
auto gx = nnvm::NodeEntry{n}; | ||
auto gx_mul_head_grads = nnvm::NodeEntry{n}; // f'(x) * head_grads | ||
auto head_grads = nnvm::NodeEntry{n->inputs[0]}; | ||
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", | ||
{gx, gx}, nullptr, &n); | ||
{gx_mul_head_grads, nnvm::NodeEntry{g_lx}}, nullptr, &n); | ||
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], gx}, nullptr, &n)); | ||
{ograds[0], nnvm::NodeEntry{g_lx}}, nullptr, &n)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am having trouble with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi. What do you mean by head_grads.grad? NodeEntry doesn't have a grad field. Could you clarify? Are you referring to the python code below? The gradient is always 0 when attach_grad() is called. The value is updated after running backward on an output, or using autograd.grad. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Still looking into this. The first output should be the gradient of y_grad. However, the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry for the confusion. I forgot to add the line from the test file. |
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", | ||
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); | ||
return ret; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,52 +27,79 @@ def test_log(): | |
def log(x): | ||
return nd.log(x) | ||
|
||
def grad_op(x): | ||
return 1/x | ||
|
||
def grad_grad_op(x): | ||
return -1/(x**2) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log, grad_grad_op) | ||
check_second_order_unary(array, log, grad_op, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log2(): | ||
def log2(x): | ||
return nd.log2(x) | ||
|
||
def grad_op(x): | ||
return 1/(x * math.log(2)) | ||
|
||
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(2)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log2, grad_grad_op) | ||
check_second_order_unary(array, log2, grad_op, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log10(): | ||
def log10(x): | ||
return nd.log10(x) | ||
|
||
def grad_op(x): | ||
return 1/(x * math.log(10)) | ||
|
||
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(10)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log10, grad_grad_op) | ||
check_second_order_unary(array, log10, grad_op, grad_grad_op) | ||
|
||
|
||
def check_second_order_unary(x, op, grad_grad_op): | ||
def check_second_order_unary(x, op, grad_op, grad_grad_op): | ||
x = nd.array(x) | ||
expect_grad_grad = grad_grad_op(x) | ||
grad_x = grad_op(x) | ||
grad_grad_x = grad_grad_op(x) | ||
x.attach_grad() | ||
|
||
# Manual head_grads. | ||
head_grads = nd.random.normal(shape=x.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
head_grad_grads = nd.random.normal(shape=x.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still don't understand what this variable is mathematically... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification. |
||
head_grads.attach_grad() | ||
|
||
# Perform compute. | ||
with autograd.record(): | ||
y = op(x) | ||
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0] | ||
y_grad.backward() | ||
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy()) | ||
y_grad = autograd.grad(y, x, head_grads=head_grads, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable is actually dL/dx, maybe rename it to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes. Will do that. |
||
create_graph=True, retain_graph=True)[0] | ||
|
||
y_grad.backward(head_grad_grads) | ||
|
||
# Compute expected values. | ||
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ | ||
head_grads.asnumpy() | ||
expected_heads_grad = grad_x.asnumpy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be |
||
|
||
# Validate the gradients. | ||
assert_almost_equal(expected_grad_grad, x.grad.asnumpy()) | ||
assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now I understand your question, i don't think anything is updating head_grads.grad here (this is done when running backward). Why do you want to set the head gradients manually? To verify your fix? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you try
and in validation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah to verify the fix. Thanks for the suggestion, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. y_grad.backward(head_grad_grads) indicate that head_grad_grads are the head gradients passed from "upstream". Calling (output variable).backward It will update all the independent input variables (from which those output are dependent), which have attached gradient. In this case head_grad_grads is not an input to the graph, so your problem that the grad doesn't get updated is expected: https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/ndarray/ndarray.py#L2188 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are checking gradients for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default behaviour in python vs mxnet is different with respect accumulation of gradients (pytorch: add) mxnet: write. Having said that, I still don't understand why do you expect gradient accumulation in head_grads. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh. I just expected it to have gradients ( by accumulation or writing ), as it is / its value is used while computing the |
||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add a comment about the inputs and what is g_lx? it would help reason about the code. Are the inputs of n (backward_log)
?
So g_lx is a node having 1/x ? or the derivative of the log right? can we rename to g_logx ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure thing.