Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher order gradient for sigmoid #15288

Merged
merged 44 commits into from
Jul 3, 2019

Conversation

apeforest
Copy link
Contributor

@apeforest apeforest commented Jun 20, 2019

Description

This PR supports higher order gradient for sigmoid operator.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

sxjscience and others added 30 commits October 14, 2018 14:56
@apeforest
Copy link
Contributor Author

@kshitij12345 I have figured out how backward works when one of the inputs is an output of the forward node. Please review this PR. Thanks!

@apeforest
Copy link
Contributor Author

@larroy @sxjscience Please help review this PR. Thanks!

@apeforest
Copy link
Contributor Author

apeforest commented Jun 20, 2019

@larroy I also added the method to dump computation graph to imperative mode since it will be very useful for us to debug. However, it's still very rudimentary and we still need your help to implement a
more elegant way of printing out the graph info. thanks!

auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, nullptr, &n);
// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this behaviour seem a bit odd? Is this actually the expected behaviour? What would have happened if this was split like function where we would have had many outputs, backward of all outputs be added? Actually I am confused by the behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually the expected behavior. The nnvm graph will perform a DFS traverse when performing backward pass. Since the input to this backward_sigmoid node is an output of another node sigmoid, during the backward pass the gradient function sigmoid will be invoked.

See:

This is to collect dependent nodes in the graph during RecordOp:
https://github.com/apache/incubator-mxnet/blob/master/src/imperative/imperative.cc#L180

This is to actually perform the backward pass:
https://github.com/dmlc/tvm/blob/21935dcbf56ad3bd66ebff9891a6bc3865b8106d/nnvm/src/pass/gradient.cc#L126

https://github.com/dmlc/tvm/blob/21935dcbf56ad3bd66ebff9891a6bc3865b8106d/nnvm/src/pass/gradient.cc#L190

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also didn't get that. A drawing might help. I would expect that you multiply with the backward node itself "n"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that is what he meant as

      // n->inputs[0] : y_grad
      // n->inputs[1] : f(x) = sigmoid(x)
      // ograds[0] : head_grads
      // f''(x) = f'(x) * (1 - 2*f(x))

Backward of node n->inputs[1] is the Node n itself.

One up for the visual drawing/graph of what is actually happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMG-3494

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I'll look at it today.

Copy link
Contributor

@kshitij12345 kshitij12345 Jun 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of see from the graph you have drawn and dumped graph as to what is happening. I still strongly feel that it is the incorrect behaviour.

To check that I modified the test code slightly to also compute and validate for the 3rd order gradient as it should come for free since the second order are composed from differentiable functions.

Have tested for sin, cos, log, sigmoid. Out of them only sigmoid fails for the third order.

I have also attached the hand computation of third-order for sigmoid, so please verify that just to make sure that it is not incorrect.

For some reason it is rotated.
3_rd_ord

import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed


@with_seed()
def test_sin():
    def sin(x):
        return nd.sin(x)

    def grad_grad_op(x):
        return -nd.sin(x)
    
    def grad_grad_grad_op(x):
        return -nd.cos(x)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, sin, grad_grad_op, grad_grad_grad_op)


@with_seed()
def test_cos():
    def cos(x):
        return nd.cos(x)

    def grad_grad_op(x):
        return -nd.cos(x)
    
    def grad_grad_grad_op(x):
        return nd.sin(x)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, cos, grad_grad_op, grad_grad_grad_op)

@with_seed()
def test_log():
    def log(x):
        return nd.log(x)

    def grad_grad_op(x):
        return -1/(x**2)

    def grad_grad_grad_op(x):
        return 2/(x**3)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, log, grad_grad_op, grad_grad_grad_op)


@with_seed()
def test_sigmoid():
    def sigmoid(x):
        return nd.sigmoid(x)

    def grad_op(x):
        return sigmoid(x) * (1 - sigmoid(x))

    def grad_grad_op(x):
        return grad_op(x) * (1 - 2 * sigmoid(x))

    def grad_grad_grad_op(x):
        return grad_grad_op(x) - 2 * ( grad_op(x)**2 + grad_grad_op(x) * sigmoid(x))

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, sigmoid, grad_grad_op, grad_grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op, grad_grad_grad_op):
    x = nd.array(x)
    grad_grad_x = grad_grad_op(x)
    grad_grad_grad_x = grad_grad_grad_op(x)
    x.attach_grad()

    # Manual head_grads.
    y_grad = nd.random.normal(shape=x.shape)
    head_grad_grads = nd.random.normal(shape=x.shape)

    # Perform compute.
    with autograd.record():
        y = op(x)
        x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
                               create_graph=True, retain_graph=True)[0]
        
        x_grad_grad = autograd.grad(heads=x_grad, variables=x, head_grads=head_grad_grads,
                                create_graph=True, retain_graph=True)[0]
    x_grad_grad.backward()

    # Compute expected values.
    expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
        y_grad.asnumpy()

    expected_grad_grad_grad = grad_grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
        y_grad.asnumpy()

    # Validate the gradients.
    assert_almost_equal(expected_grad_grad, x_grad_grad.asnumpy())
    assert_almost_equal(expected_grad_grad_grad, x.grad.asnumpy())

