From 8a9dd721bb49172b1b4bbe9ef99fc8a094d8724a Mon Sep 17 00:00:00 2001 From: kshitij12345 Date: Tue, 28 May 2019 23:12:21 +0530 Subject: [PATCH] [MXNET-978] Support higher order gradient for `log`. (#14992) * add higher order gradient support for log, log10, log2 * add tests * address comments * simplify NodeEntry creation. * address comments * update comment to avoid confusion. --- .../tensor/elemwise_unary_op_basic.cc | 66 ++++++++++++++- .../python/unittest/test_higher_order_grad.py | 80 +++++++++++++++++++ 2 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 tests/python/unittest/test_higher_order_grad.py diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index f4ef9c269918..ee77817fcec9 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -1069,13 +1069,73 @@ The storage type of ``log2`` output is always dense .set_attr("FGradient", ElemwiseGradUseIn{"_backward_log2"}); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // For f(x) -> f = log + // f''(x) = -1 * (f'(x) * f'(x)) + auto gx = nnvm::NodeEntry{n}; + auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", + {gx, gx}, nullptr, &n); + auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", + {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + + std::vector ret; + + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", + {ograds[0], gx}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", + {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + return ret; + }); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // For f(x) -> f = log10 + // f'(x) = 1 / (log(10) * x) + // f''(x) = -1 * (f'(x) * 1/x) + auto gx = nnvm::NodeEntry{n, 0, 0}; + auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + {n->inputs[1]}, nullptr, &n); + auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", + {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); + auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", + {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + + std::vector ret; + + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", + {ograds[0], gx}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", + {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + return ret; + }); MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, - unary_bwd); + unary_bwd) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + // For f(x) -> f = log2 + // f'(x) = 1 / (log(2) * x) + // f''(x) = -1 * (f'(x) * 1/x) + auto gx = nnvm::NodeEntry{n}; + auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", + {n->inputs[1]}, nullptr, &n); + auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", + {gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); + auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", + {nnvm::NodeEntry{ggx_mid}}, nullptr, &n); + + std::vector ret; + + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", + {ograds[0], gx}, nullptr, &n)); + ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", + {ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); + return ret; + }); // log1p MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(log1p, cpu, mshadow_op::log1p) diff --git a/tests/python/unittest/test_higher_order_grad.py b/tests/python/unittest/test_higher_order_grad.py new file mode 100644 index 000000000000..92c78d15318d --- /dev/null +++ b/tests/python/unittest/test_higher_order_grad.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import math + +from mxnet import nd, autograd +from mxnet.test_utils import assert_almost_equal, random_arrays +from common import with_seed + + +@with_seed() +def test_log(): + def log(x): + return nd.log(x) + + def grad_grad_op(x): + return -1/(x**2) + + arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) + + for array in arrays: + check_second_order_unary(array, log, grad_grad_op) + + +@with_seed() +def test_log2(): + def log2(x): + return nd.log2(x) + + def grad_grad_op(x): + return -1/((x**2) * math.log(2)) + + arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) + + for array in arrays: + check_second_order_unary(array, log2, grad_grad_op) + + +@with_seed() +def test_log10(): + def log10(x): + return nd.log10(x) + + def grad_grad_op(x): + return -1/((x**2) * math.log(10)) + + arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) + + for array in arrays: + check_second_order_unary(array, log10, grad_grad_op) + + +def check_second_order_unary(x, op, grad_grad_op): + x = nd.array(x) + expect_grad_grad = grad_grad_op(x) + x.attach_grad() + with autograd.record(): + y = op(x) + y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0] + y_grad.backward() + assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy()) + + +if __name__ == '__main__': + import nose + nose.runmodule()