From 03cd1c7f6ae08e4238811897bdbc71ff0a7310ff Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 1 Jun 2019 14:48:58 +0530 Subject: [PATCH 1/5] fix bug with higher order log implementation. * bug: the head_grads were not preserved in higher order. * add test to validate the fix of the same. --- src/operator/tensor/elemwise_unary_op_basic.cc | 6 ++++-- tests/python/unittest/test_higher_order_grad.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index ee77817fcec9..e6d7a8250fe7 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1074,9 +1074,11 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, [](const nnvm::NodePtr& n, const std::vector& ograds) { // For f(x) -> f = log // f''(x) = -1 * (f'(x) * f'(x)) - auto gx = nnvm::NodeEntry{n}; + auto gx = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + {n->inputs[1]}, nullptr, &n); auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, gx}, nullptr, &n); + {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 92c78d15318d..68bdb6243e12 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -66,13 +66,19 @@ def grad_grad_op(x): def check_second_order_unary(x, op, grad_grad_op): x = nd.array(x) - expect_grad_grad = grad_grad_op(x) + grad_grad_x = grad_grad_op(x) x.attach_grad() with autograd.record(): y = op(x) - y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0] - y_grad.backward() - assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy()) + head_grads = nd.random.normal(shape=y.shape) + y_grad = autograd.grad(y, x, head_grads=head_grads, + create_graph=True, retain_graph=True)[0] + head_grad_grads = nd.random.normal(shape=y.shape) + y_grad.backward(head_grad_grads) + + expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ + head_grads.asnumpy() + assert_almost_equal(expected_grad_grad, x.grad.asnumpy()) if __name__ == '__main__': From 37ce3b87268a8154f5c0ad97ce2522478038ee06 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 1 Jun 2019 23:15:17 +0530 Subject: [PATCH 2/5] fix grad for head_grads and update relevant test --- .../tensor/elemwise_unary_op_basic.cc | 7 ++-- .../python/unittest/test_higher_order_grad.py | 33 +++++++++++++++---- 2 files changed, 31 insertions(+), 9 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index e6d7a8250fe7..469a3818ed57 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1074,18 +1074,19 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, [](const nnvm::NodePtr& n, const std::vector& ograds) { // For f(x) -> f = log // f''(x) = -1 * (f'(x) * f'(x)) - auto gx = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto gx_mul_head_grads = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto head_grads = nnvm::NodeEntry{n->inputs[0]}; auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", {n->inputs[1]}, nullptr, &n); auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); + {gx_mul_head_grads, nnvm::NodeEntry{g_lx}}, nullptr, &n); auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{g_lx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); return ret; diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 68bdb6243e12..45d46b47d1a3 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -27,13 +27,16 @@ def test_log(): def log(x): return nd.log(x) + def grad_op(x): + return 1/x + def grad_grad_op(x): return -1/(x**2) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log, grad_grad_op) + check_second_order_unary(array, log, grad_op, grad_grad_op) @with_seed() @@ -41,13 +44,16 @@ def test_log2(): def log2(x): return nd.log2(x) + def grad_op(x): + return 1/(x * math.log(2)) + def grad_grad_op(x): return -1/((x**2) * math.log(2)) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log2, grad_grad_op) + check_second_order_unary(array, log2, grad_op, grad_grad_op) @with_seed() @@ -55,30 +61,45 @@ def test_log10(): def log10(x): return nd.log10(x) + def grad_op(x): + return 1/(x * math.log(10)) + def grad_grad_op(x): return -1/((x**2) * math.log(10)) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log10, grad_grad_op) + check_second_order_unary(array, log10, grad_op, grad_grad_op) -def check_second_order_unary(x, op, grad_grad_op): +def check_second_order_unary(x, op, grad_op, grad_grad_op): x = nd.array(x) + grad_x = grad_op(x) grad_grad_x = grad_grad_op(x) x.attach_grad() + + # Manual head_grads. + head_grads = nd.random.normal(shape=x.shape) + head_grad_grads = nd.random.normal(shape=x.shape) + head_grads.attach_grad() + + # Perform compute. with autograd.record(): y = op(x) - head_grads = nd.random.normal(shape=y.shape) y_grad = autograd.grad(y, x, head_grads=head_grads, create_graph=True, retain_graph=True)[0] - head_grad_grads = nd.random.normal(shape=y.shape) + y_grad.backward(head_grad_grads) + # Compute expected values. expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ head_grads.asnumpy() + expected_heads_grad = grad_x.asnumpy() + + # Validate the gradients. assert_almost_equal(expected_grad_grad, x.grad.asnumpy()) + assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy()) if __name__ == '__main__': From 10f2b10ef807411f2e8f6f516fb747f77004d963 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 8 Jun 2019 15:24:51 +0530 Subject: [PATCH 3/5] address comments * remove assertion for y_grad gradient. * rename variables. * fix and update computation. --- .../tensor/elemwise_unary_op_basic.cc | 69 +++++++++++-------- .../python/unittest/test_higher_order_grad.py | 30 +++----- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 469a3818ed57..d84f98ebb53a 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1072,23 +1072,26 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log(x) + // f'(x) = 1/x // f''(x) = -1 * (f'(x) * f'(x)) - auto gx_mul_head_grads = nnvm::NodeEntry{n}; // f'(x) * head_grads - auto head_grads = nnvm::NodeEntry{n->inputs[0]}; - auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", {n->inputs[1]}, nullptr, &n); - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx_mul_head_grads, nnvm::NodeEntry{g_lx}}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], nnvm::NodeEntry{g_lx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dlogx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -1096,23 +1099,28 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log10 + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log10(x) // f'(x) = 1 / (log(10) * x) // f''(x) = -1 * (f'(x) * 1/x) - auto gx = nnvm::NodeEntry{n, 0, 0}; - auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {n->inputs[0]}, nullptr, &n); + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", {n->inputs[1]}, nullptr, &n); - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -1120,23 +1128,28 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log2 + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log10(x) // f'(x) = 1 / (log(2) * x) // f''(x) = -1 * (f'(x) * 1/x) - auto gx = nnvm::NodeEntry{n}; - auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {n->inputs[0]}, nullptr, &n); + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", {n->inputs[1]}, nullptr, &n); - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 45d46b47d1a3..2acc53676fef 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -27,16 +27,13 @@ def test_log(): def log(x): return nd.log(x) - def grad_op(x): - return 1/x - def grad_grad_op(x): return -1/(x**2) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log, grad_op, grad_grad_op) + check_second_order_unary(array, log, grad_grad_op) @with_seed() @@ -44,16 +41,13 @@ def test_log2(): def log2(x): return nd.log2(x) - def grad_op(x): - return 1/(x * math.log(2)) - def grad_grad_op(x): return -1/((x**2) * math.log(2)) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log2, grad_op, grad_grad_op) + check_second_order_unary(array, log2, grad_grad_op) @with_seed() @@ -61,45 +55,37 @@ def test_log10(): def log10(x): return nd.log10(x) - def grad_op(x): - return 1/(x * math.log(10)) - def grad_grad_op(x): return -1/((x**2) * math.log(10)) arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) for array in arrays: - check_second_order_unary(array, log10, grad_op, grad_grad_op) + check_second_order_unary(array, log10, grad_grad_op) -def check_second_order_unary(x, op, grad_op, grad_grad_op): +def check_second_order_unary(x, op, grad_grad_op): x = nd.array(x) - grad_x = grad_op(x) grad_grad_x = grad_grad_op(x) x.attach_grad() # Manual head_grads. - head_grads = nd.random.normal(shape=x.shape) + y_grad = nd.random.normal(shape=x.shape) head_grad_grads = nd.random.normal(shape=x.shape) - head_grads.attach_grad() # Perform compute. with autograd.record(): y = op(x) - y_grad = autograd.grad(y, x, head_grads=head_grads, + x_grad = autograd.grad(y, x, head_grads=y_grad, create_graph=True, retain_graph=True)[0] - - y_grad.backward(head_grad_grads) + x_grad.backward(head_grad_grads) # Compute expected values. expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ - head_grads.asnumpy() - expected_heads_grad = grad_x.asnumpy() + y_grad.asnumpy() # Validate the gradients. assert_almost_equal(expected_grad_grad, x.grad.asnumpy()) - assert_almost_equal(expected_heads_grad, head_grads.grad.asnumpy()) if __name__ == '__main__': From 23eaf42f94ee5d14a49ac6f6cbcf0ff4b6fa83de Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Sat, 15 Jun 2019 00:19:26 +0530 Subject: [PATCH 4/5] address comments * explicitly pass arguments with name. --- tests/python/unittest/test_higher_order_grad.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 1fb0ca2eee59..4f1ea9a6c7b8 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -118,7 +118,7 @@ def check_second_order_unary(x, op, grad_grad_op): # Perform compute. with autograd.record(): y = op(x) - x_grad = autograd.grad(y, x, head_grads=y_grad, + x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, create_graph=True, retain_graph=True)[0] x_grad.backward(head_grad_grads) From cf80ed69290291a47fa68d92ad7ad25654cc0616 Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Thu, 20 Jun 2019 08:53:05 +0530 Subject: [PATCH 5/5] fix mistyped comment. Co-Authored-By: Lin Yuan --- src/operator/tensor/elemwise_unary_op_basic.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 39c6ae7ceeff..98dc8dad825f 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1149,7 +1149,7 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, // ograds[0]: dL/dxgrad // inputs[0]: dL/dy // inputs[1]: x - // f(x) = y = log10(x) + // f(x) = y = log2(x) // f'(x) = 1 / (log(2) * x) // f''(x) = -1 * (f'(x) * 1/x) auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads