Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Support higher order gradient for log, log2, log10. #14992

Merged
merged 4 commits into from
May 28, 2019

Conversation

kshitij12345
Copy link
Contributor

@kshitij12345 kshitij12345 commented May 18, 2019

Description

With reference to #14613, #10002 , this PR intends to add support for higher order gradient for log { and ideally for log2, log10 } as well.

Tests are based totally on #14613

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • higher order gradient for a log.
  • unit test for the same.

@kshitij12345 kshitij12345 force-pushed the log_higher_order_grad branch 2 times, most recently from 16e1815 to 49e59f0 Compare May 18, 2019 21:24
@kshitij12345
Copy link
Contributor Author

kshitij12345 commented May 18, 2019

I don't know much about this library but,

I believe it would be better to have gradients defined for existing backward, instead of a differentiable gradient (relying on autograd machinery) at least on ops where backward is not trivial. It will allow to use existing optimised fused kernels and make sure there is no regression in the backward.

Note: log is relatively trivial (single reciprocal). But maybe we may see a performance regression for slightly non-trivial sigmoid, if we do it by relying on autograd machinery instead of the existing _backward_sigmoid.

@kshitij12345
Copy link
Contributor Author

Can anyone please point how I can have 1/log(2.0), 1/log(10.0) multiplied with gradient for log2, log10.

@pinaraws
Copy link

@mxnet-label-bot add[Operator, pr-awaiting-review]

@apeforest
Copy link
Contributor

@kshitij12345 Thanks for your contribution. I agree with you it would be better to have gradients defined for existing backward operators.

I do not fully understand your question of 1/log(2.0), 1/log(10.0) multiplied with gradient for log2, log 10. Could you please elaborate?

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented May 21, 2019

I do not fully understand your question of 1/log(2.0), 1/log(10.0) multiplied with gradient for log2, log 10. Could you please elaborate?

Reading it again, I phrased it poorly. Sorry. So actually, plan was to update gradient for log2, which would be 1/(log(2.0) * x), for which I would have required a log(2.0). So how to get that, is scalar multiplication allowed ? or ones_like followed by fill.

Note: This is not needed for this PR. But curious to know.

Thank You.

@kshitij12345
Copy link
Contributor Author

@larroy @apeforest Have updated as per #14095 .
Please review.

@kshitij12345
Copy link
Contributor Author