if __name__ == '__main__':
    import nose
    nose.runmodule()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the discrepancy in 3rd order gradient is not because an error in my implementation of grad_grad_input, but because I did not return back the first grad_grad in correctly. After all, the first output may be useful in calculating higher order even if they are not visible output at Python level. I will look into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345 I have fixed the issue. The result can pass your test now. Please review again. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (1 - 2*f(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (1 - 2*f(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -106,6 +106,23 @@ def grad_grad_op(x):
check_second_order_unary(array, log10, grad_grad_op)


@with_seed()
def test_sigmoid():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks correct to me.

@@ -501,6 +501,10 @@ std::vector<NDArray*> Imperative::Backward(
}
}

if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice hack :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how logging of static memory helps here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He wants to dump the graph. maybe should be separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually dumping out the computation graph.

Copy link
Contributor

@larroy larroy Jun 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest isn't that what I wrote? could you answer wrt separating into a different PR? Also MXNET_MEM_PLAN_VERBOSE_LOGGING is not documented in faq/env_var.md

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larroy I did not see your comment before I refresh the page. Our minds think alike :)
Sure, I can separate this into a different PR. I added here only to help @kshitij12345 dump out the graph and understand the backward pass better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is great, up to you to separate not a big deal. I'm not a radical on my reviews.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest Really like the dump, its great help to actually see the graph. Thank You.

@Roshrini Roshrini added the pr-awaiting-review PR is waiting for code review label Jun 23, 2019
@larroy
Copy link
Contributor

larroy commented Jun 27, 2019

Can we merge this?

@apeforest
Copy link
Contributor Author

I verified the result is the same as pytorch

import torch
import numpy as np
import math

op = lambda x: torch.sigmoid(x)
grad_op = lambda x: op(x) * (1 - op(x))
grad_grad_op = lambda x: grad_op(x) * (1 - 2 * op(x))
grad_grad_grad_op = lambda x: grad_grad_op(x) - 2 * ( grad_op(x)**2 + grad_grad_op(x) * op(x))

x = torch.tensor(np.array([1, 2, 3]), dtype=torch.float32)
head_grads = torch.tensor(np.array([1, 1, 1]), dtype=torch.float32) * 0.5
head_grad_grads = torch.tensor(np.array([1, 1, 1]), dtype=torch.float32) * 0.6
head_grad_grad_grads = torch.tensor(np.array([1, 1, 1]), dtype=torch.float32) * 0.7
x.requires_grad = True
head_grads.requires_grad = True

y = op(x)
x_grad = torch.autograd.grad(y, x, grad_outputs= head_grads, create_graph=True, retain_graph=True)[0]
expected_grad_x = (grad_op(x) * head_grads).detach().numpy()
print('expected_grad_x = {}'.format(expected_grad_x))
print('grad_x          = {}'.format(x_grad.detach().numpy()))
x_grad_grad = torch.autograd.grad(x_grad, x, grad_outputs= head_grad_grads, create_graph=True, retain_graph=True)[0]
x_grad_grad.backward(head_grad_grad_grads)

expected_grad_grad_x = (grad_grad_op(x) * head_grads * head_grad_grads).detach().numpy()
expected_head_grad = (grad_op(x) * head_grad_grads).detach().numpy()
expected_grad_grad_grad_x = (grad_grad_grad_op(x) * head_grads * head_grad_grads * head_grad_grad_grads).detach().numpy()

print('expected_grad_grad_x = {}'.format(expected_grad_grad_x))
print('grad_grad_x          = {}'.format(x_grad_grad.detach().numpy()))
print('expected_grad_grad_grad_x = {}'.format(expected_grad_grad_grad_x))
print('grad_grad_grad_x          = {}'.format(x.grad.detach().numpy()))

@apeforest
Copy link
Contributor Author

@sxjscience Please help to review this PR. Thanks!

kshitij12345 added a commit to kshitij12345/incubator-mxnet that referenced this pull request Jul 2, 2019
@apeforest
Copy link
Contributor Author

@kshitij12345 could you approve the PR if everything looks good to you now? thx

@apeforest apeforest merged commit 6a8d9eb into apache:master Jul 3, 2019
apeforest pushed a commit that referenced this pull request Jul 29, 2019
* init to reset

* issue: higher order backward sigmoid

* update gradient code.

update code as per #15288.

* undo changes

* relax tolerance of gradient mismatch for tanh

* update comments

* update comments
anirudhacharya pushed a commit to anirudhacharya/mxnet that referenced this pull request Aug 20, 2019
…5253)

* init to reset

* issue: higher order backward sigmoid

* update gradient code.

update code as per apache#15288.

* undo changes

* relax tolerance of gradient mismatch for tanh

* update comments

* update comments
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants