diff --git a/src/nnvm/node_op_util.h b/src/nnvm/node_op_util.h index 8d5916aafff9..54a96336fb94 100644 --- a/src/nnvm/node_op_util.h +++ b/src/nnvm/node_op_util.h @@ -68,6 +68,18 @@ class NodeOpGen { dependent_node->attrs.name + "_square", {x}, nullptr, &dependent_node)}; } + + nnvm::NodeEntry reciprocal(const nnvm::NodeEntry &x) { + return nnvm::NodeEntry{mxnet::op::MakeNode("reciprocal", + dependent_node->attrs.name + "_reciprocal", + {x}, nullptr, &dependent_node)}; + } + + nnvm::NodeEntry negative(const nnvm::NodeEntry &x) { + return nnvm::NodeEntry{mxnet::op::MakeNode("negative", + dependent_node->attrs.name + "_negative", + {x}, nullptr, &dependent_node)}; + } }; } // namespace util diff --git a/src/operator/tensor/elemwise_unary_op_logexp.cc b/src/operator/tensor/elemwise_unary_op_logexp.cc index 65394826276f..7ca12e0b248b 100644 --- a/src/operator/tensor/elemwise_unary_op_logexp.cc +++ b/src/operator/tensor/elemwise_unary_op_logexp.cc @@ -25,6 +25,7 @@ #include "elemwise_unary_op.h" #include "./elemwise_binary_op-inl.h" #include "../nn/mkldnn/mkldnn_ops-inl.h" +#include "../../nnvm/node_op_util.h" namespace mxnet { namespace op { @@ -110,25 +111,23 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log, .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { // ograds[0]: dL/dxgrad - // inputs[0]: dL/dy - // inputs[1]: x + // inputs[0]: dL/dy (ygrad) + // inputs[1]: x (ElemewiseGradUseIn) // f(x) = y = log(x) // f'(x) = 1/x // f''(x) = -1 * (f'(x) * f'(x)) + auto x = n->inputs[1]; auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads - auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", - {n->inputs[1]}, nullptr, &n); - auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", - {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); - auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", - {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); + auto op = mxnet::util::NodeOpGen{n}; + + auto dlogx = op.reciprocal(x); + auto d2ydx2_mid = op.mul(dydx_mul_dldy, dlogx); + auto d2ydx2 = op.negative(d2ydx2_mid); std::vector ret; + ret.emplace_back(op.mul(ograds[0], dlogx)); + ret.emplace_back(op.mul(ograds[0], d2ydx2)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], nnvm::NodeEntry{dlogx}}, nullptr, &n)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -137,27 +136,24 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log10, .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { // ograds[0]: dL/dxgrad - // inputs[0]: dL/dy - // inputs[1]: x + // inputs[0]: dL/dy (ygrad) + // inputs[1]: x (ElemewiseGradUseIn) // f(x) = y = log10(x) // f'(x) = 1 / (log(10) * x) // f''(x) = -1 * (f'(x) * 1/x) + auto dldy = n->inputs[0]; + auto x = n->inputs[1]; auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads - auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", - {n->inputs[0]}, nullptr, &n); - auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", - {n->inputs[1]}, nullptr, &n); - auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", - {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); - auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", - {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); + auto op = mxnet::util::NodeOpGen{n}; + auto dydx = op.div(dydx_mul_dldy, dldy); + auto dlogx = op.reciprocal(x); + auto d2ydx2_mid = op.mul(dydx_mul_dldy, dlogx); + auto d2ydx2 = op.negative(d2ydx2_mid); std::vector ret; + ret.emplace_back(op.mul(ograds[0], dydx)); + ret.emplace_back(op.mul(ograds[0], d2ydx2)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; }); @@ -166,27 +162,24 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_DR(_backward_log2, .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { // ograds[0]: dL/dxgrad - // inputs[0]: dL/dy - // inputs[1]: x + // inputs[0]: dL/dy (ygrad) + // inputs[1]: x (ElemewiseGradUseIn) // f(x) = y = log2(x) // f'(x) = 1 / (log(2) * x) // f''(x) = -1 * (f'(x) * 1/x) + auto dldy = n->inputs[0]; + auto x = n->inputs[1]; auto dydx_mul_dldy = nnvm::NodeEntry{n}; // f'(x) * head_grads - auto dydx = MakeNode("elemwise_div", n->attrs.name + "_dydx", - {n->inputs[0]}, nullptr, &n); - auto dlogx = MakeNode("reciprocal", n->attrs.name + "_dlogx", - {n->inputs[1]}, nullptr, &n); - auto d2ydx2_mid = MakeNode("elemwise_mul", n->attrs.name + "_d2ydx2_mid", - {dydx_mul_dldy, nnvm::NodeEntry{dlogx}}, nullptr, &n); - auto d2ydx2 = MakeNode("negative", n->attrs.name + "_d2ydx2", - {nnvm::NodeEntry{d2ydx2_mid}}, nullptr, &n); + auto op = mxnet::util::NodeOpGen{n}; + auto dydx = op.div(dydx_mul_dldy, dldy); + auto dlogx = op.reciprocal(x); + auto d2ydx2_mid = op.mul(dydx_mul_dldy, dlogx); + auto d2ydx2 = op.negative(d2ydx2_mid); std::vector ret; + ret.emplace_back(op.mul(ograds[0], dydx)); + ret.emplace_back(op.mul(ograds[0], d2ydx2)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad", - {ograds[0], nnvm::NodeEntry{dydx}}, nullptr, &n)); - ret.emplace_back(MakeNode("elemwise_mul", n->attrs.name + "_backward_grad_grad_inp", - {ograds[0], nnvm::NodeEntry{d2ydx2}}, nullptr, &n)); return ret; });