-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Add higher order gradient support tan
, tanh
#15253
Conversation
@larroy @apeforest Please help. I am facing some issue here. Here I am returning However from the tests. You can see that the actual value returned is Thank You. |
@kshitij12345 I will review it carefully later tonight. Could you please also rebase your other PR #15120 and trigger CI? I need your update in the unit tests for my PR as well. Thanks! |
|
||
std::vector<nnvm::NodeEntry> ret; | ||
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the first item be _backward_grad_grad and the second item be _backward_grad_grad_inp?
Given that:
n->input[0]: ygrad
n->input[1]: y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the code is incorrect in terms of naming the attribute and probably variable names.
As my main concern was trying to figure why we f(x) * f'(x)
where f(x)
was expected.
Done. |
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", | ||
{ograds[0], gx_ograd}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be ograds[0] * f'(x) ?
Let f(x) = sigmoid(x)
xgrad = dL/dx
ygrad = dL/dy
We know dydx = y*(1-y)
In the first backward pass:
xgrad = dydx * ygrad
n->inputs[0] ---> ygrad
n->inputs[1] ---> y = f(x)
nnvm::NodeEntry{n} ---> xgrad
According to the multiplier rule, in the second backward pass:
_backward_grad_grad = d2L/dx2 * f'(x) = ograds[0] * dydx
_backward_grad_grad_inp = d2L/dx2 * ygrad * f''(x) = ograds[0] * n->input[0] * f''(x)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is indeed the surprising part, I have asserted from Python that n->inputs[1]
is f(x) * f'(x)
surprisingly instead of it being f(x)
that is why i was very much confused. And also if I remember correctly n->inputs[0]
as well as nnvm::NodeEntry{n}
are different from what we expect them to be.
Note : Right now it doesn't actually implement the second order gradient for sigmoid
.
I am essentially returning what I expect to be f(x) * ograds(from second backward)
at line
https://github.com/apache/incubator-mxnet/blob/f49013a60b8ea5d6c75ba3515f25a8c346748269/src/operator/tensor/elemwise_unary_op_basic.cc#L134-L135
but instead we are receiving f(x) * f'(x) * ograds(from second backward)
which is confirmed from Python. That is where my confusion lies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your question now. I got the same result f(x) * f'(x) * ograds
locally. Need to dig in more...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dumped out the computation graph. If you want to try, refer to my local branch: https://github.com/apeforest/incubator-mxnet/blob/develop/higher_order_grad/src/imperative/imperative.cc#L509
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 0 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 1 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 2 sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 1: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: output 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 3 _backward_sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 0: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: output 3: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 4 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 5 elemwise_mul
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 4: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: output 5: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 6 _backward_sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 5: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: output 6: [3] (0 KB) ->
So it seems if the backward operator takes output of the forward node instead of input, the backward graph may be created differently with some chain rule applied already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh Thanks. Will try to look into it. Thanks for the pointer.
@mxnet-label-bot add [pr-awaiting-review,autograd] |
@mxnet-label-bot add [operator] |
update code as per apache#15288.
sigmoid
, tan
, tanh
tan
, tanh
tan
, tanh
tan
, tanh
// f''(x) = 2 * f'(x) * f(x) | ||
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}}; | ||
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n); | ||
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we clarify / add a comment on why is correct to multiply by y_grad (the first head gradient?) again? This would help readers as is not obvious, as well as the very non-obvious implicit multiplication by f'(x) it compounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Thanks. Will get to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated the comment. See if it is okay? Or maybe the phrasing can be improved.
Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I clarified with @apeforest , this makes sense now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I find the code non-obvious, (is consistent with sigmoid) I think we should add some documentation on the faq about this mechanics, the implicit multiplication by f'(x) and the multiplication by y_grad.
// f''(x) = 2 * f'(x) * f(x) | ||
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}}; | ||
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n); | ||
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually confusing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@apeforest - Can you please take look at this PR and merge if it looks good? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…5253) * init to reset * issue: higher order backward sigmoid * update gradient code. update code as per apache#15288. * undo changes * relax tolerance of gradient mismatch for tanh * update comments * update comments
Description
Trying to add higher order support for
tan
,tanh
.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes