-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-978] Support higher order gradient for log
, log2
, log10
.
#14992
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1069,13 +1069,73 @@ The storage type of ``log2`` output is always dense | |
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log2"}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, | ||
unary_bwd<mshadow_op::log_grad>); | ||
unary_bwd<mshadow_op::log_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// For g(x) -> g = log | ||
// g''(x) = -1 * (g'(x) * g'(x)) | ||
auto gx = nnvm::NodeEntry{n}; | ||
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", | ||
{gx, gx}, nullptr, &n); | ||
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], gx}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are we returning two gradients, isn't it an unary function with just one input? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It takes as input the output gradient and input to log. |
||
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10, | ||
unary_bwd<mshadow_op::log10_grad>); | ||
unary_bwd<mshadow_op::log10_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// For g(x) -> g = log10 | ||
// g'(x) = 1 / (log(10) * x) | ||
// g''(x) = -1 * (g'(x) * 1/x) | ||
auto gx = nnvm::NodeEntry{n, 0, 0}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we follow the same pattern as in the natural logarithm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For natural Considering There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @larroy Thanks for pointing this, going through this again made me realise that there is a problem with the implementation of |
||
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", | ||
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); | ||
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above. |
||
{ograds[0], gx}, nullptr, &n)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", | ||
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, | ||
unary_bwd<mshadow_op::log2_grad>); | ||
unary_bwd<mshadow_op::log2_grad>) | ||
.set_attr<nnvm::FGradient>("FGradient", | ||
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) { | ||
// For g(x) -> g = log2 | ||
// g'(x) = 1 / (log(2) * x) | ||
// g''(x) = -1 * (g'(x) * 1/x) | ||
auto gx = nnvm::NodeEntry{n}; | ||
auto g_lx = MakeNode("reciprocal", n->attrs.name + "_backward_log_grad", | ||
{n->inputs[1]}, nullptr, &n); | ||
auto ggx_mid = MakeNode("elemwise_mul", n->attrs.name + "_backward_mid_grad_grad", | ||
{gx, nnvm::NodeEntry{g_lx}}, nullptr, &n); | ||
auto ggx = MakeNode("negative", n->attrs.name + "_backward_grad_grad", | ||
{nnvm::NodeEntry{ggx_mid}}, nullptr, &n); | ||
|
||
std::vector<nnvm::NodeEntry> ret; | ||
|
||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", | ||
{ograds[0], gx}, nullptr, &n)); | ||
ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", | ||
{ograds[0], nnvm::NodeEntry{ggx}}, nullptr, &n)); | ||
return ret; | ||
}); | ||
|
||
// log1p | ||
MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(log1p, cpu, mshadow_op::log1p) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import math | ||
|
||
from mxnet import nd, autograd | ||
from mxnet.test_utils import assert_almost_equal, random_arrays | ||
from common import with_seed | ||
|
||
|
||
@with_seed() | ||
def test_log(): | ||
def log(x): | ||
return nd.log(x) | ||
|
||
def grad_grad_op(x): | ||
return -1/(x**2) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log2(): | ||
def log2(x): | ||
return nd.log2(x) | ||
|
||
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(2)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log2, grad_grad_op) | ||
|
||
|
||
@with_seed() | ||
def test_log10(): | ||
def log10(x): | ||
return nd.log10(x) | ||
|
||
def grad_grad_op(x): | ||
return -1/((x**2) * math.log(10)) | ||
|
||
arrays = random_arrays((2, 2), (2, 3), (4, 5, 2), (3, 1, 4, 5)) | ||
|
||
for array in arrays: | ||
check_second_order_unary(array, log10, grad_grad_op) | ||
|
||
|
||
def check_second_order_unary(x, op, grad_grad_op): | ||
x = nd.array(x) | ||
expect_grad_grad = grad_grad_op(x) | ||
x.attach_grad() | ||
with autograd.record(): | ||
y = op(x) | ||
y_grad = autograd.grad(y, x, create_graph=True, retain_graph=True)[0] | ||
y_grad.backward() | ||
assert_almost_equal(expect_grad_grad.asnumpy(), x.grad.asnumpy()) | ||
|
||
|
||
if __name__ == '__main__': | ||
import nose | ||
nose.runmodule() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It's very nice to see a comment here. The g(x) is actually a function of x. It might be easily confused with the variable gx two lines below. Maybe use
f(x)
in the comment here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah sure . That makes sense. Thank You.