diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index f2b8dd6b1314..98dc8dad825f 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1090,20 +1090,26 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log(x) + // f'(x) = 1/x // f''(x) = -1 * (f'(x) * f'(x)) - auto gx = nnvm::NodeEntry{n}; - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, gx}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", + {n->inputs[1]}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dlogx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -1111,23 +1117,28 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log10 + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log10(x) // f'(x) = 1 / (log(10) * x) // f''(x) = -1 * (f'(x) * 1/x) - auto gx = nnvm::NodeEntry{n, 0, 0}; - auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {n->inputs[0]}, nullptr, &n); + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", {n->inputs[1]}, nullptr, &n); - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -1135,23 +1146,28 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, unary_bwd) .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { - // For f(x) -> f = log2 + // ograds[0]: dL/dxgrad + // inputs[0]: dL/dy + // inputs[1]: x + // f(x) = y = log2(x) // f'(x) = 1 / (log(2) * x) // f''(x) = -1 * (f'(x) * 1/x) - auto gx = nnvm::NodeEntry{n}; - auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads + auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", + {n->inputs[0]}, nullptr, &n); + auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", {n->inputs[1]}, nullptr, &n); - auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", - {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); - auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", - {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", + {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); + auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", + {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); std::vector ret; ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], gx}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py index 77bfa68157aa..4f1ea9a6c7b8 100644 --- a/tests/python/unittest/test_higher_order_grad.py +++ b/tests/python/unittest/test_higher_order_grad.py @@ -108,13 +108,26 @@ def grad_grad_op(x): def check_second_order_unary(x, op, grad_grad_op): x = nd.array(x) - expect_grad_grad = grad_grad_op(x) + grad_grad_x = grad_grad_op(x) x.attach_grad() + + # Manual head_grads. + y_grad = nd.random.normal(shape=x.shape) + head_grad_grads = nd.random.normal(shape=x.shape) + + # Perform compute. with autograd.record(): y = op(x) - y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0] - y_grad.backward() - assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy()) + x_grad = autograd.grad(heads=y, variables=x, head_grads=y_grad, + create_graph=True, retain_graph=True)[0] + x_grad.backward(head_grad_grads) + + # Compute expected values. + expected_grad_grad = grad_grad_x.asnumpy() * head_grad_grads.asnumpy() * \ + y_grad.asnumpy() + + # Validate the gradients. + assert_almost_equal(expected_grad_grad, x.grad.asnumpy()) if __name__ == '__main__':