Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[bug] fix higher grad log #15120

Merged
merged 7 commits into from
Jun 20, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1074,16 +1074,19 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log,
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log
// f''(x) = -1 * (f'(x) * f'(x))
auto gx = nnvm::NodeEntry{n};
auto gx_mul_head_grads = nnvm::NodeEntry{n}; // f'(x) * head_grads
auto head_grads = nnvm::NodeEntry{n->inputs[0]};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a comment about the inputs and what is g_lx? it would help reason about the code. Are the inputs of n (backward_log)

  • 0: input gradient
  • 1: x
    ?

So g_lx is a node having 1/x ? or the derivative of the log right? can we rename to g_logx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing.

{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, gx}, nullptr, &n);
{gx_mul_head_grads, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
{ograds[0], nnvm::NodeEntry{g_lx}}, nullptr, &n));
Copy link
Contributor Author

@kshitij12345 kshitij12345 Jun 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/incubator-mxnet/blob/37ce3b87268a8154f5c0ad97ce2522478038ee06/tests/python/unittest/test_higher_order_grad.py#L102

I am having trouble with head_grads.grad which is being returned as 0's (I guess they are somehow not being updated) while I expect it to be the output of this line.
Please help.

Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi. What do you mean by head_grads.grad? NodeEntry doesn't have a grad field. Could you clarify? Are you referring to the python code below? The gradient is always 0 when attach_grad() is called. The value is updated after running backward on an output, or using autograd.grad.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still looking into this. The first output should be the gradient of y_grad. However, the head_grads.grad does not get the value. I suspect the returned value from this function is dropped in the gradient calculation in imperative.cc. I will look more into this. Stay tuned.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I forgot to add the line from the test file.
Sure waiting to know what you find.

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
return ret;
Expand Down
43 changes: 35 additions & 8 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,52 +27,79 @@ def test_log():
def log(x):
return nd.log(x)

def grad_op(x):
return 1/x

def grad_grad_op(x):
return -1/(x**2)

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log, grad_grad_op)
check_second_order_unary(array, log, grad_op, grad_grad_op)


@with_seed()
def test_log2():
def log2(x):
return nd.log2(x)

def grad_op(x):
return 1/(x * math.log(2))

def grad_grad_op(x):
return -1/((x**2) * math.log(2))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log2, grad_grad_op)
check_second_order_unary(array, log2, grad_op, grad_grad_op)


@with_seed()
def test_log10():
def log10(x):
return nd.log10(x)

def grad_op(x):
return 1/(x * math.log(10))

def grad_grad_op(x):
return -1/((x**2) * math.log(10))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log10, grad_grad_op)
check_second_order_unary(array, log10, grad_op, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op):
def check_second_order_unary(x, op, grad_op, grad_grad_op):
x = nd.array(x)
expect_grad_grad = grad_grad_op(x)
grad_x = grad_op(x)
grad_grad_x = grad_grad_op(x)
x.attach_grad()

# Manual head_grads.
head_grads = nd.random.normal(shape=x.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename this to y_grad as it is dL/dy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

head_grad_grads = nd.random.normal(shape=x.shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't understand what this variable is mathematically...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

head_grads is just the input node in the graph for x_grad.
head_grad_grads is just to check the validity of the chain rule/backprop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification.

head_grads.attach_grad()

# Perform compute.
with autograd.record():
y = op(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())
y_grad = autograd.grad(y, x, head_grads=head_grads,
Copy link
Contributor

@apeforest apeforest Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is actually dL/dx, maybe rename it to x_grad for better readability?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes. Will do that.

create_graph=True, retain_graph=True)[0]

y_grad.backward(head_grad_grads)

# Compute expected values.
expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
head_grads.asnumpy()
expected_heads_grad = grad_x.asnumpy()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be grad_x.asnumpy() * head_grad_grads.asnumpy()


# Validate the gradients.
assert_almost_equal(expected_grad_grad, x.grad.asnumpy())
assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy())
Copy link
Contributor

@larroy larroy Jun 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I understand your question, i don't think anything is updating head_grads.grad here (this is done when running backward). Why do you want to set the head gradients manually? To verify your fix?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you try

y_grad_grad = autograd.grad(y_grad, x, ..., create_graph = False...)[0]

and in validation
assert_almost_equal(expected_heads_grad, y_grad_grad.asnumpy())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah to verify the fix.
I expected y_grad.backward(head_grad_grads) to update the head_grads.grad similar to the Pytorch Script from the description.

Thanks for the suggestion,
I will surely try that.

Copy link
Contributor

@larroy larroy Jun 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

y_grad.backward(head_grad_grads) indicate that head_grad_grads are the head gradients passed from "upstream". Calling (output variable).backward It will update all the independent input variables (from which those output are dependent), which have attached gradient. In this case head_grad_grads is not an input to the graph, so your problem that the grad doesn't get updated is expected:

https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/ndarray/ndarray.py#L2188
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py#L270

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are checking gradients for head_grads (not head_grad_grads), which is used to compute x_grad, so I believe we should accumulate some gradient in head_grads

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default behaviour in python vs mxnet is different with respect accumulation of gradients (pytorch: add) mxnet: write. Having said that, I still don't understand why do you expect gradient accumulation in head_grads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. I just expected it to have gradients ( by accumulation or writing ), as it is / its value is used while computing the x_grad. But from your and @apeforest's explanation, I kinda understand the behaviour better.
Thank You for digging in and explaining.



if __name__ == '__main__':
Expand Down