[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For g(x) -> g = log
// g''(x) = -1 * (g'(x) * g'(x))
auto gx = nnvm::NodeEntry{n, 0, 0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are welcome to simplify calls to NodeEntry as per:

#14095

Just call nnvm::NodeEntry{n}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh. Missed that. Thank You.

Copy link
Contributor

@larroy larroy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice PR, thanks a lot for this. just a couple of questions.


ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we returning two gradients, isn't it an unary function with just one input?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above.

* simplify NodeEntry creation.
@apeforest
Copy link
Contributor

apeforest commented May 23, 2019

I do not fully understand your question of 1/log(2.0), 1/log(10.0) multiplied with gradient for log2, log 10. Could you please elaborate?

Reading it again, I phrased it poorly. Sorry. So actually, plan was to update gradient for log2, which would be 1/(log(2.0) * x), for which I would have required a log(2.0). So how to get that, is scalar multiplication allowed ? or ones_like followed by fill.

Note: This is not needed for this PR. But curious to know.

Thank You.

After reviewing your code, I had a better understanding of what you meant. I think you can do an elemwise_mul operator with Op('log') and a vector filled with 2.0. There maybe other ways to optimize the graph representation, but I think this should actually work.

unary_bwd<mshadow_op::log_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For g(x) -> g = log
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's very nice to see a comment here. The g(x) is actually a function of x. It might be easily confused with the variable gx two lines below. Maybe use f(x) in the comment here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sure . That makes sense. Thank You.

@apeforest
Copy link
Contributor

@kshitij12345 The CI failure in unix-GPU was due to a flaky test for TensorRT: #14978

The issue has been fixed by #15014. Please re-trigger CI again. Thanks!

* update comment to avoid confusion.
Copy link
Contributor

@apeforest apeforest left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for your contribution.

@larroy
Copy link
Contributor

larroy commented May 28, 2019

Lgtm

@apeforest apeforest merged commit 8a9dd72 into apache:master May 28, 2019
// For f(x) -> f = log10
// f'(x) = 1 / (log(10) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n, 0, 0};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we follow the same pattern as in the natural logarithm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For natural log,
we have with us in gradient function, gx i.e. 1/x as well as x.
Since, second derivative of log is -(gx * gx) = -1/(x^2). We use the pattern.

Considering log2 (similar case for log10)
we have with us, gx i.e. 1/(log(2) * x) as well as x.
Since second derivative is -1/(log(2) * x * x)
which we get in the code using negative(gx * reciprocal(x)), where gx=1/(log(2) * x.
Another way to get that will be negative(gx * gx * log(2.0)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larroy Thanks for pointing this, going through this again made me realise that there is a problem with the implementation of log.

@kshitij12345
Copy link
Contributor Author

kshitij12345 commented May 31, 2019

@apeforest @larroy

https://github.com/kshitij12345/incubator-mxnet/blob/7b343d1fcde73b61322985580080333d9eee9e82/src/operator/tensor/elemwise_unary_op_basic.cc#L1077-L1079

We multiply gx * gx where gx = ograd * f'(x), getting ograd^2 * f'(x)^2, however we want only ograd * f'(x)^2 which can be achieved in a similar fashion to the implementation of _backward_log10/2.

I have validated the expected results.

from mxnet import nd, autograd
import numpy
import math
grad_grad_op = lambda x: (-1/x**2)

x = nd.random.normal(0,1,(3,3))
x.attach_grad()
with autograd.record():
  y = nd.log(x)
  y_grad = autograd.grad(y, x, head_grads= nd.ones_like(y) * 0.5, create_graph=True, retain_graph=True)[0]
y_grad.backward(nd.ones_like(y_grad) * 0.6)

numpy.testing.assert_allclose(x.grad.asnumpy() , ( grad_grad_op(x) * 0.5 * 0.6).asnumpy(), rtol=1e-7, atol=1e-7)

Which fails with current code.
Should make a new PR, or add commits in this PR itself?. Sorry for the trouble.

Have confirmed the behaviour with Pytorch as well.

import torch
import numpy
import math

grad_grad_op = lambda x: (-1/x**2)

x = torch.randn(2,3)
x.requires_grad = True

y = torch.log(x)
y_grad = torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y) * 0.5, create_graph=True, retain_graph=True)[0]
y_grad.backward(torch.ones_like(y_grad) * 0.6)

numpy.testing.assert_allclose(x.grad.detach().numpy() , ( grad_grad_op(x) * 0.5 * 0.6).detach().numpy(), rtol=1e-7, atol=1e-7)

aaronmarkham pushed a commit to aaronmarkham/incubator-mxnet that referenced this pull request May 31, 2019
* add higher order gradient support for log, log10, log2

* add tests

* address comments

* simplify NodeEntry creation.

* address comments

* update comment to avoid confusion.

add nano cross compile

add nano dockerfile

workaround build error

workaround build error - attempt 2

workaround build error - attempt 3

add jetson nano instructions; java api for jetson

fix ci side for jetson build

revert cmake updates not needed

fix website build error and opencv error for arm8

make executable

workaround apt install issue

update python setup; remove java setup

removed unneeded changes

add a gpu test for verification

remove scala setup step for now

get rid of apt-get update since it fails every time
@apeforest
Copy link
Contributor

Hi @kshitij12345 sorry, we missed that. We should have reviewed it more carefully. Please submit another PR to fix this issue. I will also update #14613 accordingly. Thanks!

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be {ograds[0], g_lx} instead? Isn't dL/dy_grad = d^2L/dx^2 * f'(x)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it should. Thanks. I have updated this change and added relevant test in the new PR #15120 .
Actually I am having trouble exactly at this part as the grad value is not being updated. More info in #15120

@kshitij12345
Copy link
Contributor Author

Have created a new PR for the same. #15120 Please review

haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* add higher order gradient support for log, log10, log2

* add tests

* address comments

* simplify NodeEntry creation.

* address comments

* update comment to avoid confusion.
@kshitij12345 kshitij12345 changed the title [MXNET-978] Support higher order gradient for log. [MXNET-978] Support higher order gradient for log, log2, log10. Jul 13, 2019
@kshitij12345 kshitij12345 deleted the log_higher_order_grad branch July 13, 2019 08:40
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants