Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-978] Higher order gradient for sigmoid #15288

Merged
merged 44 commits into from
Jul 3, 2019
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
45e1502
try to add support some ops
sxjscience Oct 14, 2018
904adb4
Merge branch 'higher_order_sample' of https://github.com/sxjscience/m…
apeforest Mar 7, 2019
d5dc994
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Mar 12, 2019
0e69075
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Mar 19, 2019
0c7cf98
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 2, 2019
492e4cd
add unit test for second order grad
apeforest Apr 3, 2019
45b334e
implement grad for relu and add unit test
apeforest Apr 3, 2019
3bbfbac
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 4, 2019
4dc0907
fix lint
apeforest Apr 5, 2019
c4034b2
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Apr 5, 2019
3fe54e6
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 16, 2019
76aa6ad
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 16, 2019
8458717
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 21, 2019
f66610b
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 23, 2019
30ff1e9
register FGradient attribute for backward relu
apeforest May 28, 2019
8ecffcc
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 28, 2019
d9ba3da
resolve conflict
apeforest May 28, 2019
1c93c7d
remove unused imports
apeforest May 28, 2019
de721bc
change gradient using set_attr
apeforest May 30, 2019
0ac0942
remove higher order grad test for negative(x)
apeforest May 30, 2019
f8e624e
fix lint
apeforest May 30, 2019
3315124
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 30, 2019
8538980
reverse indent
apeforest May 30, 2019
1ee38b5
remove unused backward operator
apeforest May 30, 2019
c18f317
refactor backward for sin(x) and cos(x)
apeforest May 30, 2019
689cfee
change value init to list init
apeforest May 30, 2019
d56e132
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 30, 2019
2207815
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest May 31, 2019
0b6c2ef
change to list initialization
apeforest May 31, 2019
31f671f
generate random shape in test
apeforest May 31, 2019
62fcca3
fix a bug in second order backward
apeforest Jun 3, 2019
a0a0e75
fix lint
apeforest Jun 3, 2019
451c4bd
fix lint
apeforest Jun 4, 2019
b9b0c93
address reviewer comment and renaming
apeforest Jun 5, 2019
d060102
Merge branch 'master' into develop/higher_order_grad
apeforest Jun 18, 2019
94e3b5f
test 2nd order gradient for sigmoid
apeforest Jun 19, 2019
7d95760
higher order grads for sigmoid
apeforest Jun 20, 2019
f43489d
add unit test
apeforest Jun 20, 2019
55d7ebc
remove blank lines
apeforest Jun 20, 2019
25d3f78
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Jun 20, 2019
10dab58
update test
apeforest Jun 20, 2019
d134d2f
fix lint
apeforest Jun 20, 2019
30b1ba9
Merge remote-tracking branch 'upstream/master' into develop/higher_or…
apeforest Jun 29, 2019
6848d42
fix third order gradient for sigmoid
apeforest Jul 2, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/common/exec_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ inline void LogMemoryPlan(const nnvm::Graph& g) {
const auto &idx = g.indexed_graph();
const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& vtype = g.GetAttr<nnvm::DTypeVector>("dtype");
const auto& vstorage = g.GetAttr<nnvm::StorageVector>("storage_id");
// find node range
uint32_t node_start = 0, node_end = idx.num_nodes();
if (g.attrs.count("node_range")) {
Expand All @@ -304,13 +303,13 @@ inline void LogMemoryPlan(const nnvm::Graph& g) {
auto eid = idx.entry_id(e);
size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
LOG(INFO) << "\t\tinput " << eid << ": " << vshape[eid] << " ("
<< kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
<< kilo_bytes << " KB)";
}
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
size_t kilo_bytes = vshape[eid].Size() * mshadow::mshadow_sizeof(vtype[eid]) / 1024;
LOG(INFO) << "\t\toutput " << eid << ": " << vshape[eid] << " ("
<< kilo_bytes << " KB) -> " << storage_str(vstorage[eid]);
<< kilo_bytes << " KB)";
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,10 @@ std::vector<NDArray*> Imperative::Backward(
}
}

if (dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice hack :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain how logging of static memory helps here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

He wants to dump the graph. maybe should be separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually dumping out the computation graph.

Copy link
Contributor

@larroy larroy Jun 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest isn't that what I wrote? could you answer wrt separating into a different PR? Also MXNET_MEM_PLAN_VERBOSE_LOGGING is not documented in faq/env_var.md

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@larroy I did not see your comment before I refresh the page. Our minds think alike :)
Sure, I can separate this into a different PR. I added here only to help @kshitij12345 dump out the graph and understand the backward pass better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR is great, up to you to separate not a big deal. I'm not a radical on my reviews.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@apeforest Really like the dump, its great help to actually see the graph. Thank You.

common::LogMemoryPlan(graph);
}

// Execution

bool prev_recording = set_is_recording(create_graph);
Expand Down
25 changes: 24 additions & 1 deletion src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,30 @@ The storage type of ``sigmoid`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
unary_bwd<mshadow_op::sigmoid_grad>);
unary_bwd<mshadow_op::sigmoid_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// n->inputs[0] : y_grad
// n->inputs[1] : f(x) = sigmoid(x)
// ograds[0] : head_grads
// f''(x) = f'(x) * (1 - 2*f(x))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

auto ones = MakeNode("ones_like", n->attrs.name + "_grad_ones", {n->inputs[1]}, nullptr, &n);
const std::unordered_map<std::string, std::string> args = {{"scalar", "2.0"}};
auto two_y = MakeNode("_mul_scalar", n->attrs.name + "_mul_two", {n->inputs[1]}, &args, &n);
auto one_minus_two_y = MakeNode("elemwise_sub", n->attrs.name + "_grad_sub",
{nnvm::NodeEntry{ones}, nnvm::NodeEntry{two_y}}, nullptr, &n);
auto grad_grad_mid = MakeNode("elemwise_mul", n->attrs.name + "_grad_mul",
{n->inputs[0], nnvm::NodeEntry{one_minus_two_y}}, nullptr, &n);
// when building gradient graph, the backward node of n->inputs[1] will be
// added to the graph again, therefore f`(x) will be multiplied
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this behaviour seem a bit odd? Is this actually the expected behaviour? What would have happened if this was split like function where we would have had many outputs, backward of all outputs be added? Actually I am confused by the behaviour.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is actually the expected behavior. The nnvm graph will perform a DFS traverse when performing backward pass. Since the input to this backward_sigmoid node is an output of another node sigmoid, during the backward pass the gradient function sigmoid will be invoked.

See:

This is to collect dependent nodes in the graph during RecordOp:
https://github.com/apache/incubator-mxnet/blob/master/src/imperative/imperative.cc#L180

This is to actually perform the backward pass:
https://github.com/dmlc/tvm/blob/21935dcbf56ad3bd66ebff9891a6bc3865b8106d/nnvm/src/pass/gradient.cc#L126

https://github.com/dmlc/tvm/blob/21935dcbf56ad3bd66ebff9891a6bc3865b8106d/nnvm/src/pass/gradient.cc#L190

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also didn't get that. A drawing might help. I would expect that you multiply with the backward node itself "n"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that is what he meant as

      // n->inputs[0] : y_grad
      // n->inputs[1] : f(x) = sigmoid(x)
      // ograds[0] : head_grads
      // f''(x) = f'(x) * (1 - 2*f(x))

Backward of node n->inputs[1] is the Node n itself.

One up for the visual drawing/graph of what is actually happening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMG-3494

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I'll look at it today.

Copy link
Contributor

@kshitij12345 kshitij12345 Jun 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kind of see from the graph you have drawn and dumped graph as to what is happening. I still strongly feel that it is the incorrect behaviour.

To check that I modified the test code slightly to also compute and validate for the 3rd order gradient as it should come for free since the second order are composed from differentiable functions.

Have tested for sin, cos, log, sigmoid. Out of them only sigmoid fails for the third order.

I have also attached the hand computation of third-order for sigmoid, so please verify that just to make sure that it is not incorrect.

For some reason it is rotated.
3_rd_ord

import math
from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays, rand_shape_nd
from common import with_seed


@with_seed()
def test_sin():
    def sin(x):
        return nd.sin(x)

    def grad_grad_op(x):
        return -nd.sin(x)
    
    def grad_grad_grad_op(x):
        return -nd.cos(x)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, sin, grad_grad_op, grad_grad_grad_op)


@with_seed()
def test_cos():
    def cos(x):
        return nd.cos(x)

    def grad_grad_op(x):
        return -nd.cos(x)
    
    def grad_grad_grad_op(x):
        return nd.sin(x)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, cos, grad_grad_op, grad_grad_grad_op)

@with_seed()
def test_log():
    def log(x):
        return nd.log(x)

    def grad_grad_op(x):
        return -1/(x**2)

    def grad_grad_grad_op(x):
        return 2/(x**3)

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, log, grad_grad_op, grad_grad_grad_op)


@with_seed()
def test_sigmoid():
    def sigmoid(x):
        return nd.sigmoid(x)

    def grad_op(x):
        return sigmoid(x) * (1 - sigmoid(x))

    def grad_grad_op(x):
        return grad_op(x) * (1 - 2 * sigmoid(x))

    def grad_grad_grad_op(x):
        return grad_grad_op(x) - 2 * ( grad_op(x)**2 + grad_grad_op(x) * sigmoid(x))

    for dim in range(1, 5):
        shape = rand_shape_nd(dim)
        array = random_arrays(shape)
        check_second_order_unary(array, sigmoid, grad_grad_op, grad_grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op, grad_grad_grad_op):
    x = nd.array(x)
    grad_grad_x = grad_grad_op(x)
    grad_grad_grad_x = grad_grad_grad_op(x)
    x.attach_grad()

    # Manual head_grads.
    y_grad = nd.random.normal(shape=x.shape)
    head_grad_grads = nd.random.normal(shape=x.shape)

    # Perform compute.
    with autograd.record():
        y = op(x)
        x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad,
                               create_graph=True, retain_graph=True)[0]
        
        x_grad_grad = autograd.grad(heads=x_grad, variables=x, head_grads=head_grad_grads,
                                create_graph=True, retain_graph=True)[0]
    x_grad_grad.backward()

    # Compute expected values.
    expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
        y_grad.asnumpy()

    expected_grad_grad_grad = grad_grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \
        y_grad.asnumpy()

    # Validate the gradients.
    assert_almost_equal(expected_grad_grad, x_grad_grad.asnumpy())
    assert_almost_equal(expected_grad_grad_grad, x.grad.asnumpy())

if __name__ == '__main__':
    import nose
    nose.runmodule()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the discrepancy in 3rd order gradient is not because an error in my implementation of grad_grad_input, but because I did not return back the first grad_grad in correctly. After all, the first output may be useful in calculating higher order even if they are not visible output at Python level. I will look into it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345 I have fixed the issue. The result can pass your test now. Please review again. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

std::vector<nnvm::NodeEntry> ret;
ret.emplace_back(ograds[0]); // this output is not passed out if gradient w.r.t x only
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "backward_grad_grad_in",
{ograds[0], nnvm::NodeEntry{grad_grad_mid}}, nullptr, &n));
return ret;
});



DMLC_REGISTER_PARAMETER(HardSigmoidParam);
MXNET_OPERATOR_REGISTER_UNARY(hard_sigmoid)
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,23 @@ def grad_grad_op(x):
check_second_order_unary(array, log10, grad_grad_op)


@with_seed()
def test_sigmoid():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks correct to me.

def sigmoid(x):
return nd.sigmoid(x)

def grad_op(x):
return sigmoid(x) * (1 - sigmoid(x))

def grad_grad_op(x):
return grad_op(x) * (1 - 2 * sigmoid(x))

for dim in range(1, 5):
shape = rand_shape_nd(dim)
array = random_arrays(shape)
check_second_order_unary(array, sigmoid, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
grad_grad_x = grad_grad_op(x)
Expand Down