Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-978] Support higher order gradient for log. (#14992)
Browse files Browse the repository at this point in the history
* add higher order gradient support for log, log10, log2

* add tests

* address comments

* simplify NodeEntry creation.

* address comments

* update comment to avoid confusion.
  • Loading branch information
kshitij12345 authored and apeforest committed May 28, 2019
1 parent bbab527 commit 8a9dd72
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 3 deletions.
66 changes: 63 additions & 3 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1069,13 +1069,73 @@ The storage type of ``log2`` output is always dense
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log2"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log,
unary_bwd<mshadow_op::log_grad>);
unary_bwd<mshadow_op::log_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log
// f''(x) = -1 * (f'(x) * f'(x))
auto gx = nnvm::NodeEntry{n};
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, gx}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10,
unary_bwd<mshadow_op::log10_grad>);
unary_bwd<mshadow_op::log10_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log10
// f'(x) = 1 / (log(10) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n, 0, 0};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
return ret;
});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2,
unary_bwd<mshadow_op::log2_grad>);
unary_bwd<mshadow_op::log2_grad>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
// For f(x) -> f = log2
// f'(x) = 1 / (log(2) * x)
// f''(x) = -1 * (f'(x) * 1/x)
auto gx = nnvm::NodeEntry{n};
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad",
{n->inputs[1]}, nullptr, &n);
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad",
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n);
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad",
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n);

std::vector<nnvm::NodeEntry> ret;

ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad",
{ograds[0], gx}, nullptr, &n));
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp",
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n));
return ret;
});

// log1p
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(log1p, cpu, mshadow_op::log1p)
Expand Down
80 changes: 80 additions & 0 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import math

from mxnet import nd, autograd
from mxnet.test_utils import assert_almost_equal, random_arrays
from common import with_seed


@with_seed()
def test_log():
def log(x):
return nd.log(x)

def grad_grad_op(x):
return -1/(x**2)

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log, grad_grad_op)


@with_seed()
def test_log2():
def log2(x):
return nd.log2(x)

def grad_grad_op(x):
return -1/((x**2) * math.log(2))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log2, grad_grad_op)


@with_seed()
def test_log10():
def log10(x):
return nd.log10(x)

def grad_grad_op(x):
return -1/((x**2) * math.log(10))

arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5))

for array in arrays:
check_second_order_unary(array, log10, grad_grad_op)


def check_second_order_unary(x, op, grad_grad_op):
x = nd.array(x)
expect_grad_grad = grad_grad_op(x)
x.attach_grad()
with autograd.record():
y = op(x)
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0]
y_grad.backward()
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy())


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 8a9dd72

Please sign in to comment.