Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Add higher order gradient support tan, tanh #15253

Merged
merged 10 commits into from
Jul 29, 2019

Conversation

kshitij12345
Copy link
Contributor

@kshitij12345 kshitij12345 commented Jun 16, 2019

Description

Trying to add higher order support for tan, tanh.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA-978 issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Support Higher order gradient for Tan and Tests
  • Support Higher order gradient for Tanh and Tests

@kshitij12345
Copy link
Contributor Author

@larroy @apeforest Please help.

I am facing some issue here.

https://github.com/apache/incubator-mxnet/blob/f49013a60b8ea5d6c75ba3515f25a8c346748269/src/operator/tensor/elemwise_unary_op_basic.cc#L123-L137

Here I am returning input[1] * ograd as value for x_grad_grad which I expect it to be f(x) * ograd as we use ElemGradUseOut to wrap the _backward_sigmoid function.

However from the tests.

https://github.com/apache/incubator-mxnet/blob/f49013a60b8ea5d6c75ba3515f25a8c346748269/tests/python/unittest/test_higher_order_grad.py#L109-L123

https://github.com/apache/incubator-mxnet/blob/f49013a60b8ea5d6c75ba3515f25a8c346748269/tests/python/unittest/test_higher_order_grad.py#L136-L151

You can see that the actual value returned is f(x) * f'(x) * ograd.
I have also tried the method as suggested by @larroy of using create_graph=False and retain_graph=True.

Thank You.

@apeforest apeforest self-requested a review June 17, 2019 23:31
@apeforest
Copy link
Contributor

@kshitij12345 I will review it carefully later tonight. Could you please also rebase your other PR #15120 and trigger CI? I need your update in the unit tests for my PR as well. Thanks!


std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the first item be _backward_grad_grad and the second item be _backward_grad_grad_inp?

Given that:

n->input[0]:  ygrad
n->input[1]: y

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the code is incorrect in terms of naming the attribute and probably variable names.
As my main concern was trying to figure why we f(x) * f'(x) where f(x) was expected.

@kshitij12345
Copy link
Contributor Author

@kshitij12345 I will review it carefully later tonight. Could you please also rebase your other PR #15120 and trigger CI? I need your update in the unit tests for my PR as well. Thanks!

Done.


ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], gx_ograd}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be ograds[0] * f'(x) ?

Let f(x) = sigmoid(x)
xgrad = dL/dx
ygrad = dL/dy
We know dydx = y*(1-y)
In the first backward pass:

xgrad = dydx * ygrad

n->inputs[0] ---> ygrad
n->inputs[1] ---> y = f(x)
nnvm::NodeEntry{n} ---> xgrad

According to the multiplier rule, in the second backward pass:

_backward_grad_grad = d2L/dx2 * f'(x) = ograds[0] * dydx
_backward_grad_grad_inp = d2L/dx2 * ygrad * f''(x) = ograds[0] * n->input[0] * f''(x)

Copy link
Contributor Author

@kshitij12345 kshitij12345 Jun 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is indeed the surprising part, I have asserted from Python that n->inputs[1] is f(x) * f'(x) surprisingly instead of it being f(x) that is why i was very much confused. And also if I remember correctly n->inputs[0] as well as nnvm::NodeEntry{n} are different from what we expect them to be.

Note : Right now it doesn't actually implement the second order gradient for sigmoid.

I am essentially returning what I expect to be f(x) * ograds(from second backward)
at line
https://github.com/apache/incubator-mxnet/blob/f49013a60b8ea5d6c75ba3515f25a8c346748269/src/operator/tensor/elemwise_unary_op_basic.cc#L134-L135
but instead we are receiving f(x) * f'(x) * ograds(from second backward) which is confirmed from Python. That is where my confusion lies.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got your question now. I got the same result f(x) * f'(x) * ograds locally. Need to dig in more...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dumped out the computation graph. If you want to try, refer to my local branch: https://github.com/apeforest/incubator-mxnet/blob/develop/higher_order_grad/src/imperative/imperative.cc#L509

[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 0 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 1 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 2 sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 1: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: 		output 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 3 _backward_sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 0: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: 		output 3: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:300: node 4 var
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 5 elemwise_mul
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 4: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: 		output 5: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:302: node 6 _backward_sigmoid
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 5: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:306: 		input 2: [3] (0 KB) ->
[16:37:15] /Users/lnyuan/work/mxnet/src/executor/../common/exec_utils.h:312: 		output 6: [3] (0 KB) ->

So it seems if the backward operator takes output of the forward node instead of input, the backward graph may be created differently with some chain rule applied already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh Thanks. Will try to look into it. Thanks for the pointer.

@larroy
Copy link
Contributor

larroy commented Jun 18, 2019

@mxnet-label-bot add [pr-awaiting-review,autograd]

@marcoabreu marcoabreu added Autograd pr-awaiting-review PR is waiting for code review labels Jun 18, 2019
@larroy
Copy link
Contributor

larroy commented Jun 18, 2019

@mxnet-label-bot add [operator]

@kshitij12345 kshitij12345 changed the title Add higher order gradient support sigmoid, tan, tanh Add higher order gradient support tan, tanh Jul 2, 2019
@kshitij12345 kshitij12345 changed the title Add higher order gradient support tan, tanh [MXNET-978] Add higher order gradient support tan, tanh Jul 2, 2019
// f''(x) = 2 * f'(x) * f(x)
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we clarify / add a comment on why is correct to multiply by y_grad (the first head gradient?) again? This would help readers as is not obvious, as well as the very non-obvious implicit multiplication by f'(x) it compounds.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks. Will get to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the comment. See if it is okay? Or maybe the phrasing can be improved.
Thanks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. About the outputs, I think we should write some documentation explaining what we are doing as I find it non trivial. Can you help me understand the y_grad_grad (first output)?

If you want, we can move the conversation to the dev list or slack, as the PR LGTM.

IMG_20190711_122948__01

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I clarified with @apeforest , this makes sense now.

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I find the code non-obvious, (is consistent with sigmoid) I think we should add some documentation on the faq about this mechanics, the implicit multiplication by f'(x) and the multiplication by y_grad.

// f''(x) = 2 * f'(x) * f(x)
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. About the outputs, I think we should write some documentation explaining what we are doing as I find it non trivial. Can you help me understand the y_grad_grad (first output)?

If you want, we can move the conversation to the dev list or slack, as the PR LGTM.

IMG_20190711_122948__01

@kshitij12345
Copy link
Contributor Author

It is actually confusing.
Sure.
Thank You.

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sandeep-krishnamurthy
Copy link
Contributor

@apeforest - Can you please take look at this PR and merge if it looks good? Thanks

Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@apeforest apeforest merged commit 8e31dad into apache:master Jul 29, 2019
@kshitij12345 kshitij12345 deleted the add-higher-ord branch July 30, 2019 14:04
anirudhacharya pushed a commit to anirudhacharya/mxnet that referenced this pull request Aug 20, 2019
…5253)

* init to reset

* issue: higher order backward sigmoid

* update gradient code.

update code as per apache#15288.

* undo changes

* relax tolerance of gradient mismatch for tanh

* update comments

* update comments
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Autograd Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants